diff options
author | Andrej Shadura <andrewsh@debian.org> | 2021-06-29 12:59:58 +0200 |
---|---|---|
committer | Andrej Shadura <andrewsh@debian.org> | 2021-06-29 12:59:58 +0200 |
commit | 364c37238258580e132178cc7b35acabce3ff326 (patch) | |
tree | dadc23431f7f55fd8bcd8780c64b519dae7d5a76 /synapse/handlers | |
parent | 219af4a8aef838c5e3689a2aa71cf72f2fd75aa2 (diff) |
New upstream version 1.37.0
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/acme.py | 117 | ||||
-rw-r--r-- | synapse/handlers/acme_issuing_service.py | 127 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 7 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 350 | ||||
-rw-r--r-- | synapse/handlers/event_auth.py | 45 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 214 | ||||
-rw-r--r-- | synapse/handlers/message.py | 134 | ||||
-rw-r--r-- | synapse/handlers/register.py | 2 | ||||
-rw-r--r-- | synapse/handlers/room_list.py | 7 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 284 | ||||
-rw-r--r-- | synapse/handlers/room_member_worker.py | 55 | ||||
-rw-r--r-- | synapse/handlers/space_summary.py | 45 | ||||
-rw-r--r-- | synapse/handlers/sso.py | 25 | ||||
-rw-r--r-- | synapse/handlers/stats.py | 7 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 123 |
15 files changed, 991 insertions, 551 deletions
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py deleted file mode 100644 index 16ab93f5..00000000 --- a/synapse/handlers/acme.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING - -import twisted -import twisted.internet.error -from twisted.web import server, static -from twisted.web.resource import Resource - -from synapse.app import check_bind_error - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -ACME_REGISTER_FAIL_ERROR = """ --------------------------------------------------------------------------------- -Failed to register with the ACME provider. This is likely happening because the installation -is new, and ACME v1 has been deprecated by Let's Encrypt and disabled for -new installations since November 2019. -At the moment, Synapse doesn't support ACME v2. For more information and alternative -solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1 ---------------------------------------------------------------------------------""" - - -class AcmeHandler: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.reactor = hs.get_reactor() - self._acme_domain = hs.config.acme_domain - - async def start_listening(self) -> None: - from synapse.handlers import acme_issuing_service - - # Configure logging for txacme, if you need to debug - # from eliot import add_destinations - # from eliot.twisted import TwistedDestination - # - # add_destinations(TwistedDestination()) - - well_known = Resource() - - self._issuer = acme_issuing_service.create_issuing_service( - self.reactor, - acme_url=self.hs.config.acme_url, - account_key_file=self.hs.config.acme_account_key_file, - well_known_resource=well_known, - ) - - responder_resource = Resource() - responder_resource.putChild(b".well-known", well_known) - responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain")) - srv = server.Site(responder_resource) - - bind_addresses = self.hs.config.acme_bind_addresses - for host in bind_addresses: - logger.info( - "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port - ) - try: - self.reactor.listenTCP( - self.hs.config.acme_port, srv, backlog=50, interface=host - ) - except twisted.internet.error.CannotListenError as e: - check_bind_error(e, host, bind_addresses) - - # Make sure we are registered to the ACME server. There's no public API - # for this, it is usually triggered by startService, but since we don't - # want it to control where we save the certificates, we have to reach in - # and trigger the registration machinery ourselves. - self._issuer._registered = False - - try: - await self._issuer._ensure_registered() - except Exception: - logger.error(ACME_REGISTER_FAIL_ERROR) - raise - - async def provision_certificate(self) -> None: - - logger.warning("Reprovisioning %s", self._acme_domain) - - try: - await self._issuer.issue_cert(self._acme_domain) - except Exception: - logger.exception("Fail!") - raise - logger.warning("Reprovisioned %s, saving.", self._acme_domain) - cert_chain = self._issuer.cert_store.certs[self._acme_domain] - - try: - with open(self.hs.config.tls_private_key_file, "wb") as private_key_file: - for x in cert_chain: - if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"): - private_key_file.write(x) - - with open(self.hs.config.tls_certificate_file, "wb") as certificate_file: - for x in cert_chain: - if x.startswith(b"-----BEGIN CERTIFICATE-----"): - certificate_file.write(x) - except Exception: - logger.exception("Failed saving!") - raise diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py deleted file mode 100644 index a972d3fa..00000000 --- a/synapse/handlers/acme_issuing_service.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Utility function to create an ACME issuing service. - -This file contains the unconditional imports on the acme and cryptography bits that we -only need (and may only have available) if we are doing ACME, so is designed to be -imported conditionally. -""" -import logging -from typing import Dict, Iterable, List - -import attr -import pem -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization -from josepy import JWKRSA -from josepy.jwa import RS256 -from txacme.challenges import HTTP01Responder -from txacme.client import Client -from txacme.interfaces import ICertificateStore -from txacme.service import AcmeIssuingService -from txacme.util import generate_private_key -from zope.interface import implementer - -from twisted.internet import defer -from twisted.internet.interfaces import IReactorTCP -from twisted.python.filepath import FilePath -from twisted.python.url import URL -from twisted.web.resource import IResource - -logger = logging.getLogger(__name__) - - -def create_issuing_service( - reactor: IReactorTCP, - acme_url: str, - account_key_file: str, - well_known_resource: IResource, -) -> AcmeIssuingService: - """Create an ACME issuing service, and attach it to a web Resource - - Args: - reactor: twisted reactor - acme_url: URL to use to request certificates - account_key_file: where to store the account key - well_known_resource: web resource for .well-known. - we will attach a child resource for "acme-challenge". - - Returns: - AcmeIssuingService - """ - responder = HTTP01Responder() - - well_known_resource.putChild(b"acme-challenge", responder.resource) - - store = ErsatzStore() - - return AcmeIssuingService( - cert_store=store, - client_creator=( - lambda: Client.from_url( - reactor=reactor, - url=URL.from_text(acme_url), - key=load_or_create_client_key(account_key_file), - alg=RS256, - ) - ), - clock=reactor, - responders=[responder], - ) - - -@attr.s(slots=True) -@implementer(ICertificateStore) -class ErsatzStore: - """ - A store that only stores in memory. - """ - - certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict)) - - def store( - self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject] - ) -> defer.Deferred: - self.certs[server_name] = [o.as_bytes() for o in pem_objects] - return defer.succeed(None) - - -def load_or_create_client_key(key_file: str) -> JWKRSA: - """Load the ACME account key from a file, creating it if it does not exist. - - Args: - key_file: name of the file to use as the account key - """ - # this is based on txacme.endpoint.load_or_create_client_key, but doesn't - # hardcode the 'client.key' filename - acme_key_file = FilePath(key_file) - if acme_key_file.exists(): - logger.info("Loading ACME account key from '%s'", acme_key_file) - key = serialization.load_pem_private_key( - acme_key_file.getContent(), password=None, backend=default_backend() - ) - else: - logger.info("Saving new ACME account key to '%s'", acme_key_file) - key = generate_private_key("rsa") - acme_key_file.setContent( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - ) - return JWKRSA(key=key) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8a6666a4..1971e373 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -302,6 +302,7 @@ class AuthHandler(BaseHandler): request: SynapseRequest, request_body: Dict[str, Any], description: str, + can_skip_ui_auth: bool = False, ) -> Tuple[dict, Optional[str]]: """ Checks that the user is who they claim to be, via a UI auth. @@ -320,6 +321,10 @@ class AuthHandler(BaseHandler): description: A human readable string to be displayed to the user that describes the operation happening on their account. + can_skip_ui_auth: True if the UI auth session timeout applies this + action. Should be set to False for any "dangerous" + actions (e.g. deactivating an account). + Returns: A tuple of (params, session_id). @@ -343,7 +348,7 @@ class AuthHandler(BaseHandler): """ if not requester.access_token_id: raise ValueError("Cannot validate a user without an access token") - if self._ui_auth_session_timeout: + if can_skip_ui_auth and self._ui_auth_session_timeout: last_validated = await self.store.get_access_token_last_validated( requester.access_token_id ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 97448780..3972849d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -79,9 +79,15 @@ class E2eKeysHandler: "client_keys", self.on_federation_query_client_keys ) + # Limit the number of in-flight requests from a single device. + self._query_devices_linearizer = Linearizer( + name="query_devices", + max_count=10, + ) + @trace async def query_devices( - self, query_body: JsonDict, timeout: int, from_user_id: str + self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str ) -> JsonDict: """Handle a device key query from a client @@ -105,191 +111,197 @@ class E2eKeysHandler: from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. + from_device_id: the device making the query. This is used to limit + the number of in-flight queries at a time. """ - - device_keys_query = query_body.get( - "device_keys", {} - ) # type: Dict[str, Iterable[str]] - - # separate users by domain. - # make a map from domain to user_id to device_ids - local_query = {} - remote_queries = {} - - for user_id, device_ids in device_keys_query.items(): - # we use UserID.from_string to catch invalid user ids - if self.is_mine(UserID.from_string(user_id)): - local_query[user_id] = device_ids - else: - remote_queries[user_id] = device_ids - - set_tag("local_key_query", local_query) - set_tag("remote_key_query", remote_queries) - - # First get local devices. - # A map of destination -> failure response. - failures = {} # type: Dict[str, JsonDict] - results = {} - if local_query: - local_result = await self.query_local_devices(local_query) - for user_id, keys in local_result.items(): - if user_id in local_query: - results[user_id] = keys - - # Get cached cross-signing keys - cross_signing_keys = await self.get_cross_signing_keys_from_cache( - device_keys_query, from_user_id - ) - - # Now attempt to get any remote devices from our local cache. - # A map of destination -> user ID -> device IDs. - remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] - if remote_queries: - query_list = [] # type: List[Tuple[str, Optional[str]]] - for user_id, device_ids in remote_queries.items(): - if device_ids: - query_list.extend((user_id, device_id) for device_id in device_ids) + with await self._query_devices_linearizer.queue((from_user_id, from_device_id)): + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Iterable[str]] + + # separate users by domain. + # make a map from domain to user_id to device_ids + local_query = {} + remote_queries = {} + + for user_id, device_ids in device_keys_query.items(): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): + local_query[user_id] = device_ids else: - query_list.append((user_id, None)) - - ( - user_ids_not_in_cache, - remote_results, - ) = await self.store.get_user_devices_from_cache(query_list) - for user_id, devices in remote_results.items(): - user_devices = results.setdefault(user_id, {}) - for device_id, device in devices.items(): - keys = device.get("keys", None) - device_display_name = device.get("device_display_name", None) - if keys: - result = dict(keys) - unsigned = result.setdefault("unsigned", {}) - if device_display_name: - unsigned["device_display_name"] = device_display_name - user_devices[device_id] = result - - # check for missing cross-signing keys. - for user_id in remote_queries.keys(): - cached_cross_master = user_id in cross_signing_keys["master_keys"] - cached_cross_selfsigning = ( - user_id in cross_signing_keys["self_signing_keys"] - ) - - # check if we are missing only one of cross-signing master or - # self-signing key, but the other one is cached. - # as we need both, this will issue a federation request. - # if we don't have any of the keys, either the user doesn't have - # cross-signing set up, or the cached device list - # is not (yet) updated. - if cached_cross_master ^ cached_cross_selfsigning: - user_ids_not_in_cache.add(user_id) - - # add those users to the list to fetch over federation. - for user_id in user_ids_not_in_cache: - domain = get_domain_from_id(user_id) - r = remote_queries_not_in_cache.setdefault(domain, {}) - r[user_id] = remote_queries[user_id] - - # Now fetch any devices that we don't have in our cache - @trace - async def do_remote_query(destination): - """This is called when we are querying the device list of a user on - a remote homeserver and their device list is not in the device list - cache. If we share a room with this user and we're not querying for - specific user we will update the cache with their device list. - """ - - destination_query = remote_queries_not_in_cache[destination] - - # We first consider whether we wish to update the device list cache with - # the users device list. We want to track a user's devices when the - # authenticated user shares a room with the queried user and the query - # has not specified a particular device. - # If we update the cache for the queried user we remove them from further - # queries. We use the more efficient batched query_client_keys for all - # remaining users - user_ids_updated = [] - for (user_id, device_list) in destination_query.items(): - if user_id in user_ids_updated: - continue - - if device_list: - continue + remote_queries[user_id] = device_ids + + set_tag("local_key_query", local_query) + set_tag("remote_key_query", remote_queries) + + # First get local devices. + # A map of destination -> failure response. + failures = {} # type: Dict[str, JsonDict] + results = {} + if local_query: + local_result = await self.query_local_devices(local_query) + for user_id, keys in local_result.items(): + if user_id in local_query: + results[user_id] = keys - room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: - continue + # Get cached cross-signing keys + cross_signing_keys = await self.get_cross_signing_keys_from_cache( + device_keys_query, from_user_id + ) - # We've decided we're sharing a room with this user and should - # probably be tracking their device lists. However, we haven't - # done an initial sync on the device list so we do it now. - try: - if self._is_master: - user_devices = await self.device_handler.device_list_updater.user_device_resync( - user_id + # Now attempt to get any remote devices from our local cache. + # A map of destination -> user ID -> device IDs. + remote_queries_not_in_cache = ( + {} + ) # type: Dict[str, Dict[str, Iterable[str]]] + if remote_queries: + query_list = [] # type: List[Tuple[str, Optional[str]]] + for user_id, device_ids in remote_queries.items(): + if device_ids: + query_list.extend( + (user_id, device_id) for device_id in device_ids ) else: - user_devices = await self._user_device_resync_client( - user_id=user_id - ) - - user_devices = user_devices["devices"] - user_results = results.setdefault(user_id, {}) - for device in user_devices: - user_results[device["device_id"]] = device["keys"] - user_ids_updated.append(user_id) - except Exception as e: - failures[destination] = _exception_to_failure(e) - - if len(destination_query) == len(user_ids_updated): - # We've updated all the users in the query and we do not need to - # make any further remote calls. - return + query_list.append((user_id, None)) - # Remove all the users from the query which we have updated - for user_id in user_ids_updated: - destination_query.pop(user_id) + ( + user_ids_not_in_cache, + remote_results, + ) = await self.store.get_user_devices_from_cache(query_list) + for user_id, devices in remote_results.items(): + user_devices = results.setdefault(user_id, {}) + for device_id, device in devices.items(): + keys = device.get("keys", None) + device_display_name = device.get("device_display_name", None) + if keys: + result = dict(keys) + unsigned = result.setdefault("unsigned", {}) + if device_display_name: + unsigned["device_display_name"] = device_display_name + user_devices[device_id] = result + + # check for missing cross-signing keys. + for user_id in remote_queries.keys(): + cached_cross_master = user_id in cross_signing_keys["master_keys"] + cached_cross_selfsigning = ( + user_id in cross_signing_keys["self_signing_keys"] + ) - try: - remote_result = await self.federation.query_client_keys( - destination, {"device_keys": destination_query}, timeout=timeout - ) + # check if we are missing only one of cross-signing master or + # self-signing key, but the other one is cached. + # as we need both, this will issue a federation request. + # if we don't have any of the keys, either the user doesn't have + # cross-signing set up, or the cached device list + # is not (yet) updated. + if cached_cross_master ^ cached_cross_selfsigning: + user_ids_not_in_cache.add(user_id) + + # add those users to the list to fetch over federation. + for user_id in user_ids_not_in_cache: + domain = get_domain_from_id(user_id) + r = remote_queries_not_in_cache.setdefault(domain, {}) + r[user_id] = remote_queries[user_id] + + # Now fetch any devices that we don't have in our cache + @trace + async def do_remote_query(destination): + """This is called when we are querying the device list of a user on + a remote homeserver and their device list is not in the device list + cache. If we share a room with this user and we're not querying for + specific user we will update the cache with their device list. + """ + + destination_query = remote_queries_not_in_cache[destination] + + # We first consider whether we wish to update the device list cache with + # the users device list. We want to track a user's devices when the + # authenticated user shares a room with the queried user and the query + # has not specified a particular device. + # If we update the cache for the queried user we remove them from further + # queries. We use the more efficient batched query_client_keys for all + # remaining users + user_ids_updated = [] + for (user_id, device_list) in destination_query.items(): + if user_id in user_ids_updated: + continue + + if device_list: + continue + + room_ids = await self.store.get_rooms_for_user(user_id) + if not room_ids: + continue + + # We've decided we're sharing a room with this user and should + # probably be tracking their device lists. However, we haven't + # done an initial sync on the device list so we do it now. + try: + if self._is_master: + user_devices = await self.device_handler.device_list_updater.user_device_resync( + user_id + ) + else: + user_devices = await self._user_device_resync_client( + user_id=user_id + ) + + user_devices = user_devices["devices"] + user_results = results.setdefault(user_id, {}) + for device in user_devices: + user_results[device["device_id"]] = device["keys"] + user_ids_updated.append(user_id) + except Exception as e: + failures[destination] = _exception_to_failure(e) + + if len(destination_query) == len(user_ids_updated): + # We've updated all the users in the query and we do not need to + # make any further remote calls. + return + + # Remove all the users from the query which we have updated + for user_id in user_ids_updated: + destination_query.pop(user_id) - for user_id, keys in remote_result["device_keys"].items(): - if user_id in destination_query: - results[user_id] = keys + try: + remote_result = await self.federation.query_client_keys( + destination, {"device_keys": destination_query}, timeout=timeout + ) - if "master_keys" in remote_result: - for user_id, key in remote_result["master_keys"].items(): + for user_id, keys in remote_result["device_keys"].items(): if user_id in destination_query: - cross_signing_keys["master_keys"][user_id] = key + results[user_id] = keys - if "self_signing_keys" in remote_result: - for user_id, key in remote_result["self_signing_keys"].items(): - if user_id in destination_query: - cross_signing_keys["self_signing_keys"][user_id] = key + if "master_keys" in remote_result: + for user_id, key in remote_result["master_keys"].items(): + if user_id in destination_query: + cross_signing_keys["master_keys"][user_id] = key - except Exception as e: - failure = _exception_to_failure(e) - failures[destination] = failure - set_tag("error", True) - set_tag("reason", failure) + if "self_signing_keys" in remote_result: + for user_id, key in remote_result["self_signing_keys"].items(): + if user_id in destination_query: + cross_signing_keys["self_signing_keys"][user_id] = key - await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background(do_remote_query, destination) - for destination in remote_queries_not_in_cache - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) + except Exception as e: + failure = _exception_to_failure(e) + failures[destination] = failure + set_tag("error", True) + set_tag("reason", failure) + + await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(do_remote_query, destination) + for destination in remote_queries_not_in_cache + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) - ret = {"device_keys": results, "failures": failures} + ret = {"device_keys": results, "failures": failures} - ret.update(cross_signing_keys) + ret.update(cross_signing_keys) - return ret + return ret async def get_cross_signing_keys_from_cache( self, query: Iterable[str], from_user_id: Optional[str] diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index a0df16a3..989996b6 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -13,7 +13,12 @@ # limitations under the License. from typing import TYPE_CHECKING, Collection, Optional -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import ( + EventTypes, + JoinRules, + Membership, + RestrictedJoinRuleTypes, +) from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase @@ -42,7 +47,7 @@ class EventAuthHandler: Check whether a user can join a room without an invite due to restricted join rules. When joining a room with restricted joined rules (as defined in MSC3083), - the membership of spaces must be checked during a room join. + the membership of rooms must be checked during a room join. Args: state_ids: The state of the room as it currently is. @@ -67,20 +72,20 @@ class EventAuthHandler: if not await self.has_restricted_join_rules(state_ids, room_version): return - # Get the spaces which allow access to this room and check if the user is + # Get the rooms which allow access to this room and check if the user is # in any of them. - allowed_spaces = await self.get_spaces_that_allow_join(state_ids) - if not await self.is_user_in_rooms(allowed_spaces, user_id): + allowed_rooms = await self.get_rooms_that_allow_join(state_ids) + if not await self.is_user_in_rooms(allowed_rooms, user_id): raise AuthError( 403, - "You do not belong to any of the required spaces to join this room.", + "You do not belong to any of the required rooms to join this room.", ) async def has_restricted_join_rules( self, state_ids: StateMap[str], room_version: RoomVersion ) -> bool: """ - Return if the room has the proper join rules set for access via spaces. + Return if the room has the proper join rules set for access via rooms. Args: state_ids: The state of the room as it currently is. @@ -102,17 +107,17 @@ class EventAuthHandler: join_rules_event = await self._store.get_event(join_rules_event_id) return join_rules_event.content.get("join_rule") == JoinRules.MSC3083_RESTRICTED - async def get_spaces_that_allow_join( + async def get_rooms_that_allow_join( self, state_ids: StateMap[str] ) -> Collection[str]: """ - Generate a list of spaces which allow access to a room. + Generate a list of rooms in which membership allows access to a room. Args: - state_ids: The state of the room as it currently is. + state_ids: The current state of the room the user wishes to join Returns: - A collection of spaces which provide membership to the room. + A collection of room IDs. Membership in any of the rooms in the list grants the ability to join the target room. """ # If there's no join rule, then it defaults to invite (so this doesn't apply). join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) @@ -123,21 +128,25 @@ class EventAuthHandler: join_rules_event = await self._store.get_event(join_rules_event_id) # If allowed is of the wrong form, then only allow invited users. - allowed_spaces = join_rules_event.content.get("allow", []) - if not isinstance(allowed_spaces, list): + allow_list = join_rules_event.content.get("allow", []) + if not isinstance(allow_list, list): return () # Pull out the other room IDs, invalid data gets filtered. result = [] - for space in allowed_spaces: - if not isinstance(space, dict): + for allow in allow_list: + if not isinstance(allow, dict): + continue + + # If the type is unexpected, skip it. + if allow.get("type") != RestrictedJoinRuleTypes.ROOM_MEMBERSHIP: continue - space_id = space.get("space") - if not isinstance(space_id, str): + room_id = allow.get("room_id") + if not isinstance(room_id, str): continue - result.append(space_id) + result.append(room_id) return result diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index abbb7142..1b566dbf 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1,6 +1,5 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,6 +33,7 @@ from typing import ( ) import attr +from prometheus_client import Counter from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -102,6 +102,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +soft_failed_event_counter = Counter( + "synapse_federation_soft_failed_events_total", + "Events received over federation that we marked as soft_failed", +) + @attr.s(slots=True) class _NewEventInfo: @@ -1550,6 +1555,77 @@ class FederationHandler(BaseHandler): run_in_background(self._handle_queued_pdus, room_queue) + @log_function + async def do_knock( + self, + target_hosts: List[str], + room_id: str, + knockee: str, + content: JsonDict, + ) -> Tuple[str, int]: + """Sends the knock to the remote server. + + This first triggers a make_knock request that returns a partial + event that we can fill out and sign. This is then sent to the + remote server via send_knock. + + Knock events must be signed by the knockee's server before distributing. + + Args: + target_hosts: A list of hosts that we want to try knocking through. + room_id: The ID of the room to knock on. + knockee: The ID of the user who is knocking. + content: The content of the knock event. + + Returns: + A tuple of (event ID, stream ID). + + Raises: + SynapseError: If the chosen remote server returns a 3xx/4xx code. + RuntimeError: If no servers were reachable. + """ + logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee) + + # Inform the remote server of the room versions we support + supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys()) + + # Ask the remote server to create a valid knock event for us. Once received, + # we sign the event + params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]] + origin, event, event_format_version = await self._make_and_verify_event( + target_hosts, room_id, knockee, Membership.KNOCK, content, params=params + ) + + # Record the room ID and its version so that we have a record of the room + await self._maybe_store_room_on_outlier_membership( + room_id=event.room_id, room_version=event_format_version + ) + + # Initially try the host that we successfully called /make_knock on + try: + target_hosts.remove(origin) + target_hosts.insert(0, origin) + except ValueError: + pass + + # Send the signed event back to the room, and potentially receive some + # further information about the room in the form of partial state events + stripped_room_state = await self.federation_client.send_knock( + target_hosts, event + ) + + # Store any stripped room state events in the "unsigned" key of the event. + # This is a bit of a hack and is cribbing off of invites. Basically we + # store the room state here and retrieve it again when this event appears + # in the invitee's sync stream. It is stripped out for all other local users. + event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] + + context = await self.state_handler.compute_event_context(event) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) + return event.event_id, stream_id + async def _handle_queued_pdus( self, room_queue: List[Tuple[EventBase, str]] ) -> None: @@ -1885,7 +1961,7 @@ class FederationHandler(BaseHandler): return event async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None: - """ We have received a leave event for a room. Fully process it.""" + """We have received a leave event for a room. Fully process it.""" event = pdu logger.debug( @@ -1915,6 +1991,114 @@ class FederationHandler(BaseHandler): return None + @log_function + async def on_make_knock_request( + self, origin: str, room_id: str, user_id: str + ) -> EventBase: + """We've received a make_knock request, so we create a partial + knock event for the room and return that. We do *not* persist or + process it until the other server has signed it and sent it back. + + Args: + origin: The (verified) server name of the requesting server. + room_id: The room to create the knock event in. + user_id: The user to create the knock for. + + Returns: + The partial knock event. + """ + if get_domain_from_id(user_id) != origin: + logger.info( + "Get /make_knock request for user %r from different origin %s, ignoring", + user_id, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + room_version = await self.store.get_room_version_id(room_id) + + builder = self.event_builder_factory.new( + room_version, + { + "type": EventTypes.Member, + "content": {"membership": Membership.KNOCK}, + "room_id": room_id, + "sender": user_id, + "state_key": user_id, + }, + ) + + event, context = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.warning("Creation of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + try: + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_knock_request` + await self.auth.check_from_context( + room_version, event, context, do_sig_check=False + ) + except AuthError as e: + logger.warning("Failed to create new knock %r because %s", event, e) + raise e + + return event + + @log_function + async def on_send_knock_request( + self, origin: str, event: EventBase + ) -> EventContext: + """ + We have received a knock event for a room. Verify that event and send it into the room + on the knocking homeserver's behalf. + + Args: + origin: The remote homeserver of the knocking user. + event: The knocking member event that has been signed by the remote homeserver. + + Returns: + The context of the event after inserting it into the room graph. + """ + logger.debug( + "on_send_knock_request: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_knock request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + event.internal_metadata.outlier = False + + context = await self.state_handler.compute_event_context(event) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + await self._auth_and_persist_event(origin, event, context) + + return context + async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event.""" @@ -2239,7 +2423,11 @@ class FederationHandler(BaseHandler): ) async def _check_for_soft_fail( - self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool + self, + event: EventBase, + state: Optional[Iterable[EventBase]], + backfilled: bool, + origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as such. @@ -2248,6 +2436,7 @@ class FederationHandler(BaseHandler): event state: The state at the event if we don't have all the event's prev events backfilled: Whether the event is from backfill + origin: The host the event originates from. """ # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we @@ -2317,7 +2506,18 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version_obj, event, auth_events=current_auth_events) except AuthError as e: - logger.warning("Soft-failing %r because %s", event, e) + logger.warning( + "Soft-failing %r (from %s) because %s", + event, + e, + origin, + extra={ + "room_id": event.room_id, + "mxid": event.sender, + "hs": origin, + }, + ) + soft_failed_event_counter.inc() event.internal_metadata.soft_failed = True async def on_get_missing_events( @@ -2429,7 +2629,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR if not context.rejected: - await self._check_for_soft_fail(event, state, backfilled) + await self._check_for_soft_fail(event, state, backfilled, origin=origin) if event.type == EventTypes.GuestAccess and not context.rejected: await self.maybe_kick_guest_users(event) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9f365eb5..db12abd5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1,6 +1,7 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyrignt 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -398,13 +399,14 @@ class EventCreationHandler: self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() - self.room_invite_state_types = self.hs.config.api.room_prejoin_state + self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state - self.membership_types_to_include_profile_data_in = ( - {Membership.JOIN, Membership.INVITE} - if self.hs.config.include_profile_data_on_invite - else {Membership.JOIN} - ) + self.membership_types_to_include_profile_data_in = { + Membership.JOIN, + Membership.KNOCK, + } + if self.hs.config.include_profile_data_on_invite: + self.membership_types_to_include_profile_data_in.add(Membership.INVITE) self.send_event = ReplicationSendEventRestServlet.make_client(hs) @@ -480,6 +482,9 @@ class EventCreationHandler: prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, require_consent: bool = True, + outlier: bool = False, + historical: bool = False, + depth: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ Given a dict from a client, create a new event. @@ -506,6 +511,14 @@ class EventCreationHandler: require_consent: Whether to check if the requester has consented to the privacy policy. + + outlier: Indicates whether the event is an `outlier`, i.e. if + it's from an arbitrary point and floating in the DAG as + opposed to being inline with the current DAG. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. + Raises: ResourceLimitError if server is blocked to some resource being exceeded @@ -561,11 +574,36 @@ class EventCreationHandler: if txn_id is not None: builder.internal_metadata.txn_id = txn_id + builder.internal_metadata.outlier = outlier + + builder.internal_metadata.historical = historical + + # Strip down the auth_event_ids to only what we need to auth the event. + # For example, we don't need extra m.room.member that don't match event.sender + if auth_event_ids is not None: + temp_event = await builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + depth=depth, + ) + auth_events = await self.store.get_events_as_list(auth_event_ids) + # Create a StateMap[str] + auth_event_state_map = { + (e.type, e.state_key): e.event_id for e in auth_events + } + # Actually strip down and use the necessary auth events + auth_event_ids = self.auth.compute_auth_events( + event=temp_event, + current_state_ids=auth_event_state_map, + for_verification=False, + ) + event, context = await self.create_new_client_event( builder=builder, requester=requester, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, + depth=depth, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -722,9 +760,13 @@ class EventCreationHandler: self, requester: Requester, event_dict: dict, + prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, ratelimit: bool = True, txn_id: Optional[str] = None, ignore_shadow_ban: bool = False, + outlier: bool = False, + depth: Optional[int] = None, ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -734,10 +776,24 @@ class EventCreationHandler: Args: requester: The requester sending the event. event_dict: An entire event. + prev_event_ids: + The event IDs to use as the prev events. + Should normally be left as None to automatically request them + from the database. + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. ratelimit: Whether to rate limit this send. txn_id: The transaction ID. ignore_shadow_ban: True if shadow-banned users should be allowed to send this event. + outlier: Indicates whether the event is an `outlier`, i.e. if + it's from an arbitrary point and floating in the DAG as + opposed to being inline with the current DAG. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. Returns: The event, and its stream ordering (if deduplication happened, @@ -777,7 +833,13 @@ class EventCreationHandler: return event, event.internal_metadata.stream_ordering event, context = await self.create_event( - requester, event_dict, txn_id=txn_id + requester, + event_dict, + txn_id=txn_id, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + outlier=outlier, + depth=depth, ) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( @@ -809,6 +871,7 @@ class EventCreationHandler: requester: Optional[Requester] = None, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -826,6 +889,10 @@ class EventCreationHandler: Should normally be left as None, which will cause them to be calculated based on the room state at the prev_events. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. + Returns: Tuple of created event, context """ @@ -849,9 +916,24 @@ class EventCreationHandler: ), "Attempting to create an event with no prev_events" event = await builder.build( - prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + depth=depth, ) - context = await self.state.compute_event_context(event) + + old_state = None + + # Pass on the outlier property from the builder to the event + # after it is created + if builder.internal_metadata.outlier: + event.internal_metadata.outlier = builder.internal_metadata.outlier + + # Calculate the state for outliers that pass in their own `auth_event_ids` + if auth_event_ids: + old_state = await self.store.get_events_as_list(auth_event_ids) + + context = await self.state.compute_event_context(event, old_state=old_state) + if requester: context.app_service = requester.app_service @@ -961,8 +1043,8 @@ class EventCreationHandler: room_version = await self.store.get_room_version_id(event.room_id) if event.internal_metadata.is_out_of_band_membership(): - # the only sort of out-of-band-membership events we expect to see here - # are invite rejections we have generated ourselves. + # the only sort of out-of-band-membership events we expect to see here are + # invite rejections and rescinded knocks that we have generated ourselves. assert event.type == EventTypes.Member assert event.content["membership"] == Membership.LEAVE else: @@ -1016,7 +1098,13 @@ class EventCreationHandler: the arguments. """ - await self.action_generator.handle_push_actions_for_event(event, context) + # Skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). + if not event.internal_metadata.is_historical(): + await self.action_generator.handle_push_actions_for_event(event, context) try: # If we're a worker we need to hit out to the master. @@ -1239,7 +1327,7 @@ class EventCreationHandler: "invite_room_state" ] = await self.store.get_stripped_room_state_from_event_context( context, - self.room_invite_state_types, + self.room_prejoin_state_types, membership_user_id=event.sender, ) @@ -1257,6 +1345,14 @@ class EventCreationHandler: # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) + if event.content["membership"] == Membership.KNOCK: + event.unsigned[ + "knock_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + ) + if event.type == EventTypes.Redaction: original_event = await self.store.get_event( event.redacts, @@ -1307,13 +1403,21 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") + # Mark any `m.historical` messages as backfilled so they don't appear + # in `/sync` and have the proper decrementing `stream_ordering` as we import + backfilled = False + if event.internal_metadata.is_historical(): + backfilled = True + # Note that this returns the event that was persisted, which may not be # the same as we passed in if it was deduplicated due transaction IDs. ( event, event_pos, max_stream_token, - ) = await self.storage.persistence.persist_event(event, context=context) + ) = await self.storage.persistence.persist_event( + event, context=context, backfilled=backfilled + ) if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4ceef3fa..ca1ed6a5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -195,7 +195,7 @@ class RegistrationHandler(BaseHandler): bind_emails: list of emails to bind to this account. by_admin: True if this registration is being made via the admin api, otherwise False. - user_agent_ips: Tuples of IP addresses and user-agents used + user_agent_ips: Tuples of user-agents and IP addresses used during the registration process. auth_provider_id: The SSO IdP the user used, if any. Returns: diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 141c9c04..5e3ef7ce 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -44,7 +44,7 @@ class RoomListHandler(BaseHandler): self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache( hs.get_clock(), "room_list" - ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]] + ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]] self.remote_response_cache = ResponseCache( hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] @@ -54,7 +54,7 @@ class RoomListHandler(BaseHandler): limit: Optional[int] = None, since_token: Optional[str] = None, search_filter: Optional[dict] = None, - network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, from_federation: bool = False, ) -> JsonDict: """Generate a local public room list. @@ -111,7 +111,7 @@ class RoomListHandler(BaseHandler): limit: Optional[int] = None, since_token: Optional[str] = None, search_filter: Optional[dict] = None, - network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, from_federation: bool = False, ) -> JsonDict: """Generate a public room list. @@ -169,6 +169,7 @@ class RoomListHandler(BaseHandler): "world_readable": room["history_visibility"] == HistoryVisibility.WORLD_READABLE, "guest_can_join": room["guest_access"] == "can_join", + "join_rule": room["join_rules"], } # Filter out Nones – rather omit the field altogether diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d6fc43e7..11925916 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1,4 +1,5 @@ # Copyright 2016-2020 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import abc import logging import random @@ -30,7 +30,15 @@ from synapse.api.errors import ( from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID +from synapse.types import ( + JsonDict, + Requester, + RoomAlias, + RoomID, + StateMap, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room @@ -126,6 +134,24 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + async def remote_knock( + self, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Tuple[str, int]: + """Try and knock on a room that this server is not in + + Args: + remote_room_hosts: List of servers that can be used to knock via. + room_id: Room that we are trying to knock on. + user: User who is trying to knock. + content: A dict that should be used as the content of the knock event. + """ + raise NotImplementedError() + + @abc.abstractmethod async def remote_reject_invite( self, invite_event_id: str, @@ -149,6 +175,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """Rescind a local knock made on a remote room. + + Args: + knock_event_id: The ID of the knock event to rescind. + txn_id: An optional transaction ID supplied by the client. + requester: The user making the request, according to the access token. + content: The content of the generated leave event. + + Returns: + A tuple containing (event_id, stream_id of the leave event). + """ + raise NotImplementedError() + + @abc.abstractmethod async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the room. @@ -210,11 +257,42 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: str, membership: str, prev_event_ids: List[str], + auth_event_ids: Optional[List[str]] = None, txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, + outlier: bool = False, ) -> Tuple[str, int]: + """ + Internal membership update function to get an existing event or create + and persist a new event for the new membership change. + + Args: + requester: + target: + room_id: + membership: + prev_event_ids: The event IDs to use as the prev events + + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. + + txn_id: + ratelimit: + content: + require_consent: + + outlier: Indicates whether the event is an `outlier`, i.e. if + it's from an arbitrary point and floating in the DAG as + opposed to being inline with the current DAG. + + Returns: + Tuple of event ID and stream ordering position + """ + user_id = target.to_string() if content is None: @@ -251,7 +329,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): }, txn_id=txn_id, prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, require_consent=require_consent, + outlier=outlier, ) prev_state_ids = await context.get_prev_state_ids() @@ -352,6 +432,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, + outlier: bool = False, + prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -366,6 +449,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit: Whether to rate limit the request. content: The content of the created event. require_consent: Whether consent is required. + outlier: Indicates whether the event is an `outlier`, i.e. if + it's from an arbitrary point and floating in the DAG as + opposed to being inline with the current DAG. + prev_event_ids: The event IDs to use as the prev events + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. Returns: A tuple of the new event ID and stream ID. @@ -392,6 +483,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit=ratelimit, content=content, require_consent=require_consent, + outlier=outlier, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, ) return result @@ -408,10 +502,36 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, + outlier: bool = False, + prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: """Helper for update_membership. Assumes that the membership linearizer is already held for the room. + + Args: + requester: + target: + room_id: + action: + txn_id: + remote_room_hosts: + third_party_signed: + ratelimit: + content: + require_consent: + outlier: Indicates whether the event is an `outlier`, i.e. if + it's from an arbitrary point and floating in the DAG as + opposed to being inline with the current DAG. + prev_event_ids: The event IDs to use as the prev events + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. + + Returns: + A tuple of the new event ID and stream ID. """ content_specified = bool(content) if content is None: @@ -496,6 +616,21 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") + if prev_event_ids: + return await self._local_membership_update( + requester=requester, + target=target, + room_id=room_id, + membership=effective_membership_state, + txn_id=txn_id, + ratelimit=ratelimit, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + content=content, + require_consent=require_consent, + outlier=outlier, + ) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) current_state_ids = await self.state_handler.get_current_state_ids( @@ -603,53 +738,79 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: - # perhaps we've been invited + # Figure out the user's current membership state for the room ( current_membership_type, current_membership_event_id, ) = await self.store.get_local_current_membership_for_user_in_room( target.to_string(), room_id ) - if ( - current_membership_type != Membership.INVITE - or not current_membership_event_id - ): + if not current_membership_type or not current_membership_event_id: logger.info( "%s sent a leave request to %s, but that is not an active room " - "on this server, and there is no pending invite", + "on this server, or there is no pending invite or knock", target, room_id, ) raise SynapseError(404, "Not a known room") - invite = await self.store.get_event(current_membership_event_id) - logger.info( - "%s rejects invite to %s from %s", target, room_id, invite.sender - ) + # perhaps we've been invited + if current_membership_type == Membership.INVITE: + invite = await self.store.get_event(current_membership_event_id) + logger.info( + "%s rejects invite to %s from %s", + target, + room_id, + invite.sender, + ) - if not self.hs.is_mine_id(invite.sender): - # send the rejection to the inviter's HS (with fallback to - # local event) - return await self.remote_reject_invite( - invite.event_id, - txn_id, - requester, - content, + if not self.hs.is_mine_id(invite.sender): + # send the rejection to the inviter's HS (with fallback to + # local event) + return await self.remote_reject_invite( + invite.event_id, + txn_id, + requester, + content, + ) + + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath, which will also send the + # rejection out to any other servers we believe are still in the room. + + # thanks to overzealous cleaning up of event_forward_extremities in + # `delete_old_current_state_events`, it's possible to end up with no + # forward extremities here. If that happens, let's just hang the + # rejection off the invite event. + # + # see: https://github.com/matrix-org/synapse/issues/7139 + if len(latest_event_ids) == 0: + latest_event_ids = [invite.event_id] + + # or perhaps this is a remote room that a local user has knocked on + elif current_membership_type == Membership.KNOCK: + knock = await self.store.get_event(current_membership_event_id) + return await self.remote_rescind_knock( + knock.event_id, txn_id, requester, content ) - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath, which will also send the - # rejection out to any other servers we believe are still in the room. + elif effective_membership_state == Membership.KNOCK: + if not is_host_in_room: + # The knock needs to be sent over federation instead + remote_room_hosts.append(get_domain_from_id(room_id)) + + content["membership"] = Membership.KNOCK + + profile = self.profile_handler + if "displayname" not in content: + content["displayname"] = await profile.get_displayname(target) + if "avatar_url" not in content: + content["avatar_url"] = await profile.get_avatar_url(target) - # thanks to overzealous cleaning up of event_forward_extremities in - # `delete_old_current_state_events`, it's possible to end up with no - # forward extremities here. If that happens, let's just hang the - # rejection off the invite event. - # - # see: https://github.com/matrix-org/synapse/issues/7139 - if len(latest_event_ids) == 0: - latest_event_ids = [invite.event_id] + return await self.remote_knock( + remote_room_hosts, room_id, target, content + ) return await self._local_membership_update( requester=requester, @@ -659,8 +820,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): txn_id=txn_id, ratelimit=ratelimit, prev_event_ids=latest_event_ids, + auth_event_ids=auth_event_ids, content=content, require_consent=require_consent, + outlier=outlier, ) async def transfer_room_state_on_room_upgrade( @@ -1209,6 +1372,35 @@ class RoomMemberMasterHandler(RoomMemberHandler): invite_event, txn_id, requester, content ) + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """ + Rescinds a local knock made on a remote room + + Args: + knock_event_id: The ID of the knock event to rescind. + txn_id: The transaction ID to use. + requester: The originator of the request. + content: The content of the leave event. + + Implements RoomMemberHandler.remote_rescind_knock + """ + # TODO: We don't yet support rescinding knocks over federation + # as we don't know which homeserver to send it to. An obvious + # candidate is the remote homeserver we originally knocked through, + # however we don't currently store that information. + + # Just rescind the knock locally + knock_event = await self.store.get_event(knock_event_id) + return await self._generate_local_out_of_band_leave( + knock_event, txn_id, requester, content + ) + async def _generate_local_out_of_band_leave( self, previous_membership_event: EventBase, @@ -1272,6 +1464,36 @@ class RoomMemberMasterHandler(RoomMemberHandler): return result_event.event_id, result_event.internal_metadata.stream_ordering + async def remote_knock( + self, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Tuple[str, int]: + """Sends a knock to a room. Attempts to do so via one remote out of a given list. + + Args: + remote_room_hosts: A list of homeservers to try knocking through. + room_id: The ID of the room to knock on. + user: The user to knock on behalf of. + content: The content of the knock event. + + Returns: + A tuple of (event ID, stream ID). + """ + # filter ourselves out of remote_room_hosts + remote_room_hosts = [ + host for host in remote_room_hosts if host != self.hs.hostname + ] + + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + return await self.federation_handler.do_knock( + remote_room_hosts, room_id, user.to_string(), content=content + ) + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room""" user_left_room(self.distributor, target, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 3e89dd23..221552a2 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -1,4 +1,4 @@ -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,10 +19,12 @@ from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler from synapse.replication.http.membership import ( ReplicationRemoteJoinRestServlet as ReplRemoteJoin, + ReplicationRemoteKnockRestServlet as ReplRemoteKnock, ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, + ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ) -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,7 +37,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler): super().__init__(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs) + self._remote_knock_client = ReplRemoteKnock.make_client(hs) self._remote_reject_client = ReplRejectInvite.make_client(hs) + self._remote_rescind_client = ReplRescindKnock.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs) async def _remote_join( @@ -80,6 +84,53 @@ class RoomMemberWorkerHandler(RoomMemberHandler): ) return ret["event_id"], ret["stream_id"] + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """ + Rescinds a local knock made on a remote room + + Args: + knock_event_id: the knock event + txn_id: optional transaction ID supplied by the client + requester: user making the request, according to the access token + content: additional content to include in the leave event. + Normally an empty dict. + + Returns: + A tuple containing (event_id, stream_id of the leave event) + """ + ret = await self._remote_rescind_client( + knock_event_id=knock_event_id, + txn_id=txn_id, + requester=requester, + content=content, + ) + return ret["event_id"], ret["stream_id"] + + async def remote_knock( + self, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Tuple[str, int]: + """Sends a knock to a room. + + Implements RoomMemberHandler.remote_knock + """ + ret = await self._remote_knock_client( + remote_room_hosts=remote_room_hosts, + room_id=room_id, + user=user, + content=content, + ) + return ret["event_id"], ret["stream_id"] + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room""" await self._notify_change_client( diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 046dba6f..17fc47ce 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -160,14 +160,14 @@ class SpaceSummaryHandler: # Check if the user is a member of any of the allowed spaces # from the response. - allowed_spaces = room.get("allowed_spaces") + allowed_rooms = room.get("allowed_spaces") if ( not include_room - and allowed_spaces - and isinstance(allowed_spaces, list) + and allowed_rooms + and isinstance(allowed_rooms, list) ): include_room = await self._event_auth_handler.is_user_in_rooms( - allowed_spaces, requester + allowed_rooms, requester ) # Finally, if this isn't the requested room, check ourselves @@ -402,10 +402,7 @@ class SpaceSummaryHandler: return (), () return res.rooms, tuple( - ev.data - for ev in res.events - if ev.event_type == EventTypes.MSC1772_SPACE_CHILD - or ev.event_type == EventTypes.SpaceChild + ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild ) async def _is_room_accessible( @@ -448,21 +445,20 @@ class SpaceSummaryHandler: member_event_id = state_ids.get((EventTypes.Member, requester), None) # If they're in the room they can see info on it. - member_event = None if member_event_id: member_event = await self._store.get_event(member_event_id) if member_event.membership in (Membership.JOIN, Membership.INVITE): return True # Otherwise, check if they should be allowed access via membership in a space. - if self._event_auth_handler.has_restricted_join_rules( + if await self._event_auth_handler.has_restricted_join_rules( state_ids, room_version ): - allowed_spaces = ( - await self._event_auth_handler.get_spaces_that_allow_join(state_ids) + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join(state_ids) ) if await self._event_auth_handler.is_user_in_rooms( - allowed_spaces, requester + allowed_rooms, requester ): return True @@ -478,10 +474,10 @@ class SpaceSummaryHandler: if await self._event_auth_handler.has_restricted_join_rules( state_ids, room_version ): - allowed_spaces = ( - await self._event_auth_handler.get_spaces_that_allow_join(state_ids) + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join(state_ids) ) - for space_id in allowed_spaces: + for space_id in allowed_rooms: if await self._auth.check_host_in_room(space_id, origin): return True @@ -514,17 +510,12 @@ class SpaceSummaryHandler: current_state_ids[(EventTypes.Create, "")] ) - # TODO: update once MSC1772 lands - room_type = create_event.content.get(EventContentFields.ROOM_TYPE) - if not room_type: - room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE) - room_version = await self._store.get_room_version(room_id) - allowed_spaces = None + allowed_rooms = None if await self._event_auth_handler.has_restricted_join_rules( current_state_ids, room_version ): - allowed_spaces = await self._event_auth_handler.get_spaces_that_allow_join( + allowed_rooms = await self._event_auth_handler.get_rooms_that_allow_join( current_state_ids ) @@ -540,8 +531,8 @@ class SpaceSummaryHandler: ), "guest_can_join": stats["guest_access"] == "can_join", "creation_ts": create_event.origin_server_ts, - "room_type": room_type, - "allowed_spaces": allowed_spaces, + "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), + "allowed_spaces": allowed_rooms, } # Filter out Nones – rather omit the field altogether @@ -569,9 +560,7 @@ class SpaceSummaryHandler: [ event_id for key, event_id in current_state_ids.items() - # TODO: update once MSC1772 has been FCP for a period of time. - if key[0] == EventTypes.MSC1772_SPACE_CHILD - or key[0] == EventTypes.SpaceChild + if key[0] == EventTypes.SpaceChild ] ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 044ff06d..0b297e54 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -41,7 +41,12 @@ from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters +from synapse.types import ( + JsonDict, + UserID, + contains_invalid_mxid_characters, + create_requester, +) from synapse.util.async_helpers import Linearizer from synapse.util.stringutils import random_string @@ -185,11 +190,14 @@ class SsoHandler: self._auth_handler = hs.get_auth_handler() self._error_template = hs.config.sso_error_template self._bad_user_template = hs.config.sso_auth_bad_user_template + self._profile_handler = hs.get_profile_handler() # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. self._sso_auth_success_template = hs.config.sso_auth_success_template + self._sso_update_profile_information = hs.config.sso_update_profile_information + # a lock on the mappings self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) @@ -458,6 +466,21 @@ class SsoHandler: request.getClientIP(), ) new_user = True + elif self._sso_update_profile_information: + attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) + if attributes.display_name: + user_id_obj = UserID.from_string(user_id) + profile_display_name = await self._profile_handler.get_displayname( + user_id_obj + ) + if profile_display_name != attributes.display_name: + requester = create_requester( + user_id, + authenticated_entity=user_id, + ) + await self._profile_handler.set_displayname( + user_id_obj, requester, attributes.display_name, True + ) await self._auth_handler.complete_sso_login( user_id, diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 383e3402..4e45d1da 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -1,4 +1,5 @@ -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -230,6 +231,8 @@ class StatsHandler: room_stats_delta["left_members"] -= 1 elif prev_membership == Membership.BAN: room_stats_delta["banned_members"] -= 1 + elif prev_membership == Membership.KNOCK: + room_stats_delta["knocked_members"] -= 1 else: raise ValueError( "%r is not a valid prev_membership" % (prev_membership,) @@ -251,6 +254,8 @@ class StatsHandler: room_stats_delta["left_members"] += 1 elif membership == Membership.BAN: room_stats_delta["banned_members"] += 1 + elif membership == Membership.KNOCK: + room_stats_delta["knocked_members"] += 1 else: raise ValueError("%r is not a valid membership" % (membership,)) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b1c58ffd..b9a03610 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -49,7 +49,7 @@ from synapse.types import ( from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.response_cache import ResponseCache +from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext from synapse.util.metrics import Measure, measure_func from synapse.visibility import filter_events_for_client @@ -83,12 +83,15 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 +SyncRequestKey = Tuple[Any, ...] + + @attr.s(slots=True, frozen=True) class SyncConfig: user = attr.ib(type=UserID) filter_collection = attr.ib(type=FilterCollection) is_guest = attr.ib(type=bool) - request_key = attr.ib(type=Tuple[Any, ...]) + request_key = attr.ib(type=SyncRequestKey) device_id = attr.ib(type=Optional[str]) @@ -160,6 +163,16 @@ class InvitedSyncResult: @attr.s(slots=True, frozen=True) +class KnockedSyncResult: + room_id = attr.ib(type=str) + knock = attr.ib(type=EventBase) + + def __bool__(self) -> bool: + """Knocked rooms should always be reported to the client""" + return True + + +@attr.s(slots=True, frozen=True) class GroupsSyncResult: join = attr.ib(type=JsonDict) invite = attr.ib(type=JsonDict) @@ -192,6 +205,7 @@ class _RoomChanges: room_entries = attr.ib(type=List["RoomSyncResultBuilder"]) invited = attr.ib(type=List[InvitedSyncResult]) + knocked = attr.ib(type=List[KnockedSyncResult]) newly_joined_rooms = attr.ib(type=List[str]) newly_left_rooms = attr.ib(type=List[str]) @@ -205,6 +219,7 @@ class SyncResult: account_data: List of account_data events for the user. joined: JoinedSyncResult for each joined room. invited: InvitedSyncResult for each invited room. + knocked: KnockedSyncResult for each knocked on room. archived: ArchivedSyncResult for each archived room. to_device: List of direct messages for the device. device_lists: List of user_ids whose devices have changed @@ -220,6 +235,7 @@ class SyncResult: account_data = attr.ib(type=List[JsonDict]) joined = attr.ib(type=List[JoinedSyncResult]) invited = attr.ib(type=List[InvitedSyncResult]) + knocked = attr.ib(type=List[KnockedSyncResult]) archived = attr.ib(type=List[ArchivedSyncResult]) to_device = attr.ib(type=List[JsonDict]) device_lists = attr.ib(type=DeviceLists) @@ -236,6 +252,7 @@ class SyncResult: self.presence or self.joined or self.invited + or self.knocked or self.archived or self.account_data or self.to_device @@ -252,9 +269,9 @@ class SyncHandler: self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache( + self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( hs.get_clock(), "sync" - ) # type: ResponseCache[Tuple[Any, ...]] + ) self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() @@ -293,6 +310,7 @@ class SyncHandler: since_token, timeout, full_state, + cache_context=True, ) logger.debug("Returning sync response for %s", user_id) return res @@ -300,9 +318,10 @@ class SyncHandler: async def _wait_for_sync_for_user( self, sync_config: SyncConfig, - since_token: Optional[StreamToken] = None, - timeout: int = 0, - full_state: bool = False, + since_token: Optional[StreamToken], + timeout: int, + full_state: bool, + cache_context: ResponseCacheContext[SyncRequestKey], ) -> SyncResult: if since_token is None: sync_type = "initial_sync" @@ -329,13 +348,13 @@ class SyncHandler: if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result = await self.current_sync_for_user( + result: SyncResult = await self.current_sync_for_user( sync_config, since_token, full_state=full_state ) else: - def current_sync_callback(before_token, after_token): - return self.current_sync_for_user(sync_config, since_token) + async def current_sync_callback(before_token, after_token) -> SyncResult: + return await self.current_sync_for_user(sync_config, since_token) result = await self.notifier.wait_for_events( sync_config.user.to_string(), @@ -344,6 +363,17 @@ class SyncHandler: from_token=since_token, ) + # if nothing has happened in any of the users' rooms since /sync was called, + # the resultant next_batch will be the same as since_token (since the result + # is generated when wait_for_events is first called, and not regenerated + # when wait_for_events times out). + # + # If that happens, we mustn't cache it, so that when the client comes back + # with the same cache token, we don't immediately return the same empty + # result, causing a tightloop. (#8518) + if result.next_batch == since_token: + cache_context.should_cache = False + if result: if sync_config.filter_collection.lazy_load_members(): lazy_loaded = "true" @@ -1031,7 +1061,7 @@ class SyncHandler: res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_or_invited_users, _, _ = res + newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( @@ -1040,7 +1070,9 @@ class SyncHandler: if self.hs_config.use_presence and not block_all_presence_data: logger.debug("Fetching presence data") await self._generate_sync_entry_for_presence( - sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_or_knocked_users, ) logger.debug("Fetching to-device data") @@ -1049,7 +1081,7 @@ class SyncHandler: device_lists = await self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, - newly_joined_or_invited_users=newly_joined_or_invited_users, + newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, newly_left_rooms=newly_left_rooms, newly_left_users=newly_left_users, ) @@ -1083,6 +1115,7 @@ class SyncHandler: account_data=sync_result_builder.account_data, joined=sync_result_builder.joined, invited=sync_result_builder.invited, + knocked=sync_result_builder.knocked, archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, @@ -1142,7 +1175,7 @@ class SyncHandler: self, sync_result_builder: "SyncResultBuilder", newly_joined_rooms: Set[str], - newly_joined_or_invited_users: Set[str], + newly_joined_or_invited_or_knocked_users: Set[str], newly_left_rooms: Set[str], newly_left_users: Set[str], ) -> DeviceLists: @@ -1151,8 +1184,9 @@ class SyncHandler: Args: sync_result_builder newly_joined_rooms: Set of rooms user has joined since previous sync - newly_joined_or_invited_users: Set of users that have joined or - been invited to a room since previous sync. + newly_joined_or_invited_or_knocked_users: Set of users that have joined, + been invited to a room or are knocking on a room since + previous sync. newly_left_rooms: Set of rooms user has left since previous sync newly_left_users: Set of users that have left a room we're in since previous sync @@ -1163,7 +1197,9 @@ class SyncHandler: # We're going to mutate these fields, so lets copy them rather than # assume they won't get used later. - newly_joined_or_invited_users = set(newly_joined_or_invited_users) + newly_joined_or_invited_or_knocked_users = set( + newly_joined_or_invited_or_knocked_users + ) newly_left_users = set(newly_left_users) if since_token and since_token.device_list_key: @@ -1202,11 +1238,11 @@ class SyncHandler: # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: joined_users = await self.store.get_users_in_room(room_id) - newly_joined_or_invited_users.update(joined_users) + newly_joined_or_invited_or_knocked_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. - users_that_have_changed.update(newly_joined_or_invited_users) + users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) user_signatures_changed = ( await self.store.get_users_whose_signatures_changed( @@ -1452,6 +1488,7 @@ class SyncHandler: room_entries = room_changes.room_entries invited = room_changes.invited + knocked = room_changes.knocked newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms @@ -1472,9 +1509,10 @@ class SyncHandler: await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) + sync_result_builder.knocked.extend(knocked) - # Now we want to get any newly joined or invited users - newly_joined_or_invited_users = set() + # Now we want to get any newly joined, invited or knocking users + newly_joined_or_invited_or_knocked_users = set() newly_left_users = set() if since_token: for joined_sync in sync_result_builder.joined: @@ -1486,19 +1524,22 @@ class SyncHandler: if ( event.membership == Membership.JOIN or event.membership == Membership.INVITE + or event.membership == Membership.KNOCK ): - newly_joined_or_invited_users.add(event.state_key) + newly_joined_or_invited_or_knocked_users.add( + event.state_key + ) else: prev_content = event.unsigned.get("prev_content", {}) prev_membership = prev_content.get("membership", None) if prev_membership == Membership.JOIN: newly_left_users.add(event.state_key) - newly_left_users -= newly_joined_or_invited_users + newly_left_users -= newly_joined_or_invited_or_knocked_users return ( set(newly_joined_rooms), - newly_joined_or_invited_users, + newly_joined_or_invited_or_knocked_users, set(newly_left_rooms), newly_left_users, ) @@ -1553,6 +1594,7 @@ class SyncHandler: newly_left_rooms = [] room_entries = [] invited = [] + knocked = [] for room_id, events in mem_change_events_by_room_id.items(): logger.debug( "Membership changes in %s: [%s]", @@ -1632,9 +1674,17 @@ class SyncHandler: should_invite = non_joins[-1].membership == Membership.INVITE if should_invite: if event.sender not in ignored_users: - room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) - if room_sync: - invited.append(room_sync) + invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if invite_room_sync: + invited.append(invite_room_sync) + + # Only bother if our latest membership in the room is knock (and we haven't + # been accepted/rejected in the meantime). + should_knock = non_joins[-1].membership == Membership.KNOCK + if should_knock: + knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1]) + if knock_room_sync: + knocked.append(knock_room_sync) # Always include leave/ban events. Just take the last one. # TODO: How do we handle ban -> leave in same batch? @@ -1738,7 +1788,13 @@ class SyncHandler: ) room_entries.append(entry) - return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) + return _RoomChanges( + room_entries, + invited, + knocked, + newly_joined_rooms, + newly_left_rooms, + ) async def _get_all_rooms( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] @@ -1758,6 +1814,7 @@ class SyncHandler: membership_list = ( Membership.INVITE, + Membership.KNOCK, Membership.JOIN, Membership.LEAVE, Membership.BAN, @@ -1769,6 +1826,7 @@ class SyncHandler: room_entries = [] invited = [] + knocked = [] for event in room_list: if event.membership == Membership.JOIN: @@ -1788,8 +1846,11 @@ class SyncHandler: continue invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) + elif event.membership == Membership.KNOCK: + knock = await self.store.get_event(event.event_id) + knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock)) elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. + # Always send down rooms we were banned from or kicked from. if not sync_config.filter_collection.include_leave: if event.membership == Membership.LEAVE: if user_id == event.sender: @@ -1810,7 +1871,7 @@ class SyncHandler: ) ) - return _RoomChanges(room_entries, invited, [], []) + return _RoomChanges(room_entries, invited, knocked, [], []) async def _generate_room_entry( self, @@ -2101,6 +2162,7 @@ class SyncResultBuilder: account_data (list) joined (list[JoinedSyncResult]) invited (list[InvitedSyncResult]) + knocked (list[KnockedSyncResult]) archived (list[ArchivedSyncResult]) groups (GroupsSyncResult|None) to_device (list) @@ -2116,6 +2178,7 @@ class SyncResultBuilder: account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list)) joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list)) invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list)) + knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list)) archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list)) groups = attr.ib(type=Optional[GroupsSyncResult], default=None) to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) |