summaryrefslogtreecommitdiff
path: root/synapse/federation
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-06-16 10:28:18 +0200
committerAndrej Shadura <andrewsh@debian.org>2021-06-16 10:28:18 +0200
commit219af4a8aef838c5e3689a2aa71cf72f2fd75aa2 (patch)
tree3183d9a61335f862a9ddd3b3de2c804aaa93a6bf /synapse/federation
parent396a9dfc77fc34b77d7ef552048f22ecb94e91ea (diff)
New upstream version 1.36.0
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py243
-rw-r--r--synapse/federation/federation_client.py147
-rw-r--r--synapse/federation/transport/server.py20
3 files changed, 184 insertions, 226 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 3fe496dc..c066617b 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -14,11 +14,6 @@
# limitations under the License.
import logging
from collections import namedtuple
-from typing import Iterable, List
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred, DeferredList
-from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -28,11 +23,6 @@ from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event, validate_canonicaljson
from synapse.http.servlet import assert_params_in_dict
-from synapse.logging.context import (
- PreserveLoggingContext,
- current_context,
- make_deferred_yieldable,
-)
from synapse.types import JsonDict, get_domain_from_id
logger = logging.getLogger(__name__)
@@ -48,112 +38,82 @@ class FederationBase:
self.store = hs.get_datastore()
self._clock = hs.get_clock()
- def _check_sigs_and_hash(
+ async def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase
- ) -> Deferred:
- return make_deferred_yieldable(
- self._check_sigs_and_hashes(room_version, [pdu])[0]
- )
-
- def _check_sigs_and_hashes(
- self, room_version: RoomVersion, pdus: List[EventBase]
- ) -> List[Deferred]:
- """Checks that each of the received events is correctly signed by the
- sending server.
+ ) -> EventBase:
+ """Checks that event is correctly signed by the sending server.
Args:
- room_version: The room version of the PDUs
- pdus: the events to be checked
+ room_version: The room version of the PDU
+ pdu: the event to be checked
Returns:
- For each input event, a deferred which:
- * returns the original event if the checks pass
- * returns a redacted version of the event (if the signature
+ * the original event if the checks pass
+ * a redacted version of the event (if the signature
matched but the hash did not)
- * throws a SynapseError if the signature check failed.
- The deferreds run their callbacks in the sentinel
- """
- deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
-
- ctx = current_context()
-
- @defer.inlineCallbacks
- def callback(_, pdu: EventBase):
- with PreserveLoggingContext(ctx):
- if not check_event_content_hash(pdu):
- # let's try to distinguish between failures because the event was
- # redacted (which are somewhat expected) vs actual ball-tampering
- # incidents.
- #
- # This is just a heuristic, so we just assume that if the keys are
- # about the same between the redacted and received events, then the
- # received event was probably a redacted copy (but we then use our
- # *actual* redacted copy to be on the safe side.)
- redacted_event = prune_event(pdu)
- if set(redacted_event.keys()) == set(pdu.keys()) and set(
- redacted_event.content.keys()
- ) == set(pdu.content.keys()):
- logger.info(
- "Event %s seems to have been redacted; using our redacted "
- "copy",
- pdu.event_id,
- )
- else:
- logger.warning(
- "Event %s content has been tampered, redacting",
- pdu.event_id,
- )
- return redacted_event
-
- result = yield defer.ensureDeferred(
- self.spam_checker.check_event_for_spam(pdu)
+ * throws a SynapseError if the signature check failed."""
+ try:
+ await _check_sigs_on_pdu(self.keyring, room_version, pdu)
+ except SynapseError as e:
+ logger.warning(
+ "Signature check failed for %s: %s",
+ pdu.event_id,
+ e,
+ )
+ raise
+
+ if not check_event_content_hash(pdu):
+ # let's try to distinguish between failures because the event was
+ # redacted (which are somewhat expected) vs actual ball-tampering
+ # incidents.
+ #
+ # This is just a heuristic, so we just assume that if the keys are
+ # about the same between the redacted and received events, then the
+ # received event was probably a redacted copy (but we then use our
+ # *actual* redacted copy to be on the safe side.)
+ redacted_event = prune_event(pdu)
+ if set(redacted_event.keys()) == set(pdu.keys()) and set(
+ redacted_event.content.keys()
+ ) == set(pdu.content.keys()):
+ logger.info(
+ "Event %s seems to have been redacted; using our redacted copy",
+ pdu.event_id,
)
-
- if result:
- logger.warning(
- "Event contains spam, redacting %s: %s",
- pdu.event_id,
- pdu.get_pdu_json(),
- )
- return prune_event(pdu)
-
- return pdu
-
- def errback(failure: Failure, pdu: EventBase):
- failure.trap(SynapseError)
- with PreserveLoggingContext(ctx):
+ else:
logger.warning(
- "Signature check failed for %s: %s",
+ "Event %s content has been tampered, redacting",
pdu.event_id,
- failure.getErrorMessage(),
)
- return failure
+ return redacted_event
- for deferred, pdu in zip(deferreds, pdus):
- deferred.addCallbacks(
- callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
+ result = await self.spam_checker.check_event_for_spam(pdu)
+
+ if result:
+ logger.warning(
+ "Event contains spam, redacting %s: %s",
+ pdu.event_id,
+ pdu.get_pdu_json(),
)
+ return prune_event(pdu)
- return deferreds
+ return pdu
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
-def _check_sigs_on_pdus(
- keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
-) -> List[Deferred]:
+async def _check_sigs_on_pdu(
+ keyring: Keyring, room_version: RoomVersion, pdu: EventBase
+) -> None:
"""Check that the given events are correctly signed
+ Raise a SynapseError if the event wasn't correctly signed.
+
Args:
keyring: keyring object to do the checks
room_version: the room version of the PDUs
pdus: the events to be checked
-
- Returns:
- A Deferred for each event in pdus, which will either succeed if
- the signatures are valid, or fail (with a SynapseError) if not.
"""
# we want to check that the event is signed by:
@@ -177,90 +137,47 @@ def _check_sigs_on_pdus(
# let's start by getting the domain for each pdu, and flattening the event back
# to JSON.
- pdus_to_check = [
- PduToCheckSig(
- pdu=p,
- sender_domain=get_domain_from_id(p.sender),
- deferreds=[],
- )
- for p in pdus
- ]
-
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
- pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
-
- more_deferreds = keyring.verify_events_for_server(
- [
- (
- p.sender_domain,
- p.pdu,
- p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+ if not _is_invite_via_3pid(pdu):
+ try:
+ await keyring.verify_event_for_server(
+ get_domain_from_id(pdu.sender),
+ pdu,
+ pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
- for p in pdus_to_check_sender
- ]
- )
-
- def sender_err(e, pdu_to_check):
- errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
- pdu_to_check.pdu.event_id,
- pdu_to_check.sender_domain,
- e.getErrorMessage(),
- )
- raise SynapseError(403, errmsg, Codes.FORBIDDEN)
-
- for p, d in zip(pdus_to_check_sender, more_deferreds):
- d.addErrback(sender_err, p)
- p.deferreds.append(d)
+ except Exception as e:
+ errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
+ pdu.event_id,
+ get_domain_from_id(pdu.sender),
+ e,
+ )
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
- if room_version.event_format == EventFormatVersions.V1:
- pdus_to_check_event_id = [
- p
- for p in pdus_to_check
- if p.sender_domain != get_domain_from_id(p.pdu.event_id)
- ]
-
- more_deferreds = keyring.verify_events_for_server(
- [
- (
- get_domain_from_id(p.pdu.event_id),
- p.pdu,
- p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- )
- for p in pdus_to_check_event_id
- ]
- )
-
- def event_err(e, pdu_to_check):
+ if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id(
+ pdu.event_id
+ ) != get_domain_from_id(pdu.sender):
+ try:
+ await keyring.verify_event_for_server(
+ get_domain_from_id(pdu.event_id),
+ pdu,
+ pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+ )
+ except Exception as e:
errmsg = (
- "event id %s: unable to verify signature for event id domain: %s"
- % (pdu_to_check.pdu.event_id, e.getErrorMessage())
+ "event id %s: unable to verify signature for event id domain %s: %s"
+ % (
+ pdu.event_id,
+ get_domain_from_id(pdu.event_id),
+ e,
+ )
)
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
- for p, d in zip(pdus_to_check_event_id, more_deferreds):
- d.addErrback(event_err, p)
- p.deferreds.append(d)
-
- # replace lists of deferreds with single Deferreds
- return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
-
-
-def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
- """Given a list of deferreds, either return the single deferred,
- combine into a DeferredList, or return an already resolved deferred.
- """
- if len(deferreds) > 1:
- return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
- elif len(deferreds) == 1:
- return deferreds[0]
- else:
- return defer.succeed(None)
-
def _is_invite_via_3pid(event: EventBase) -> bool:
return (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e0e9f5d0..1076ebc0 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -21,6 +21,7 @@ from typing import (
Any,
Awaitable,
Callable,
+ Collection,
Dict,
Iterable,
List,
@@ -35,9 +36,6 @@ from typing import (
import attr
from prometheus_client import Counter
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
@@ -56,10 +54,9 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
-from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -360,10 +357,9 @@ class FederationClient(FederationBase):
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
- pdus: List[EventBase],
+ pdus: Collection[EventBase],
room_version: RoomVersion,
outlier: bool = False,
- include_none: bool = False,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
@@ -380,57 +376,87 @@ class FederationClient(FederationBase):
pdu
room_version
outlier: Whether the events are outliers or not
- include_none: Whether to include None in the returned list
- for events that have failed their checks
Returns:
A list of PDUs that have valid signatures and hashes.
"""
- deferreds = self._check_sigs_and_hashes(room_version, pdus)
- async def handle_check_result(pdu: EventBase, deferred: Deferred):
- try:
- res = await make_deferred_yieldable(deferred)
- except SynapseError:
- res = None
+ # We limit how many PDUs we check at once, as if we try to do hundreds
+ # of thousands of PDUs at once we see large memory spikes.
- if not res:
- # Check local db.
- res = await self.store.get_event(
- pdu.event_id, allow_rejected=True, allow_none=True
- )
+ valid_pdus = []
- pdu_origin = get_domain_from_id(pdu.sender)
- if not res and pdu_origin != origin:
- try:
- res = await self.get_pdu(
- destinations=[pdu_origin],
- event_id=pdu.event_id,
- room_version=room_version,
- outlier=outlier,
- timeout=10000,
- )
- except SynapseError:
- pass
+ async def _execute(pdu: EventBase) -> None:
+ valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+ pdu=pdu,
+ origin=origin,
+ outlier=outlier,
+ room_version=room_version,
+ )
- if not res:
- logger.warning(
- "Failed to find copy of %s with valid signature", pdu.event_id
- )
+ if valid_pdu:
+ valid_pdus.append(valid_pdu)
- return res
+ await concurrently_execute(_execute, pdus, 10000)
- handle = preserve_fn(handle_check_result)
- deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
+ return valid_pdus
- valid_pdus = await make_deferred_yieldable(
- defer.gatherResults(deferreds2, consumeErrors=True)
- ).addErrback(unwrapFirstError)
+ async def _check_sigs_and_hash_and_fetch_one(
+ self,
+ pdu: EventBase,
+ origin: str,
+ room_version: RoomVersion,
+ outlier: bool = False,
+ ) -> Optional[EventBase]:
+ """Takes a PDU and checks its signatures and hashes. If the PDU fails
+ its signature check then we check if we have it in the database and if
+ not then request if from the originating server of that PDU.
- if include_none:
- return valid_pdus
- else:
- return [p for p in valid_pdus if p]
+ If then PDU fails its content hash check then it is redacted.
+
+ Args:
+ origin
+ pdu
+ room_version
+ outlier: Whether the events are outliers or not
+ include_none: Whether to include None in the returned list
+ for events that have failed their checks
+
+ Returns:
+ The PDU (possibly redacted) if it has valid signatures and hashes.
+ """
+
+ res = None
+ try:
+ res = await self._check_sigs_and_hash(room_version, pdu)
+ except SynapseError:
+ pass
+
+ if not res:
+ # Check local db.
+ res = await self.store.get_event(
+ pdu.event_id, allow_rejected=True, allow_none=True
+ )
+
+ pdu_origin = get_domain_from_id(pdu.sender)
+ if not res and pdu_origin != origin:
+ try:
+ res = await self.get_pdu(
+ destinations=[pdu_origin],
+ event_id=pdu.event_id,
+ room_version=room_version,
+ outlier=outlier,
+ timeout=10000,
+ )
+ except SynapseError:
+ pass
+
+ if not res:
+ logger.warning(
+ "Failed to find copy of %s with valid signature", pdu.event_id
+ )
+
+ return res
async def get_event_auth(
self, destination: str, room_id: str, event_id: str
@@ -671,8 +697,6 @@ class FederationClient(FederationBase):
state = response.state
auth_chain = response.auth_events
- pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
-
create_event = None
for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""):
@@ -696,14 +720,29 @@ class FederationClient(FederationBase):
% (create_room_version,)
)
- valid_pdus = await self._check_sigs_and_hash_and_fetch(
- destination,
- list(pdus.values()),
- outlier=True,
- room_version=room_version,
+ logger.info(
+ "Processing from send_join %d events", len(state) + len(auth_chain)
)
- valid_pdus_map = {p.event_id: p for p in valid_pdus}
+ # We now go and check the signatures and hashes for the event. Note
+ # that we limit how many events we process at a time to keep the
+ # memory overhead from exploding.
+ valid_pdus_map: Dict[str, EventBase] = {}
+
+ async def _execute(pdu: EventBase) -> None:
+ valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+ pdu=pdu,
+ origin=destination,
+ outlier=True,
+ room_version=room_version,
+ )
+
+ if valid_pdu:
+ valid_pdus_map[valid_pdu.event_id] = valid_pdu
+
+ await concurrently_execute(
+ _execute, itertools.chain(state, auth_chain), 10000
+ )
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 40eab455..5756fcb5 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -37,6 +37,7 @@ from synapse.http.servlet import (
)
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
+ SynapseTags,
start_active_span,
start_active_span_from_request,
tags,
@@ -151,7 +152,9 @@ class Authenticator:
)
await self.keyring.verify_json_for_server(
- origin, json_request, now, "Incoming request"
+ origin,
+ json_request,
+ now,
)
logger.debug("Request from %s", origin)
@@ -314,7 +317,7 @@ class BaseFederationServlet:
raise
request_tags = {
- "request_id": request.get_request_id(),
+ SynapseTags.REQUEST_ID: request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
@@ -1562,13 +1565,12 @@ def register_servlets(
server_name=hs.hostname,
).register(resource)
- if hs.config.experimental.spaces_enabled:
- FederationSpaceSummaryServlet(
- handler=hs.get_space_summary_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
+ FederationSpaceSummaryServlet(
+ handler=hs.get_space_summary_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
if "openid" in servlet_groups:
for servletclass in OPENID_SERVLET_CLASSES: