summaryrefslogtreecommitdiff
path: root/synapse/federation
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-08-15 10:52:27 +0100
committerAndrej Shadura <andrewsh@debian.org>2021-08-15 10:52:27 +0100
commita48716699a33ad533b4b6d088449e4bbc4528e38 (patch)
tree0807e24466a1b4044870b2f85bd703a1673b79a1 /synapse/federation
parent679ff900f5e9b83af346904d7c8604cc5917608d (diff)
New upstream version 1.40.0
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py28
-rw-r--r--synapse/federation/federation_client.py104
-rw-r--r--synapse/federation/federation_server.py58
-rw-r--r--synapse/federation/transport/client.py529
-rw-r--r--synapse/federation/transport/server.py13
5 files changed, 497 insertions, 235 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2bfe6a3d..024e440f 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -178,6 +178,34 @@ async def _check_sigs_on_pdu(
)
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+ # If this is a join event for a restricted room it may have been authorised
+ # via a different server from the sending server. Check those signatures.
+ if (
+ room_version.msc3083_join_rules
+ and pdu.type == EventTypes.Member
+ and pdu.membership == Membership.JOIN
+ and "join_authorised_via_users_server" in pdu.content
+ ):
+ authorising_server = get_domain_from_id(
+ pdu.content["join_authorised_via_users_server"]
+ )
+ try:
+ await keyring.verify_event_for_server(
+ authorising_server,
+ pdu,
+ pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+ )
+ except Exception as e:
+ errmsg = (
+ "event id %s: unable to verify signature for authorising server %s: %s"
+ % (
+ pdu.event_id,
+ authorising_server,
+ e,
+ )
+ )
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+
def _is_invite_via_3pid(event: EventBase) -> bool:
return (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c767d306..b7a10da1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,10 +19,10 @@ import itertools
import logging
from typing import (
TYPE_CHECKING,
- Any,
Awaitable,
Callable,
Collection,
+ Container,
Dict,
Iterable,
List,
@@ -79,7 +79,15 @@ class InvalidResponseError(RuntimeError):
we couldn't parse
"""
- pass
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SendJoinResult:
+ # The event to persist.
+ event: EventBase
+ # A string giving the server the event was sent to.
+ origin: str
+ state: List[EventBase]
+ auth_chain: List[EventBase]
class FederationClient(FederationBase):
@@ -506,6 +514,7 @@ class FederationClient(FederationBase):
description: str,
destinations: Iterable[str],
callback: Callable[[str], Awaitable[T]],
+ failover_errcodes: Optional[Container[str]] = None,
failover_on_unknown_endpoint: bool = False,
) -> T:
"""Try an operation on a series of servers, until it succeeds
@@ -526,6 +535,9 @@ class FederationClient(FederationBase):
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
+ failover_errcodes: Error codes (specific to this endpoint) which should
+ cause a failover when received as part of an HTTP 400 error.
+
failover_on_unknown_endpoint: if True, we will try other servers if it looks
like a server doesn't support the endpoint. This is typically useful
if the endpoint in question is new or experimental.
@@ -537,6 +549,9 @@ class FederationClient(FederationBase):
SynapseError if the chosen remote server returns a 300/400 code, or
no servers were reachable.
"""
+ if failover_errcodes is None:
+ failover_errcodes = ()
+
for destination in destinations:
if destination == self.server_name:
continue
@@ -551,11 +566,17 @@ class FederationClient(FederationBase):
synapse_error = e.to_synapse_error()
failover = False
- # Failover on an internal server error, or if the destination
- # doesn't implemented the endpoint for some reason.
+ # Failover should occur:
+ #
+ # * On internal server errors.
+ # * If the destination responds that it cannot complete the request.
+ # * If the destination doesn't implemented the endpoint for some reason.
if 500 <= e.code < 600:
failover = True
+ elif e.code == 400 and synapse_error.errcode in failover_errcodes:
+ failover = True
+
elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
e, synapse_error
):
@@ -671,13 +692,25 @@ class FederationClient(FederationBase):
return destination, ev, room_version
+ # MSC3083 defines additional error codes for room joins. Unfortunately
+ # we do not yet know the room version, assume these will only be returned
+ # by valid room versions.
+ failover_errcodes = (
+ (Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN)
+ if membership == Membership.JOIN
+ else None
+ )
+
return await self._try_destination_list(
- "make_" + membership, destinations, send_request
+ "make_" + membership,
+ destinations,
+ send_request,
+ failover_errcodes=failover_errcodes,
)
async def send_join(
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
- ) -> Dict[str, Any]:
+ ) -> SendJoinResult:
"""Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph,
@@ -691,18 +724,38 @@ class FederationClient(FederationBase):
did the make_join)
Returns:
- a dict with members ``origin`` (a string
- giving the server the event was sent to, ``state`` (?) and
- ``auth_chain``.
+ The result of the send join request.
Raises:
SynapseError: if the chosen remote server returns a 300/400 code, or
no servers successfully handle the request.
"""
- async def send_request(destination) -> Dict[str, Any]:
+ async def send_request(destination) -> SendJoinResult:
response = await self._do_send_join(room_version, destination, pdu)
+ # If an event was returned (and expected to be returned):
+ #
+ # * Ensure it has the same event ID (note that the event ID is a hash
+ # of the event fields for versions which support MSC3083).
+ # * Ensure the signatures are good.
+ #
+ # Otherwise, fallback to the provided event.
+ if room_version.msc3083_join_rules and response.event:
+ event = response.event
+
+ valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+ pdu=event,
+ origin=destination,
+ outlier=True,
+ room_version=room_version,
+ )
+
+ if valid_pdu is None or event.event_id != pdu.event_id:
+ raise InvalidResponseError("Returned an invalid join event")
+ else:
+ event = pdu
+
state = response.state
auth_chain = response.auth_events
@@ -784,13 +837,32 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,)
)
- return {
- "state": signed_state,
- "auth_chain": signed_auth,
- "origin": destination,
- }
+ return SendJoinResult(
+ event=event,
+ state=signed_state,
+ auth_chain=signed_auth,
+ origin=destination,
+ )
- return await self._try_destination_list("send_join", destinations, send_request)
+ # MSC3083 defines additional error codes for room joins.
+ failover_errcodes = None
+ if room_version.msc3083_join_rules:
+ failover_errcodes = (
+ Codes.UNABLE_AUTHORISE_JOIN,
+ Codes.UNABLE_TO_GRANT_JOIN,
+ )
+
+ # If the join is being authorised via allow rules, we need to send
+ # the /send_join back to the same server that was originally used
+ # with /make_join.
+ if "join_authorised_via_users_server" in pdu.content:
+ destinations = [
+ get_domain_from_id(pdu.content["join_authorised_via_users_server"])
+ ]
+
+ return await self._try_destination_list(
+ "send_join", destinations, send_request, failover_errcodes=failover_errcodes
+ )
async def _do_send_join(
self, room_version: RoomVersion, destination: str, pdu: EventBase
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 29619aee..145b9161 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.crypto.event_signing import compute_event_signature
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.storage.databases.main.lock import Lock
-from synapse.types import JsonDict
+from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -586,7 +587,7 @@ class FederationServer(FederationBase):
async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
) -> Dict[str, Any]:
- context = await self._on_send_membership_event(
+ event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)
@@ -597,6 +598,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
return {
+ "org.matrix.msc3083.v2.event": event.get_pdu_json(),
"state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
}
@@ -681,7 +683,7 @@ class FederationServer(FederationBase):
Returns:
The stripped room state.
"""
- event_context = await self._on_send_membership_event(
+ _, context = await self._on_send_membership_event(
origin, content, Membership.KNOCK, room_id
)
@@ -690,14 +692,14 @@ class FederationServer(FederationBase):
# related to the room while the knock request is pending.
stripped_room_state = (
await self.store.get_stripped_room_state_from_event_context(
- event_context, self._room_prejoin_state_types
+ context, self._room_prejoin_state_types
)
)
return {"knock_state_events": stripped_room_state}
async def _on_send_membership_event(
self, origin: str, content: JsonDict, membership_type: str, room_id: str
- ) -> EventContext:
+ ) -> Tuple[EventBase, EventContext]:
"""Handle an on_send_{join,leave,knock} request
Does some preliminary validation before passing the request on to the
@@ -712,7 +714,7 @@ class FederationServer(FederationBase):
in the event
Returns:
- The context of the event after inserting it into the room graph.
+ The event and context of the event after inserting it into the room graph.
Raises:
SynapseError if there is a problem with the request, including things like
@@ -748,6 +750,33 @@ class FederationServer(FederationBase):
logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
+ # Sign the event since we're vouching on behalf of the remote server that
+ # the event is valid to be sent into the room. Currently this is only done
+ # if the user is being joined via restricted join rules.
+ if (
+ room_version.msc3083_join_rules
+ and event.membership == Membership.JOIN
+ and "join_authorised_via_users_server" in event.content
+ ):
+ # We can only authorise our own users.
+ authorising_server = get_domain_from_id(
+ event.content["join_authorised_via_users_server"]
+ )
+ if authorising_server != self.server_name:
+ raise SynapseError(
+ 400,
+ f"Cannot authorise request from resident server: {authorising_server}",
+ )
+
+ event.signatures.update(
+ compute_event_signature(
+ room_version,
+ event.get_pdu_json(),
+ self.hs.hostname,
+ self.hs.signing_key,
+ )
+ )
+
event = await self._check_sigs_and_hash(room_version, event)
return await self.handler.on_send_membership_event(origin, event)
@@ -995,6 +1024,23 @@ class FederationServer(FederationBase):
origin, event = next
+ # Prune the event queue if it's getting large.
+ #
+ # We do this *after* handling the first event as the common case is
+ # that the queue is empty (/has the single event in), and so there's
+ # no need to do this check.
+ pruned = await self.store.prune_staged_events_in_room(room_id, room_version)
+ if pruned:
+ # If we have pruned the queue check we need to refetch the next
+ # event to handle.
+ next = await self.store.get_next_staged_event_for_room(
+ room_id, room_version
+ )
+ if not next:
+ break
+
+ origin, event = next
+
lock = await self.store.try_acquire_lock(
_INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 98b1bf77..6a8d3ad4 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,7 +15,7 @@
import logging
import urllib
-from typing import Any, Dict, List, Optional
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import attr
import ijson
@@ -29,6 +29,7 @@ from synapse.api.urls import (
FEDERATION_V2_PREFIX,
)
from synapse.events import EventBase, make_event_from_dict
+from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
from synapse.logging.utils import log_function
from synapse.types import JsonDict
@@ -49,23 +50,25 @@ class TransportLayerClient:
self.client = hs.get_federation_http_client()
@log_function
- def get_room_state_ids(self, destination, room_id, event_id):
+ async def get_room_state_ids(
+ self, destination: str, room_id: str, event_id: str
+ ) -> JsonDict:
"""Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
- destination (str): The host name of the remote homeserver we want
+ destination: The host name of the remote homeserver we want
to get the state from.
- context (str): The name of the context we want the state of
- event_id (str): The event we want the context at.
+ context: The name of the context we want the state of
+ event_id: The event we want the context at.
Returns:
- Awaitable: Results in a dict received from the remote homeserver.
+ Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
path = _create_v1_path("/state_ids/%s", room_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination,
path=path,
args={"event_id": event_id},
@@ -73,39 +76,43 @@ class TransportLayerClient:
)
@log_function
- def get_event(self, destination, event_id, timeout=None):
+ async def get_event(
+ self, destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> JsonDict:
"""Requests the pdu with give id and origin from the given server.
Args:
- destination (str): The host name of the remote homeserver we want
+ destination: The host name of the remote homeserver we want
to get the state from.
- event_id (str): The id of the event being requested.
- timeout (int): How long to try (in ms) the destination for before
+ event_id: The id of the event being requested.
+ timeout: How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns:
- Awaitable: Results in a dict received from the remote homeserver.
+ Results in a dict received from the remote homeserver.
"""
logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
path = _create_v1_path("/event/%s", event_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
)
@log_function
- def backfill(self, destination, room_id, event_tuples, limit):
+ async def backfill(
+ self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int
+ ) -> Optional[JsonDict]:
"""Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
- dest (str)
- room_id (str)
- event_tuples (list)
- limit (int)
+ destination
+ room_id
+ event_tuples
+ limit
Returns:
- Awaitable: Results in a dict received from the remote homeserver.
+ Results in a dict received from the remote homeserver.
"""
logger.debug(
"backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
@@ -117,18 +124,22 @@ class TransportLayerClient:
if not event_tuples:
# TODO: raise?
- return
+ return None
path = _create_v1_path("/backfill/%s", room_id)
args = {"v": event_tuples, "limit": [str(limit)]}
- return self.client.get_json(
+ return await self.client.get_json(
destination, path=path, args=args, try_trailing_slash_on_400=True
)
@log_function
- async def send_transaction(self, transaction, json_data_callback=None):
+ async def send_transaction(
+ self,
+ transaction: Transaction,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ ) -> JsonDict:
"""Sends the given Transaction to its destination
Args:
@@ -149,21 +160,21 @@ class TransportLayerClient:
"""
logger.debug(
"send_data dest=%s, txid=%s",
- transaction.destination,
- transaction.transaction_id,
+ transaction.destination, # type: ignore
+ transaction.transaction_id, # type: ignore
)
- if transaction.destination == self.server_name:
+ if transaction.destination == self.server_name: # type: ignore
raise RuntimeError("Transport layer cannot send to itself!")
# FIXME: This is only used by the tests. The actual json sent is
# generated by the json_data_callback.
json_data = transaction.get_dict()
- path = _create_v1_path("/send/%s", transaction.transaction_id)
+ path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore
- response = await self.client.put_json(
- transaction.destination,
+ return await self.client.put_json(
+ transaction.destination, # type: ignore
path=path,
data=json_data,
json_data_callback=json_data_callback,
@@ -172,8 +183,6 @@ class TransportLayerClient:
try_trailing_slash_on_400=True,
)
- return response
-
@log_function
async def make_query(
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
@@ -193,8 +202,13 @@ class TransportLayerClient:
@log_function
async def make_membership_event(
- self, destination, room_id, user_id, membership, params
- ):
+ self,
+ destination: str,
+ room_id: str,
+ user_id: str,
+ membership: str,
+ params: Optional[Mapping[str, Union[str, Iterable[str]]]],
+ ) -> JsonDict:
"""Asks a remote server to build and sign us a membership event
Note that this does not append any events to any graphs.
@@ -240,7 +254,7 @@ class TransportLayerClient:
ignore_backoff = True
retry_on_dns_fail = True
- content = await self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args=params,
@@ -249,20 +263,18 @@ class TransportLayerClient:
ignore_backoff=ignore_backoff,
)
- return content
-
@log_function
async def send_join_v1(
self,
- room_version,
- destination,
- room_id,
- event_id,
- content,
+ room_version: RoomVersion,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ content: JsonDict,
) -> "SendJoinResponse":
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
data=content,
@@ -270,15 +282,18 @@ class TransportLayerClient:
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
- return response
-
@log_function
async def send_join_v2(
- self, room_version, destination, room_id, event_id, content
+ self,
+ room_version: RoomVersion,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ content: JsonDict,
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
data=content,
@@ -286,13 +301,13 @@ class TransportLayerClient:
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
- return response
-
@log_function
- async def send_leave_v1(self, destination, room_id, event_id, content):
+ async def send_leave_v1(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict
+ ) -> Tuple[int, JsonDict]:
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
data=content,
@@ -303,13 +318,13 @@ class TransportLayerClient:
ignore_backoff=True,
)
- return response
-
@log_function
- async def send_leave_v2(self, destination, room_id, event_id, content):
+ async def send_leave_v2(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict
+ ) -> JsonDict:
path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
data=content,
@@ -320,8 +335,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- return response
-
@log_function
async def send_knock_v1(
self,
@@ -357,25 +370,25 @@ class TransportLayerClient:
)
@log_function
- async def send_invite_v1(self, destination, room_id, event_id, content):
+ async def send_invite_v1(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict
+ ) -> Tuple[int, JsonDict]:
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
- return response
-
@log_function
- async def send_invite_v2(self, destination, room_id, event_id, content):
+ async def send_invite_v2(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict
+ ) -> JsonDict:
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
- return response
-
@log_function
async def get_public_rooms(
self,
@@ -385,7 +398,7 @@ class TransportLayerClient:
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
- ):
+ ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver
See synapse.federation.federation_client.FederationClient.get_public_rooms for
@@ -450,25 +463,27 @@ class TransportLayerClient:
return response
@log_function
- async def exchange_third_party_invite(self, destination, room_id, event_dict):
+ async def exchange_third_party_invite(
+ self, destination: str, room_id: str, event_dict: JsonDict
+ ) -> JsonDict:
path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination, path=path, data=event_dict
)
- return response
-
@log_function
- async def get_event_auth(self, destination, room_id, event_id):
+ async def get_event_auth(
+ self, destination: str, room_id: str, event_id: str
+ ) -> JsonDict:
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
- content = await self.client.get_json(destination=destination, path=path)
-
- return content
+ return await self.client.get_json(destination=destination, path=path)
@log_function
- async def query_client_keys(self, destination, query_content, timeout):
+ async def query_client_keys(
+ self, destination: str, query_content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote
server.
@@ -496,20 +511,21 @@ class TransportLayerClient:
}
Args:
- destination(str): The server to query.
- query_content(dict): The user ids to query.
+ destination: The server to query.
+ query_content: The user ids to query.
Returns:
A dict containing device and cross-signing keys.
"""
path = _create_v1_path("/user/keys/query")
- content = await self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
- return content
@log_function
- async def query_user_devices(self, destination, user_id, timeout):
+ async def query_user_devices(
+ self, destination: str, user_id: str, timeout: int
+ ) -> JsonDict:
"""Query the devices for a user id hosted on a remote server.
Response:
@@ -535,20 +551,21 @@ class TransportLayerClient:
}
Args:
- destination(str): The server to query.
- query_content(dict): The user ids to query.
+ destination: The server to query.
+ query_content: The user ids to query.
Returns:
A dict containing device and cross-signing keys.
"""
path = _create_v1_path("/user/devices/%s", user_id)
- content = await self.client.get_json(
+ return await self.client.get_json(
destination=destination, path=path, timeout=timeout
)
- return content
@log_function
- async def claim_client_keys(self, destination, query_content, timeout):
+ async def claim_client_keys(
+ self, destination: str, query_content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
@@ -572,33 +589,32 @@ class TransportLayerClient:
}
Args:
- destination(str): The server to query.
- query_content(dict): The user ids to query.
+ destination: The server to query.
+ query_content: The user ids to query.
Returns:
A dict containing the one-time keys.
"""
path = _create_v1_path("/user/keys/claim")
- content = await self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
- return content
@log_function
async def get_missing_events(
self,
- destination,
- room_id,
- earliest_events,
- latest_events,
- limit,
- min_depth,
- timeout,
- ):
+ destination: str,
+ room_id: str,
+ earliest_events: Iterable[str],
+ latest_events: Iterable[str],
+ limit: int,
+ min_depth: int,
+ timeout: int,
+ ) -> JsonDict:
path = _create_v1_path("/get_missing_events/%s", room_id)
- content = await self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
data={
@@ -610,14 +626,14 @@ class TransportLayerClient:
timeout=timeout,
)
- return content
-
@log_function
- def get_group_profile(self, destination, group_id, requester_user_id):
+ async def get_group_profile(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get a group profile"""
path = _create_v1_path("/groups/%s/profile", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -625,14 +641,16 @@ class TransportLayerClient:
)
@log_function
- def update_group_profile(self, destination, group_id, requester_user_id, content):
+ async def update_group_profile(
+ self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Update a remote group profile
Args:
- destination (str)
- group_id (str)
- requester_user_id (str)
- content (dict): The new profile of the group
+ destination
+ group_id
+ requester_user_id
+ content: The new profile of the group
"""
path = _create_v1_path("/groups/%s/profile", group_id)
@@ -645,11 +663,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_summary(self, destination, group_id, requester_user_id):
+ async def get_group_summary(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get a group summary"""
path = _create_v1_path("/groups/%s/summary", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -657,24 +677,31 @@ class TransportLayerClient:
)
@log_function
- def get_rooms_in_group(self, destination, group_id, requester_user_id):
+ async def get_rooms_in_group(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get all rooms in a group"""
path = _create_v1_path("/groups/%s/rooms", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
- def add_room_to_group(
- self, destination, group_id, requester_user_id, room_id, content
- ):
+ async def add_room_to_group(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ room_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Add a room to a group"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -682,15 +709,21 @@ class TransportLayerClient:
ignore_backoff=True,
)
- def update_room_in_group(
- self, destination, group_id, requester_user_id, room_id, config_key, content
- ):
+ async def update_room_in_group(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ room_id: str,
+ config_key: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update room in group"""
path = _create_v1_path(
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key
)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -698,11 +731,13 @@ class TransportLayerClient:
ignore_backoff=True,
)
- def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+ async def remove_room_from_group(
+ self, destination: str, group_id: str, requester_user_id: str, room_id: str
+ ) -> JsonDict:
"""Remove a room from a group"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -710,11 +745,13 @@ class TransportLayerClient:
)
@log_function
- def get_users_in_group(self, destination, group_id, requester_user_id):
+ async def get_users_in_group(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get users in a group"""
path = _create_v1_path("/groups/%s/users", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -722,11 +759,13 @@ class TransportLayerClient:
)
@log_function
- def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+ async def get_invited_users_in_group(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get users that have been invited to a group"""
path = _create_v1_path("/groups/%s/invited_users", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -734,16 +773,20 @@ class TransportLayerClient:
)
@log_function
- def accept_group_invite(self, destination, group_id, user_id, content):
+ async def accept_group_invite(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Accept a group invite"""
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def join_group(self, destination, group_id, user_id, content):
+ def join_group(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Attempts to join a group"""
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
@@ -752,13 +795,18 @@ class TransportLayerClient:
)
@log_function
- def invite_to_group(
- self, destination, group_id, user_id, requester_user_id, content
- ):
+ async def invite_to_group(
+ self,
+ destination: str,
+ group_id: str,
+ user_id: str,
+ requester_user_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Invite a user to a group"""
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -767,25 +815,32 @@ class TransportLayerClient:
)
@log_function
- def invite_to_group_notification(self, destination, group_id, user_id, content):
+ async def invite_to_group_notification(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sent by group server to inform a user's server that they have been
invited.
"""
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def remove_user_from_group(
- self, destination, group_id, requester_user_id, user_id, content
- ):
+ async def remove_user_from_group(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ user_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Remove a user from a group"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -794,35 +849,43 @@ class TransportLayerClient:
)
@log_function
- def remove_user_from_group_notification(
- self, destination, group_id, user_id, content
- ):
+ async def remove_user_from_group_notification(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sent by group server to inform a user's server that they have been
kicked from the group.
"""
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def renew_group_attestation(self, destination, group_id, user_id, content):
+ async def renew_group_attestation(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sent by either a group server or a user's server to periodically update
the attestations
"""
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def update_group_summary_room(
- self, destination, group_id, user_id, room_id, category_id, content
- ):
+ async def update_group_summary_room(
+ self,
+ destination: str,
+ group_id: str,
+ user_id: str,
+ room_id: str,
+ category_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update a room entry in a group summary"""
if category_id:
path = _create_v1_path(
@@ -834,7 +897,7 @@ class TransportLayerClient:
else:
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": user_id},
@@ -843,9 +906,14 @@ class TransportLayerClient:
)
@log_function
- def delete_group_summary_room(
- self, destination, group_id, user_id, room_id, category_id
- ):
+ async def delete_group_summary_room(
+ self,
+ destination: str,
+ group_id: str,
+ user_id: str,
+ room_id: str,
+ category_id: str,
+ ) -> JsonDict:
"""Delete a room entry in a group summary"""
if category_id:
path = _create_v1_path(
@@ -857,7 +925,7 @@ class TransportLayerClient:
else:
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": user_id},
@@ -865,11 +933,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_categories(self, destination, group_id, requester_user_id):
+ async def get_group_categories(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get all categories in a group"""
path = _create_v1_path("/groups/%s/categories", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -877,11 +947,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_category(self, destination, group_id, requester_user_id, category_id):
+ async def get_group_category(
+ self, destination: str, group_id: str, requester_user_id: str, category_id: str
+ ) -> JsonDict:
"""Get category info in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -889,13 +961,18 @@ class TransportLayerClient:
)
@log_function
- def update_group_category(
- self, destination, group_id, requester_user_id, category_id, content
- ):
+ async def update_group_category(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ category_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update a category in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -904,13 +981,13 @@ class TransportLayerClient:
)
@log_function
- def delete_group_category(
- self, destination, group_id, requester_user_id, category_id
- ):
+ async def delete_group_category(
+ self, destination: str, group_id: str, requester_user_id: str, category_id: str
+ ) -> JsonDict:
"""Delete a category in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -918,11 +995,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_roles(self, destination, group_id, requester_user_id):
+ async def get_group_roles(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get all roles in a group"""
path = _create_v1_path("/groups/%s/roles", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -930,11 +1009,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_role(self, destination, group_id, requester_user_id, role_id):
+ async def get_group_role(
+ self, destination: str, group_id: str, requester_user_id: str, role_id: str
+ ) -> JsonDict:
"""Get a roles info"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -942,13 +1023,18 @@ class TransportLayerClient:
)
@log_function
- def update_group_role(
- self, destination, group_id, requester_user_id, role_id, content
- ):
+ async def update_group_role(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ role_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update a role in a group"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -957,11 +1043,13 @@ class TransportLayerClient:
)
@log_function
- def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+ async def delete_group_role(
+ self, destination: str, group_id: str, requester_user_id: str, role_id: str
+ ) -> JsonDict:
"""Delete a role in a group"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -969,9 +1057,15 @@ class TransportLayerClient:
)
@log_function
- def update_group_summary_user(
- self, destination, group_id, requester_user_id, user_id, role_id, content
- ):
+ async def update_group_summary_user(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ user_id: str,
+ role_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update a users entry in a group"""
if role_id:
path = _create_v1_path(
@@ -980,7 +1074,7 @@ class TransportLayerClient:
else:
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -989,11 +1083,13 @@ class TransportLayerClient:
)
@log_function
- def set_group_join_policy(self, destination, group_id, requester_user_id, content):
+ async def set_group_join_policy(
+ self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sets the join policy for a group"""
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
- return self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -1002,9 +1098,14 @@ class TransportLayerClient:
)
@log_function
- def delete_group_summary_user(
- self, destination, group_id, requester_user_id, user_id, role_id
- ):
+ async def delete_group_summary_user(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ user_id: str,
+ role_id: str,
+ ) -> JsonDict:
"""Delete a users entry in a group"""
if role_id:
path = _create_v1_path(
@@ -1013,33 +1114,35 @@ class TransportLayerClient:
else:
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
- def bulk_get_publicised_groups(self, destination, user_ids):
+ async def bulk_get_publicised_groups(
+ self, destination: str, user_ids: Iterable[str]
+ ) -> JsonDict:
"""Get the groups a list of users are publicising"""
path = _create_v1_path("/get_groups_publicised")
content = {"user_ids": user_ids}
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
- def get_room_complexity(self, destination, room_id):
+ async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict:
"""
Args:
- destination (str): The remote server
- room_id (str): The room ID to ask about.
+ destination: The remote server
+ room_id: The room ID to ask about.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id)
- return self.client.get_json(destination=destination, path=path)
+ return await self.client.get_json(destination=destination, path=path)
async def get_space_summary(
self,
@@ -1075,14 +1178,14 @@ class TransportLayerClient:
)
-def _create_path(federation_prefix, path, *args):
+def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
Ensures that all args are url encoded.
"""
return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
-def _create_v1_path(path, *args):
+def _create_v1_path(path: str, *args: str) -> str:
"""Creates a path against V1 federation API from the path template and
args. Ensures that all args are url encoded.
@@ -1091,16 +1194,13 @@ def _create_v1_path(path, *args):
_create_v1_path("/event/%s", event_id)
Args:
- path (str): String template for the path
- args: ([str]): Args to insert into path. Each arg will be url encoded
-
- Returns:
- str
+ path: String template for the path
+ args: Args to insert into path. Each arg will be url encoded
"""
return _create_path(FEDERATION_V1_PREFIX, path, *args)
-def _create_v2_path(path, *args):
+def _create_v2_path(path: str, *args: str) -> str:
"""Creates a path against V2 federation API from the path template and
args. Ensures that all args are url encoded.
@@ -1109,11 +1209,8 @@ def _create_v2_path(path, *args):
_create_v2_path("/event/%s", event_id)
Args:
- path (str): String template for the path
- args: ([str]): Args to insert into path. Each arg will be url encoded
-
- Returns:
- str
+ path: String template for the path
+ args: Args to insert into path. Each arg will be url encoded
"""
return _create_path(FEDERATION_V2_PREFIX, path, *args)
@@ -1122,8 +1219,26 @@ def _create_v2_path(path, *args):
class SendJoinResponse:
"""The parsed response of a `/send_join` request."""
+ # The list of auth events from the /send_join response.
auth_events: List[EventBase]
+ # The list of state from the /send_join response.
state: List[EventBase]
+ # The raw join event from the /send_join response.
+ event_dict: JsonDict
+ # The parsed join event from the /send_join response. This will be None if
+ # "event" is not included in the response.
+ event: Optional[EventBase] = None
+
+
+@ijson.coroutine
+def _event_parser(event_dict: JsonDict):
+ """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
+ to add them to a given dictionary.
+ """
+
+ while True:
+ key, value = yield
+ event_dict[key] = value
@ijson.coroutine
@@ -1149,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion, v1_api: bool):
- self._response = SendJoinResponse([], [])
+ self._response = SendJoinResponse([], [], {})
+ self._room_version = room_version
# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
@@ -1163,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
)
+ self._coro_event = ijson.kvitems_coro(
+ _event_parser(self._response.event_dict),
+ prefix + "org.matrix.msc3083.v2.event",
+ )
def write(self, data: bytes) -> int:
self._coro_state.send(data)
self._coro_auth.send(data)
+ self._coro_event.send(data)
return len(data)
def finish(self) -> SendJoinResponse:
+ if self._response.event_dict:
+ self._response.event = make_event_from_dict(
+ self._response.event_dict, self._room_version
+ )
return self._response
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2974d4d0..5e059d6e 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -984,7 +984,7 @@ class PublicRoomList(BaseFederationServlet):
limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
- query, "include_all_networks", False
+ query, "include_all_networks", default=False
)
third_party_instance_id = parse_string_from_args(
query, "third_party_instance_id", None
@@ -1908,16 +1908,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
suggested_only = parse_boolean_from_args(query, "suggested_only", default=False)
max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space")
- exclude_rooms = []
- if b"exclude_rooms" in query:
- try:
- exclude_rooms = [
- room_id.decode("ascii") for room_id in query[b"exclude_rooms"]
- ]
- except Exception:
- raise SynapseError(
- 400, "Bad query parameter for exclude_rooms", Codes.INVALID_PARAM
- )
+ exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[])
return 200, await self.handler.federation_space_summary(
origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms