summaryrefslogtreecommitdiff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py121
1 files changed, 87 insertions, 34 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 593932ad..014dab29 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -72,7 +72,13 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ MutableStateMap,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -96,7 +102,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
- auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
+ auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
class FederationHandler(BaseHandler):
@@ -434,11 +440,11 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
return
- latest = await self.store.get_latest_event_ids_in_room(room_id)
+ latest_list = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
- latest = set(latest)
+ latest = set(latest_list)
latest |= seen
logger.info(
@@ -775,7 +781,7 @@ class FederationHandler(BaseHandler):
# keys across all devices.
current_keys = [
key
- for device in cached_devices
+ for device in cached_devices.values()
for key in device.get("keys", {}).get("keys", {}).values()
]
@@ -937,15 +943,26 @@ class FederationHandler(BaseHandler):
return events
- async def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(
+ self, room_id: str, current_depth: int, limit: int
+ ) -> bool:
"""Checks the database to see if we should backfill before paginating,
and if so do.
+
+ Args:
+ room_id
+ current_depth: The depth from which we're paginating from. This is
+ used to decide if we should backfill and what extremities to
+ use.
+ limit: The number of events that the pagination request will
+ return. This is used as part of the heuristic to decide if we
+ should back paginate.
"""
extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
- return
+ return False
# We only want to paginate if we can actually see the events we'll get,
# as otherwise we'll just spend a lot of resources to get redacted
@@ -998,16 +1015,54 @@ class FederationHandler(BaseHandler):
sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
max_depth = sorted_extremeties_tuple[0][1]
+ # If we're approaching an extremity we trigger a backfill, otherwise we
+ # no-op.
+ #
+ # We chose twice the limit here as then clients paginating backwards
+ # will send pagination requests that trigger backfill at least twice
+ # using the most recent extremity before it gets removed (see below). We
+ # chose more than one times the limit in case of failure, but choosing a
+ # much larger factor will result in triggering a backfill request much
+ # earlier than necessary.
+ if current_depth - 2 * limit > max_depth:
+ logger.debug(
+ "Not backfilling as we don't need to. %d < %d - 2 * %d",
+ max_depth,
+ current_depth,
+ limit,
+ )
+ return False
+
+ logger.debug(
+ "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
+ room_id,
+ current_depth,
+ max_depth,
+ sorted_extremeties_tuple,
+ )
+
+ # We ignore extremities that have a greater depth than our current depth
+ # as:
+ # 1. we don't really care about getting events that have happened
+ # before our current position; and
+ # 2. we have likely previously tried and failed to backfill from that
+ # extremity, so to avoid getting "stuck" requesting the same
+ # backfill repeatedly we drop those extremities.
+ filtered_sorted_extremeties_tuple = [
+ t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
+ ]
+
+ # However, we need to check that the filtered extremities are non-empty.
+ # If they are empty then either we can a) bail or b) still attempt to
+ # backill. We opt to try backfilling anyway just in case we do get
+ # relevant events.
+ if filtered_sorted_extremeties_tuple:
+ sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
+
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
- if current_depth > max_depth:
- logger.debug(
- "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
- )
- return
-
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
@@ -1777,9 +1832,7 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
@@ -1805,9 +1858,7 @@ class FederationHandler(BaseHandler):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -1877,8 +1928,8 @@ class FederationHandler(BaseHandler):
else:
return None
- def get_min_depth_for_context(self, context):
- return self.store.get_min_depth(context)
+ async def get_min_depth_for_context(self, context):
+ return await self.store.get_min_depth(context)
async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False
@@ -2057,7 +2108,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
- auth_events: Optional[StateMap[EventBase]],
+ auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)
@@ -2107,8 +2158,8 @@ class FederationHandler(BaseHandler):
if backfilled or event.internal_metadata.is_outlier():
return
- extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
- extrem_ids = set(extrem_ids)
+ extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids:
@@ -2138,10 +2189,12 @@ class FederationHandler(BaseHandler):
)
state_sets = list(state_sets.values())
state_sets.append(state)
- current_state_ids = await self.state_handler.resolve_events(
+ current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
- current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+ current_state_ids = {
+ k: e.event_id for k, e in current_states.items()
+ } # type: StateMap[str]
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2153,11 +2206,13 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
- current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
+ current_state_ids_list = [
+ e for k, e in current_state_ids.items() if k in auth_types
+ ]
- current_auth_events = await self.store.get_events(current_state_ids)
+ auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
- (e.type, e.state_key): e for e in current_auth_events.values()
+ (e.type, e.state_key): e for e in auth_events_map.values()
}
try:
@@ -2173,9 +2228,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
@@ -2227,7 +2280,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
- auth_events: StateMap[EventBase],
+ auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""
@@ -2278,7 +2331,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
- auth_events: StateMap[EventBase],
+ auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""Helper for do_auth. See there for docs.