summaryrefslogtreecommitdiff
path: root/synapse/handlers/federation_event.py
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-06-19 15:20:00 +0200
committerAndrej Shadura <andrewsh@debian.org>2022-06-19 15:20:00 +0200
commit6dc64c92c6991f09910f3e6db368e6eeb4b1981e (patch)
treed8bab73ee460e0a96bbda9c5988d8025dbbe2eb3 /synapse/handlers/federation_event.py
parentc2d3cd76c24f663449bfa209ac920305f0501d3a (diff)
New upstream version 1.61.0
Diffstat (limited to 'synapse/handlers/federation_event.py')
-rw-r--r--synapse/handlers/federation_event.py272
1 files changed, 158 insertions, 114 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 6cf927e4..87a06083 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -30,6 +30,7 @@ from typing import (
from prometheus_client import Counter
+from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -63,6 +64,7 @@ from synapse.replication.http.federation import (
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
@@ -96,14 +98,14 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
- self._storage = hs.get_storage()
- self._state_store = self._storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._state_handler = hs.get_state_handler()
self._event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self._message_handler = hs.get_message_handler()
- self._action_generator = hs.get_action_generator()
+ self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
self._state_resolution_handler = hs.get_state_resolution_handler()
# avoid a circular dependency by deferring execution here
self._get_room_member_handler = hs.get_room_member_handler
@@ -272,7 +274,7 @@ class FederationEventHandler:
affected=pdu.event_id,
)
- await self._process_received_pdu(origin, pdu, state=None)
+ await self._process_received_pdu(origin, pdu, state_ids=None)
async def on_send_membership_event(
self, origin: str, event: EventBase
@@ -461,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
- old_state=state,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in state
+ },
partial_state=partial_state,
)
@@ -475,7 +479,23 @@ class FederationEventHandler:
# and discover that we do not have it.
event.internal_metadata.proactively_send = False
- return await self.persist_events_and_notify(room_id, [(event, context)])
+ stream_id_after_persist = await self.persist_events_and_notify(
+ room_id, [(event, context)]
+ )
+
+ # If we're joining the room again, check if there is new marker
+ # state indicating that there is new history imported somewhere in
+ # the DAG. Multiple markers can exist in the current state with
+ # unique state_keys.
+ #
+ # Do this after the state from the remote join was persisted (via
+ # `persist_events_and_notify`). Otherwise we can run into a
+ # situation where the create event doesn't exist yet in the
+ # `current_state_events`
+ for e in state:
+ await self._handle_marker_event(origin, e)
+
+ return stream_id_after_persist
async def update_state_for_partial_state_event(
self, destination: str, event: EventBase
@@ -485,6 +505,9 @@ class FederationEventHandler:
Args:
destination: server to request full state from
event: partial-state event to be de-partial-stated
+
+ Raises:
+ FederationError if we fail to request state from the remote server.
"""
logger.info("Updating state for %s", event.event_id)
with nested_logging_context(suffix=event.event_id):
@@ -494,12 +517,12 @@ class FederationEventHandler:
#
# This is the same operation as we do when we receive a regular event
# over federation.
- state = await self._resolve_state_at_missing_prevs(destination, event)
+ state_ids = await self._resolve_state_at_missing_prevs(destination, event)
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
- old_state=state,
+ state_ids_before_event=state_ids,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
@@ -515,7 +538,9 @@ class FederationEventHandler:
)
return
await self._store.update_state_for_partial_state_event(event, context)
- self._state_store.notify_event_un_partial_stated(event.event_id)
+ self._state_storage_controller.notify_event_un_partial_stated(
+ event.event_id
+ )
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
@@ -749,11 +774,12 @@ class FederationEventHandler:
return
try:
- state = await self._resolve_state_at_missing_prevs(origin, event)
+ state_ids = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state
+
await self._process_received_pdu(
- origin, event, state=state, backfilled=backfilled
+ origin, event, state_ids=state_ids, backfilled=backfilled
)
except FederationError as e:
if e.code == 403:
@@ -763,7 +789,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
- ) -> Optional[Iterable[EventBase]]:
+ ) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
@@ -790,8 +816,12 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
- if we already had all the prev events, `None`. Otherwise, returns a list of
- the events in the state at `event`.
+ if we already had all the prev events, `None`. Otherwise, returns
+ the event ids of the state at `event`.
+
+ Raises:
+ FederationError if we fail to get the state from the remote server after any
+ missing `prev_event`s.
"""
room_id = event.room_id
event_id = event.event_id
@@ -811,10 +841,12 @@ class FederationEventHandler:
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
- event_map = {event_id: event}
+
try:
# Get the state of the events we know about
- ours = await self._state_store.get_state_groups_ids(room_id, seen)
+ ours = await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen
+ )
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
@@ -831,40 +863,23 @@ class FederationEventHandler:
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- remote_state = await self._get_state_after_missing_prev_event(
- dest, room_id, p
+ remote_state_map = (
+ await self._get_state_ids_after_missing_prev_event(
+ dest, room_id, p
+ )
)
- remote_state_map = {
- (x.type, x.state_key): x.event_id for x in remote_state
- }
state_maps.append(remote_state_map)
- for x in remote_state:
- event_map[x.event_id] = x
-
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
- event_map,
+ event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
)
- # We need to give _process_received_pdu the actual state events
- # rather than event ids, so generate that now.
-
- # First though we need to fetch all the events that are in
- # state_map, so we can build up the state below.
- evs = await self._store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.as_is,
- )
- event_map.update(evs)
-
- state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
@@ -876,14 +891,14 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
- return state
+ return state_map
- async def _get_state_after_missing_prev_event(
+ async def _get_state_ids_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
- ) -> List[EventBase]:
+ ) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@@ -892,7 +907,11 @@ class FederationEventHandler:
event_id: The id of the event we want the state at.
Returns:
- A list of events in the state, including the event itself
+ The event ids of the state *after* the given event.
+
+ Raises:
+ InvalidResponseError: if the remote homeserver's response contains fields
+ of the wrong type.
"""
(
state_event_ids,
@@ -907,19 +926,17 @@ class FederationEventHandler:
len(auth_event_ids),
)
- # start by just trying to fetch the events from the store
+ # Start by checking events we already have in the DB
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
- fetched_events = await self._store.get_events(
- desired_events, allow_rejected=True
- )
+ have_events = await self._store.have_seen_events(room_id, desired_events)
- missing_desired_events = desired_events - fetched_events.keys()
+ missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
- len(fetched_events),
+ len(have_events),
)
# We probably won't need most of the auth events, so let's just check which
@@ -930,7 +947,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
- missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
@@ -956,47 +973,51 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events
)
- # we need to make sure we re-load from the database to get the rejected
- # state correct.
- fetched_events.update(
- await self._store.get_events(missing_desired_events, allow_rejected=True)
- )
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
+ # We now need to fill out the state map, which involves fetching the
+ # type and state key for each event ID in the state.
+ state_map = {}
- bad_events = [
- (event_id, event.room_id)
- for event_id, event in fetched_events.items()
- if event.room_id != room_id
- ]
+ event_metadata = await self._store.get_metadata_for_events(state_event_ids)
+ for state_event_id, metadata in event_metadata.items():
+ if metadata.room_id != room_id:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ #
+ # This can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ state_event_id,
+ metadata.room_id,
+ room_id,
+ )
+ continue
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
+ if metadata.state_key is None:
+ logger.warning(
+ "Remote server gave us non-state event in state: %s", state_event_id
+ )
+ continue
- del fetched_events[bad_event_id]
+ state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem.
- remote_event = fetched_events.get(event_id)
+ remote_event = await self._store.get_event(
+ event_id,
+ allow_none=True,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.as_is,
+ )
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
- failed_to_fetch = desired_events - fetched_events.keys()
+ failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@@ -1004,14 +1025,12 @@ class FederationEventHandler:
failed_to_fetch,
)
- remote_state = [
- fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
- ]
-
if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
+ state_map[
+ (remote_event.type, remote_event.state_key)
+ ] = remote_event.event_id
- return remote_state
+ return state_map
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
@@ -1038,7 +1057,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
- state: Optional[Iterable[EventBase]],
+ state_ids: Optional[StateMap[str]],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@@ -1060,7 +1079,7 @@ class FederationEventHandler:
event: event to be persisted
- state: Normally None, but if we are handling a gap in the graph
+ state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
@@ -1072,7 +1091,8 @@ class FederationEventHandler:
try:
context = await self._state_handler.compute_event_context(
- event, old_state=state
+ event,
+ state_ids_before_event=state_ids,
)
context = await self._check_event_auth(
origin,
@@ -1089,7 +1109,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- await self._check_for_soft_fail(event, state, origin=origin)
+ await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@@ -1228,6 +1248,14 @@ class FederationEventHandler:
# Nothing to retrieve then (invalid marker)
return
+ already_seen_insertion_event = await self._store.have_seen_event(
+ marker_event.room_id, insertion_event_id
+ )
+ if already_seen_insertion_event:
+ # No need to process a marker again if we have already seen the
+ # insertion event that it was pointing to
+ return
+
logger.debug(
"_handle_marker_event: backfilling insertion event %s", insertion_event_id
)
@@ -1423,7 +1451,7 @@ class FederationEventHandler:
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
try:
validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth)
@@ -1500,7 +1528,11 @@ class FederationEventHandler:
return context
# now check auth against what we think the auth events *should* be.
- prev_state_ids = await context.get_prev_state_ids()
+ event_types = event_auth.auth_types_for_event(event.room_version, event)
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types(event_types)
+ )
+
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -1552,14 +1584,16 @@ class FederationEventHandler:
if guest_access == GuestAccess.CAN_JOIN:
return
- current_state_map = await self._state_handler.get_current_state(event.room_id)
- current_state = list(current_state_map.values())
- await self._get_room_member_handler().kick_guest_users(current_state)
+ current_state = await self._storage_controllers.state.get_current_state(
+ event.room_id
+ )
+ current_state_list = list(current_state.values())
+ await self._get_room_member_handler().kick_guest_users(current_state_list)
async def _check_for_soft_fail(
self,
event: EventBase,
- state: Optional[Iterable[EventBase]],
+ state_ids: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@@ -1567,7 +1601,7 @@ class FederationEventHandler:
Args:
event
- state: The state at the event if we don't have all the event's prev events
+ state_ids: The state at the event if we don't have all the event's prev events
origin: The host the event originates from.
"""
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
@@ -1582,8 +1616,11 @@ class FederationEventHandler:
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ # The event types we want to pull from the "current" state.
+ auth_types = auth_types_for_event(room_version_obj, event)
+
# Calculate the "current state".
- if state is not None:
+ if state_ids is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
@@ -1596,20 +1633,25 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets_d = await self._state_store.get_state_groups(
+ state_sets_d = await self._state_storage_controller.get_state_groups_ids(
event.room_id, extrem_ids
)
- state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
- state_sets.append(state)
- current_states = await self._state_handler.resolve_events(
- room_version, state_sets, event
+ state_sets: List[StateMap[str]] = list(state_sets_d.values())
+ state_sets.append(state_ids)
+ current_state_ids = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version,
+ state_sets,
+ event_map=None,
+ state_res_store=StateResolutionStore(self._store),
+ )
)
- current_state_ids: StateMap[str] = {
- k: e.event_id for k, e in current_states.items()
- }
else:
- current_state_ids = await self._state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
+ current_state_ids = (
+ await self._state_storage_controller.get_current_state_ids(
+ event.room_id, StateFilter.from_types(auth_types)
+ )
)
logger.debug(
@@ -1619,7 +1661,6 @@ class FederationEventHandler:
)
# Now check if event pass auth against said current state
- auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
@@ -1865,7 +1906,7 @@ class FederationEventHandler:
# create a new state group as a delta from the existing one.
prev_group = context.state_group
- state_group = await self._state_store.store_state_group(
+ state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -1874,10 +1915,10 @@ class FederationEventHandler:
)
return EventContext.with_state(
+ storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
+ state_delta_due_to_event=state_updates,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
@@ -1913,7 +1954,7 @@ class FederationEventHandler:
min_depth,
)
else:
- await self._action_generator.handle_push_actions_for_event(
+ await self._bulk_push_rule_evaluator.action_for_event_by_user(
event, context
)
@@ -1964,11 +2005,14 @@ class FederationEventHandler:
)
return result["max_stream_id"]
else:
- assert self._storage.persistence
+ assert self._storage_controllers.persistence
# Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs.
- events, max_stream_token = await self._storage.persistence.persist_events(
+ (
+ events,
+ max_stream_token,
+ ) = await self._storage_controllers.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)