diff options
Diffstat (limited to 'synapse/types.py')
-rw-r--r-- | synapse/types.py | 103 |
1 files changed, 58 insertions, 45 deletions
diff --git a/synapse/types.py b/synapse/types.py index 9ac688b2..0586d2cb 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -24,6 +24,7 @@ from typing import ( Mapping, Match, MutableMapping, + NoReturn, Optional, Set, Tuple, @@ -35,7 +36,8 @@ from typing import ( import attr from frozendict import frozendict from signedjson.key import decode_verify_key_bytes -from typing_extensions import TypedDict +from signedjson.types import VerifyKey +from typing_extensions import Final, TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore, PurgeEventsStore + from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore # Define a state map type from type/state_key to T (usually an event ID or # event) @@ -114,7 +117,7 @@ class Requester: app_service: Optional["ApplicationService"] authenticated_entity: str - def serialize(self): + def serialize(self) -> Dict[str, Any]: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -132,7 +135,9 @@ class Requester: } @staticmethod - def deserialize(store, input): + def deserialize( + store: "ApplicationServiceWorkerStore", input: Dict[str, Any] + ) -> "Requester": """Converts a dict that was produced by `serialize` back into a Requester. @@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta): domain: str # Because this is a frozen class, it is deeply immutable. - def __copy__(self): + def __copy__(self: DS) -> DS: return self - def __deepcopy__(self, memo): + def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: return self @classmethod @@ -315,29 +320,6 @@ class EventID(DomainSpecificString): SIGIL = "$" -@attr.s(slots=True, frozen=True, repr=False) -class GroupID(DomainSpecificString): - """Structure representing a group ID.""" - - SIGIL = "+" - - @classmethod - def from_string(cls: Type[DS], s: str) -> DS: - group_id: DS = super().from_string(s) # type: ignore - - if not group_id.localpart: - raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) - - if contains_invalid_mxid_characters(group_id.localpart): - raise SynapseError( - 400, - "Group ID can only contain characters a-z, 0-9, or '=_-./'", - Codes.INVALID_PARAM, - ) - - return group_id - - mxid_localpart_allowed_characters = set( "_-./=" + string.ascii_lowercase + string.digits ) @@ -625,6 +607,22 @@ class RoomStreamToken: return "s%d" % (self.stream,) +class StreamKeyType: + """Known stream types. + + A stream is a list of entities ordered by an incrementing "stream token". + """ + + ROOM: Final = "room_key" + PRESENCE: Final = "presence_key" + TYPING: Final = "typing_key" + RECEIPT: Final = "receipt_key" + ACCOUNT_DATA: Final = "account_data_key" + PUSH_RULES: Final = "push_rules_key" + TO_DEVICE: Final = "to_device_key" + DEVICE_LIST: Final = "device_list_key" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class StreamToken: """A collection of keys joined together by underscores in the following @@ -641,7 +639,7 @@ class StreamToken: 6. `push_rules_key`: `541479` 7. `to_device_key`: `274711` 8. `device_list_key`: `265584` - 9. `groups_key`: `1` + 9. `groups_key`: `1` (note that this key is now unused) You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -693,6 +691,7 @@ class StreamToken: push_rules_key: int to_device_key: int device_list_key: int + # Note that the groups key is no longer used and may have bogus values. groups_key: int _SEPARATOR = "_" @@ -724,21 +723,26 @@ class StreamToken: str(self.push_rules_key), str(self.to_device_key), str(self.device_list_key), + # Note that the groups key is no longer used, but it is still + # serialized so that there will not be confusion in the future + # if additional tokens are added. str(self.groups_key), ] ) @property - def room_stream_id(self): + def room_stream_id(self) -> int: return self.room_key.stream - def copy_and_advance(self, key, new_value) -> "StreamToken": + def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": """Advance the given key in the token to a new value if and only if the new value is after the old value. + + :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. """ - if key == "room_key": + if key == StreamKeyType.ROOM: new_token = self.copy_and_replace( - "room_key", self.room_key.copy_and_advance(new_value) + StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) ) return new_token @@ -751,7 +755,7 @@ class StreamToken: else: return self - def copy_and_replace(self, key, new_value) -> "StreamToken": + def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": return attr.evolve(self, **{key: new_value}) @@ -793,14 +797,14 @@ class ThirdPartyInstanceID: # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) - def __iter__(self): + def __iter__(self) -> NoReturn: raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) # Because this class is a frozen class, it is deeply immutable. - def __copy__(self): + def __copy__(self) -> "ThirdPartyInstanceID": return self - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": return self @classmethod @@ -852,25 +856,28 @@ class DeviceListUpdates: return bool(self.changed or self.left) -def get_verify_key_from_cross_signing_key(key_info): +def get_verify_key_from_cross_signing_key( + key_info: Mapping[str, Any] +) -> Tuple[str, VerifyKey]: """Get the key ID and signedjson verify key from a cross-signing key dict Args: - key_info (dict): a cross-signing key dict, which must have a "keys" + key_info: a cross-signing key dict, which must have a "keys" property that has exactly one item in it Returns: - (str, VerifyKey): the key ID and verify key for the cross-signing key + the key ID and verify key for the cross-signing key """ - # make sure that exactly one key is provided + # make sure that a `keys` field is provided if "keys" not in key_info: raise ValueError("Invalid key") keys = key_info["keys"] - if len(keys) != 1: - raise ValueError("Invalid key") - # and return that one key - for key_id, key_data in keys.items(): + # and that it contains exactly one key + if len(keys) == 1: + key_id, key_data = next(iter(keys.items())) return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) + else: + raise ValueError("Invalid key") @attr.s(auto_attribs=True, frozen=True, slots=True) @@ -906,3 +913,9 @@ class UserProfile(TypedDict): user_id: str display_name: Optional[str] avatar_url: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RetentionPolicy: + min_lifetime: Optional[int] = None + max_lifetime: Optional[int] = None |