diff options
author | Andrej Shadura <andrewsh@debian.org> | 2022-10-30 12:23:11 +0100 |
---|---|---|
committer | Andrej Shadura <andrewsh@debian.org> | 2022-10-30 12:23:11 +0100 |
commit | 53aa9684018b7b070d195d0e66d121ea43441434 (patch) | |
tree | 12b790285708053512a77dd536754d744281b8eb /synapse/storage/databases | |
parent | 41aea8fc55649be44b0d55810d8080f8b45fea9e (diff) |
New upstream version 1.70.1
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/cache.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 54 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_push_actions.py | 620 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 100 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 107 | ||||
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 22 | ||||
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 315 | ||||
-rw-r--r-- | synapse/storage/databases/main/room.py | 86 | ||||
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 17 | ||||
-rw-r--r-- | synapse/storage/databases/main/stream.py | 59 |
12 files changed, 1083 insertions, 319 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 3b8ed1f7..ddb73977 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -244,12 +244,18 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # redacted. self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self._attempt_to_invalidate_cache( "get_invited_rooms_for_local_user", (state_key,) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (state_key,) + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,)) if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) @@ -259,9 +265,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 18358eca..830b076a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -539,9 +539,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "device_id": device_id, "prev_id": [prev_id] if prev_id else [], "stream_id": stream_id, - "org.matrix.opentracing_context": opentracing_context, } + if opentracing_context != "{}": + result["org.matrix.opentracing_context"] = opentracing_context + prev_id = stream_id if device is not None: @@ -549,7 +551,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if keys: result["keys"] = keys - device_display_name = device.display_name + device_display_name = None + if ( + self.hs.config.federation.allow_device_name_lookup_over_federation + ): + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name else: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6b9a629e..309a4ba6 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id: The event that failed to be fetched or processed cause: The error message or reason that we failed to pull the event """ + logger.debug( + "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s", + room_id, + event_id, + cause, + ) await self.db_pool.runInteraction( "record_event_failed_pull_attempt", self._record_event_failed_pull_attempt_upsert_txn, @@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + @trace + async def get_event_ids_to_not_pull_from_backoff( + self, + room_id: str, + event_ids: Collection[str], + ) -> List[str]: + """ + Filter down the events to ones that we've failed to pull before recently. Uses + exponential backoff. + + Args: + room_id: The room that the events belong to + event_ids: A list of events to filter down + + Returns: + List of event_ids that should not be attempted to be pulled + """ + event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", + ) + + current_time = self._clock.time_msec() + return [ + event_failed_pull_attempt["event_id"] + for event_failed_pull_attempt in event_failed_pull_attempts + # Exponential back-off (up to the upper bound) so we don't try to + # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + if current_time + < event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) + ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS + ] + async def get_missing_events( self, room_id: str, diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 72cf91eb..b283ab0f 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -88,7 +88,7 @@ from typing import ( import attr -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -119,6 +119,32 @@ DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [ ] +@attr.s(slots=True, auto_attribs=True) +class _RoomReceipt: + """ + HttpPushAction instances include the information used to generate HTTP + requests to a push gateway. + """ + + unthreaded_stream_ordering: int = 0 + # threaded_stream_ordering includes the main pseudo-thread. + threaded_stream_ordering: Dict[str, int] = attr.Factory(dict) + + def is_unread(self, thread_id: str, stream_ordering: int) -> bool: + """Returns True if the stream ordering is unread according to the receipt information.""" + + # Only include push actions with a stream ordering after both the unthreaded + # and threaded receipt. Properly handles a user without any receipts present. + return ( + self.unthreaded_stream_ordering < stream_ordering + and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering + ) + + +# A _RoomReceipt with no receipts in it. +MISSING_ROOM_RECEIPT = _RoomReceipt() + + @attr.s(slots=True, frozen=True, auto_attribs=True) class HttpPushAction: """ @@ -157,7 +183,7 @@ class UserPushAction(EmailPushAction): @attr.s(slots=True, auto_attribs=True) class NotifCounts: """ - The per-user, per-room count of notifications. Used by sync and push. + The per-user, per-room, per-thread count of notifications. Used by sync and push. """ notify_count: int = 0 @@ -165,6 +191,21 @@ class NotifCounts: highlight_count: int = 0 +@attr.s(slots=True, auto_attribs=True) +class RoomNotifCounts: + """ + The per-user, per-room count of notifications. Used by sync and push. + """ + + main_timeline: NotifCounts + # Map of thread ID to the notification counts. + threads: Dict[str, NotifCounts] + + def __len__(self) -> int: + # To properly account for the amount of space in any caches. + return len(self.threads) + 1 + + def _serialize_action( actions: Collection[Union[Mapping, str]], is_highlight: bool ) -> str: @@ -253,6 +294,44 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self._background_backfill_thread_id, ) + # Indexes which will be used to quickly make the thread_id column non-null. + self.db_pool.updates.register_background_index_update( + "event_push_actions_thread_id_null", + index_name="event_push_actions_thread_id_null", + table="event_push_actions", + columns=["thread_id"], + where_clause="thread_id IS NULL", + ) + self.db_pool.updates.register_background_index_update( + "event_push_summary_thread_id_null", + index_name="event_push_summary_thread_id_null", + table="event_push_summary", + columns=["thread_id"], + where_clause="thread_id IS NULL", + ) + + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates the event_push_actions and event_push_summary tables. + self._clock.call_later(0.0, self._check_event_push_backfill_thread_id) + self._event_push_backfill_thread_id_done = False + + @wrap_as_background_process("check_event_push_backfill_thread_id") + async def _check_event_push_backfill_thread_id(self) -> None: + """ + Has thread_id finished backfilling? + + If not, we need to just-in-time update it so the queries work. + """ + done = await self.db_pool.updates.has_completed_background_update( + "event_push_backfill_thread_id" + ) + + if done: + self._event_push_backfill_thread_id_done = True + else: + # Reschedule to run. + self._clock.call_later(15.0, self._check_event_push_backfill_thread_id) + async def _background_backfill_thread_id( self, progress: JsonDict, batch_size: int ) -> int: @@ -384,12 +463,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return result - @cached(tree=True, max_entries=5000) + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after their latest read receipt. @@ -402,8 +481,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to retrieve the counts for. Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", @@ -417,7 +497,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: # Get the stream ordering of the user's latest receipt in the room. result = self.get_last_unthreaded_receipt_for_user_txn( txn, @@ -451,8 +531,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - receipt_stream_ordering: int, - ) -> NotifCounts: + unthreaded_receipt_stream_ordering: int, + ) -> RoomNotifCounts: """Get the number of unread messages for a user/room that have happened since the given stream ordering. @@ -460,78 +540,223 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: The database transaction. room_id: The room ID to get unread counts for. user_id: The user ID to get unread counts for. - receipt_stream_ordering: The stream ordering of the user's latest - receipt in the room. If there are no receipts, the stream ordering - of the user's join event. + unthreaded_receipt_stream_ordering: The stream ordering of the user's latest + unthreaded receipt in the room. If there are no unthreaded receipts, + the stream ordering of the user's join event. - Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + Returns: + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ - counts = NotifCounts() + main_counts = NotifCounts() + thread_counts: Dict[str, NotifCounts] = {} + + def _get_thread(thread_id: str) -> NotifCounts: + if thread_id == MAIN_TIMELINE: + return main_counts + return thread_counts.setdefault(thread_id, NotifCounts()) + + receipt_types_clause, receipts_args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + ) + + # First ensure that the existing rows have an updated thread_id field. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) # First we pull the counts from the summary table. # - # We check that `last_receipt_stream_ordering` matches the stream - # ordering given. If it doesn't match then a new read receipt has arrived and - # we haven't yet updated the counts in `event_push_summary` to reflect - # that; in that case we simply ignore `event_push_summary` counts - # and do a manual count of all of the rows in the `event_push_actions` table - # for this user/room. + # We check that `last_receipt_stream_ordering` matches the stream ordering of the + # latest receipt for the thread (which may be either the unthreaded read receipt + # or the threaded read receipt). + # + # If it doesn't match then a new read receipt has arrived and we haven't yet + # updated the counts in `event_push_summary` to reflect that; in that case we + # simply ignore `event_push_summary` counts. # - # If `last_receipt_stream_ordering` is null then that means it's up to - # date (as the row was written by an older version of Synapse that + # We then do a manual count of all the rows in the `event_push_actions` table + # for any user/room/thread which did not have a valid summary found. + # + # If `last_receipt_stream_ordering` is null then that means it's up-to-date + # (as the row was written by an older version of Synapse that # updated `event_push_summary` synchronously when persisting a new read # receipt). txn.execute( - """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) + f""" + SELECT notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) WHERE room_id = ? AND user_id = ? AND ( - (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) - OR last_receipt_stream_ordering = ? - ) + (last_receipt_stream_ordering IS NULL AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)) + OR last_receipt_stream_ordering = COALESCE(threaded_receipt_stream_ordering, ?) + ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0) """, - (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + room_id, + user_id, + unthreaded_receipt_stream_ordering, + unthreaded_receipt_stream_ordering, + ), ) - row = txn.fetchone() - - summary_stream_ordering = 0 - if row: - summary_stream_ordering = row[0] - counts.notify_count += row[1] - counts.unread_count += row[2] + summarised_threads = set() + for notif_count, unread_count, thread_id in txn: + summarised_threads.add(thread_id) + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count # Next we need to count highlights, which aren't summarised - sql = """ - SELECT COUNT(*) FROM event_push_actions + sql = f""" + SELECT COUNT(*), thread_id FROM event_push_actions + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) WHERE user_id = ? AND room_id = ? - AND stream_ordering > ? + AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?) AND highlight = 1 + GROUP BY thread_id """ - txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) - row = txn.fetchone() - if row: - counts.highlight_count += row[0] + txn.execute( + sql, + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + user_id, + room_id, + unthreaded_receipt_stream_ordering, + ), + ) + for highlight_count, thread_id in txn: + _get_thread(thread_id).highlight_count += highlight_count + + # For threads which were summarised we need to count actions since the last + # rotation. + thread_id_clause, thread_id_args = make_in_list_sql_clause( + self.database_engine, "thread_id", summarised_threads + ) + + # The (inclusive) event stream ordering that was previously summarised. + rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + unread_counts = self._get_notif_unread_count_for_user_room( + txn, room_id, user_id, rotated_upto_stream_ordering + ) + for notif_count, unread_count, thread_id in unread_counts: + if thread_id not in summarised_threads: + continue + + if thread_id == MAIN_TIMELINE: + counts.notify_count += notif_count + counts.unread_count += unread_count + elif thread_id in thread_counts: + thread_counts[thread_id].notify_count += notif_count + thread_counts[thread_id].unread_count += unread_count + else: + # Previous thread summaries of 0 are discarded above. + # + # TODO If empty summaries are deleted this can be removed. + thread_counts[thread_id] = NotifCounts( + notify_count=notif_count, + unread_count=unread_count, + highlight_count=0, + ) # Finally we need to count push actions that aren't included in the # summary returned above. This might be due to recent events that haven't # been summarised yet or the summary is out of date due to a recent read # receipt. - start_unread_stream_ordering = max( - receipt_stream_ordering, summary_stream_ordering - ) - notify_count, unread_count = self._get_notif_unread_count_for_user_room( - txn, room_id, user_id, start_unread_stream_ordering + sql = f""" + SELECT + COUNT(CASE WHEN notif = 1 THEN 1 END), + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) + WHERE user_id = ? + AND room_id = ? + AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?) + AND NOT {thread_id_clause} + GROUP BY thread_id + """ + txn.execute( + sql, + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *thread_id_args, + ), ) + for notif_count, unread_count, thread_id in txn: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count - counts.notify_count += notify_count - counts.unread_count += unread_count - - return counts + return RoomNotifCounts(main_counts, thread_counts) def _get_notif_unread_count_for_user_room( self, @@ -540,7 +765,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, - ) -> Tuple[int, int]: + thread_id: Optional[str] = None, + ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -554,45 +780,55 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas stream_ordering: The (exclusive) minimum stream ordering to consider. max_stream_ordering: The (inclusive) maximum stream ordering to consider. If this is not given, then no maximum is applied. + thread_id: The thread ID to fetch unread counts for. If this is not provided + then the results for *all* threads is returned. + + Note that if this is provided the resulting list will only have 0 or + 1 tuples in it. Return: - A tuple of the notif count and unread count in the given range. + A tuple of the notif count and unread count in the given range for + each thread. """ # If there have been no events in the room since the stream ordering, # there can't be any push actions either. if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): - return 0, 0 + return [] - clause = "" + stream_ordering_clause = "" args = [user_id, room_id, stream_ordering] if max_stream_ordering is not None: - clause = "AND ea.stream_ordering <= ?" + stream_ordering_clause = "AND ea.stream_ordering <= ?" args.append(max_stream_ordering) # If the max stream ordering is less than the min stream ordering, # then obviously there are zero push actions in that range. if max_stream_ordering <= stream_ordering: - return 0, 0 + return [] + + # Either limit the results to a specific thread or fetch all threads. + thread_id_clause = "" + if thread_id is not None: + thread_id_clause = "AND thread_id = ?" + args.append(thread_id) sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), - COUNT(CASE WHEN unread = 1 THEN 1 END) - FROM event_push_actions ea - WHERE user_id = ? + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions ea + WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? - {clause} + {stream_ordering_clause} + {thread_id_clause} + GROUP BY thread_id """ txn.execute(sql, args) - row = txn.fetchone() - - if row: - return cast(Tuple[int, int], row) - - return 0, 0 + return cast(List[Tuple[int, int, str]], txn.fetchall()) async def get_push_action_users_in_range( self, min_stream_ordering: int, max_stream_ordering: int @@ -609,7 +845,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_receipts_by_room_txn( self, txn: LoggingTransaction, user_id: str - ) -> Dict[str, int]: + ) -> Dict[str, _RoomReceipt]: """ Generate a map of room ID to the latest stream ordering that has been read by the given user. @@ -619,7 +855,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to fetch receipts for. Returns: - A map of room ID to stream ordering for all rooms the user has a receipt in. + A map including all rooms the user is in with a receipt. It maps + room IDs to _RoomReceipt instances """ receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, @@ -628,20 +865,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) sql = f""" - SELECT room_id, MAX(stream_ordering) + SELECT room_id, thread_id, MAX(stream_ordering) FROM receipts_linearized INNER JOIN events USING (room_id, event_id) WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id + GROUP BY room_id, thread_id """ args.extend((user_id,)) txn.execute(sql, args) - return { - room_id: latest_stream_ordering - for room_id, latest_stream_ordering in txn.fetchall() - } + + result: Dict[str, _RoomReceipt] = {} + for room_id, thread_id, stream_ordering in txn: + room_receipt = result.setdefault(room_id, _RoomReceipt()) + if thread_id is None: + room_receipt.unthreaded_stream_ordering = stream_ordering + else: + room_receipt.threaded_stream_ordering[thread_id] = stream_ordering + + return result async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -674,9 +917,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_push_actions_txn( txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool]]: + ) -> List[Tuple[str, str, str, int, str, bool]]: sql = """ - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight + SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering, + ep.actions, ep.highlight FROM event_push_actions AS ep WHERE ep.user_id = ? @@ -686,7 +930,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ORDER BY ep.stream_ordering ASC LIMIT ? """ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) - return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) + return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall()) push_actions = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn @@ -699,10 +943,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas stream_ordering=stream_ordering, actions=_deserialize_action(actions, highlight), ) - for event_id, room_id, stream_ordering, actions, highlight in push_actions - # Only include push actions with a stream ordering after any receipt, or without any - # receipt present (invited to but never read rooms). - if stream_ordering > receipts_by_room.get(room_id, 0) + for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions + if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread( + thread_id, stream_ordering + ) ] # Now sort it so it's ordered correctly, since currently it will @@ -746,10 +990,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_push_actions_txn( txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool, int]]: + ) -> List[Tuple[str, str, str, int, str, bool, int]]: sql = """ - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight, e.received_ts + SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering, + ep.actions, ep.highlight, e.received_ts FROM event_push_actions AS ep INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -760,7 +1004,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ORDER BY ep.stream_ordering DESC LIMIT ? """ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) - return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) + return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall()) push_actions = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn @@ -775,10 +1019,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas actions=_deserialize_action(actions, highlight), received_ts=received_ts, ) - for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions - # Only include push actions with a stream ordering after any receipt, or without any - # receipt present (invited to but never read rooms). - if stream_ordering > receipts_by_room.get(room_id, 0) + for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions + if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread( + thread_id, stream_ordering + ) ] # Now sort it so it's ordered correctly, since currently it will @@ -1102,7 +1346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) sql = """ - SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering + SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering FROM receipts_linearized AS r INNER JOIN events AS e USING (event_id) WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? @@ -1123,55 +1367,105 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas limit, ), ) - rows = cast(List[Tuple[int, str, str, int]], txn.fetchall()) + rows = cast(List[Tuple[int, str, str, Optional[str], int]], txn.fetchall()) # For each new read receipt we delete push actions from before it and # recalculate the summary. - for _, room_id, user_id, stream_ordering in rows: + # + # Care must be taken of whether it is a threaded or unthreaded receipt. + for _, room_id, user_id, thread_id, stream_ordering in rows: # Only handle our own read receipts. if not self.hs.is_mine_id(user_id): continue + thread_clause = "" + thread_args: Tuple = () + if thread_id is not None: + thread_clause = "AND thread_id = ?" + thread_args = (thread_id,) + + # For each new read receipt we delete push actions from before it and + # recalculate the summary. txn.execute( - """ + f""" DELETE FROM event_push_actions WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? AND highlight = 0 + {thread_clause} """, - (room_id, user_id, stream_ordering), + (room_id, user_id, stream_ordering, *thread_args), ) + # First ensure that the existing rows have an updated thread_id field. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. - notif_count, unread_count = self._get_notif_unread_count_for_user_room( - txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering + unread_counts = self._get_notif_unread_count_for_user_room( + txn, + room_id, + user_id, + stream_ordering, + old_rotate_stream_ordering, + thread_id, ) - # First ensure that the existing rows have an updated thread_id field. - txn.execute( - """ - UPDATE event_push_summary - SET thread_id = ? - WHERE room_id = ? AND user_id = ? AND thread_id is NULL - """, - ("main", room_id, user_id), - ) + # For an unthreaded receipt, mark the summary for all threads in the room + # as cleared. + if thread_id is None: + self.db_pool.simple_update_txn( + txn, + table="event_push_summary", + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={ + "notif_count": 0, + "unread_count": 0, + "stream_ordering": old_rotate_stream_ordering, + "last_receipt_stream_ordering": stream_ordering, + }, + ) + + # For a threaded receipt, we *always* want to update that receipt, + # event if there are no new notifications in that thread. This ensures + # the stream_ordering & last_receipt_stream_ordering are updated. + elif not unread_counts: + unread_counts = [(0, 0, thread_id)] - # Replace the previous summary with the new counts. - # - # TODO(threads): Upsert per-thread instead of setting them all to main. - self.db_pool.simple_upsert_txn( + # Then any updated threads get their notification count and unread + # count updated. + self.db_pool.simple_update_many_txn( txn, table="event_push_summary", - keyvalues={"room_id": room_id, "user_id": user_id, "thread_id": "main"}, - values={ - "notif_count": notif_count, - "unread_count": unread_count, - "stream_ordering": old_rotate_stream_ordering, - "last_receipt_stream_ordering": stream_ordering, - }, + key_names=("room_id", "user_id", "thread_id"), + key_values=[(room_id, user_id, row[2]) for row in unread_counts], + value_names=( + "notif_count", + "unread_count", + "stream_ordering", + "last_receipt_stream_ordering", + ), + value_values=[ + (row[0], row[1], old_rotate_stream_ordering, stream_ordering) + for row in unread_counts + ], ) # We always update `event_push_summary_last_receipt_stream_id` to @@ -1257,25 +1551,38 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas rotate_to_stream_ordering: The new maximum event stream ordering to summarise. """ + # Ensure that any new actions have an updated thread_id. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL + """, + (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + # XXX Do we need to update summaries here too? + # Calculate the new counts that should be upserted into event_push_summary sql = """ - SELECT user_id, room_id, + SELECT user_id, room_id, thread_id, coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering FROM ( - SELECT user_id, room_id, count(*) as cnt, + SELECT user_id, room_id, thread_id, count(*) as cnt, max(ea.stream_ordering) as stream_ordering FROM event_push_actions AS ea - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ? AND ( old.last_receipt_stream_ordering IS NULL OR old.last_receipt_stream_ordering < ea.stream_ordering ) AND %s = 1 - GROUP BY user_id, room_id + GROUP BY user_id, room_id, thread_id ) AS upd - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) """ # First get the count of unread messages. @@ -1289,11 +1596,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # object because we might not have the same amount of rows in each of them. To do # this, we use a dict indexed on the user ID and room ID to make it easier to # populate. - summaries: Dict[Tuple[str, str], _EventPushSummary] = {} + summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {} for row in txn: - summaries[(row[0], row[1])] = _EventPushSummary( - unread_count=row[2], - stream_ordering=row[3], + summaries[(row[0], row[1], row[2])] = _EventPushSummary( + unread_count=row[3], + stream_ordering=row[4], notif_count=0, ) @@ -1304,48 +1611,50 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) for row in txn: - if (row[0], row[1]) in summaries: - summaries[(row[0], row[1])].notif_count = row[2] + if (row[0], row[1], row[2]) in summaries: + summaries[(row[0], row[1], row[2])].notif_count = row[3] else: # Because the rules on notifying are different than the rules on marking # a message unread, we might end up with messages that notify but aren't # marked unread, so we might not have a summary for this (user, room) # tuple to complete. - summaries[(row[0], row[1])] = _EventPushSummary( + summaries[(row[0], row[1], row[2])] = _EventPushSummary( unread_count=0, - stream_ordering=row[3], - notif_count=row[2], + stream_ordering=row[4], + notif_count=row[3], ) logger.info("Rotating notifications, handling %d rows", len(summaries)) - # Ensure that any updated threads have an updated thread_id. - txn.execute_batch( - """ - UPDATE event_push_summary - SET thread_id = ? - WHERE room_id = ? AND user_id = ? AND thread_id is NULL - """, - [("main", room_id, user_id) for user_id, room_id in summaries], - ) - self.db_pool.simple_update_many_txn( - txn, - table="event_push_summary", - key_names=("user_id", "room_id", "thread_id"), - key_values=[(user_id, room_id, None) for user_id, room_id in summaries], - value_names=("thread_id",), - value_values=[("main",) for _ in summaries], - ) + # Ensure that any updated threads have the proper thread_id. + if not self._event_push_backfill_thread_id_done: + txn.execute_batch( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + [ + (MAIN_TIMELINE, room_id, user_id) + for user_id, room_id, _ in summaries + ], + ) - # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", key_names=("user_id", "room_id", "thread_id"), - key_values=[(user_id, room_id, "main") for user_id, room_id in summaries], + key_values=[ + (user_id, room_id, thread_id) + for user_id, room_id, thread_id in summaries + ], value_names=("notif_count", "unread_count", "stream_ordering"), value_values=[ - (summary.notif_count, summary.unread_count, summary.stream_ordering) + ( + summary.notif_count, + summary.unread_count, + summary.stream_ordering, + ) for summary in summaries.values() ], ) @@ -1356,7 +1665,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) async def _remove_old_push_actions_that_have_rotated(self) -> None: - """Clear out old push actions that have been summarised.""" + """ + Clear out old push actions that have been summarised (and are older than + 1 day ago). + """ # We want to clear out anything that is older than a day that *has* already # been rotated. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3e158279..00880bb3 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -1616,7 +1616,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1866,6 +1866,34 @@ class PersistEventsStore: }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -1989,35 +2017,48 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ - # Fetch the current relation of the event being redacted. - redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + # Fetch the relation of the event being redacted. + row = self.db_pool.simple_select_one_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id}, - retcol="relates_to_id", + retcols=("relates_to_id", "relation_type"), allow_none=True, ) + # Nothing to do if no relation is found. + if row is None: + return + + redacted_relates_to = row["relates_to_id"] + rel_type = row["relation_type"] + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + # Any relation information for the related event must be cleared. - if redacted_relates_to is not None: - self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) - ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + if rel_type == RelationTypes.ANNOTATION: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) ) + if rel_type == RelationTypes.THREAD: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_summary, (redacted_relates_to,) ) @@ -2025,14 +2066,41 @@ class PersistEventsStore: txn, self.store.get_thread_participated, (redacted_relates_to,) ) self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), + txn, self.store.get_threads, (room_id,) ) - self.db_pool.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + # Find the new latest event in the thread. + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) + + # If a latest event is found, update the threads table, this might + # be the same current latest event (if an earlier event in the thread + # was redacted). + latest_event_row = txn.fetchone() + if latest_event_row: + self.db_pool.simple_upsert_txn( + txn, + table="threads", + keyvalues={"room_id": room_id, "thread_id": redacted_relates_to}, + values={ + "latest_event_id": latest_event_row[0], + "topological_ordering": latest_event_row[1], + "stream_ordering": latest_event_row[2], + }, + ) + + # Otherwise, delete the thread: it no longer exists. + else: + self.db_pool.simple_delete_one_txn( + txn, table="threads", keyvalues={"thread_id": redacted_relates_to} + ) def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7cdc9fe9..69fea452 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -374,7 +374,7 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - The event, or None if the event was not found. + The event, or None if the event was not found and allow_none is `True`. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) @@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = await self._get_events_from_cache_or_db( + event_entry_map = await self.get_unredacted_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = await self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self.get_unredacted_events_from_cache_or_db( + [redacted_event_id] + ) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore): return events @cancellable - async def _get_events_from_cache_or_db( - self, event_ids: Iterable[str], allow_rejected: bool = False + async def get_unredacted_events_from_cache_or_db( + self, + event_ids: Iterable[str], + allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. + Note that the events pulled by this function will not have any redactions + applied, and no guarantee is made about the ordering of the events returned. + If events are pulled from the database, they will be cached for future lookups. Unknown events are omitted from the response. @@ -1495,21 +1502,15 @@ class EventsWorkerStore(SQLBaseStore): Returns: a dict {event_id -> bool} """ - # if the event cache contains the event, obviously we've seen it. - - cache_results = { - event_id - for event_id in event_ids - if await self._get_event_cache.contains((event_id,)) - } - results = dict.fromkeys(cache_results, True) - remaining = [ - event_id for event_id in event_ids if event_id not in cache_results - ] - if not remaining: - return results + # TODO: We used to query the _get_event_cache here as a fast-path before + # hitting the database. For if an event were in the cache, we've presumably + # seen it before. + # + # But this is currently an invalid assumption due to the _get_event_cache + # not being invalidated when purging events from a room. The optimisation can + # be re-added after https://github.com/matrix-org/synapse/issues/13476 - def have_seen_events_txn(txn: LoggingTransaction) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1517,16 +1518,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", remaining + txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update({eid: (eid in found_events) for eid in remaining}) + return {eid: (eid in found_events) for eid in event_ids} - await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) - return results + return await self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn + ) @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: @@ -1969,12 +1971,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_backward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_backward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question has any of its prev_events listed as a # backward extremity, it's next to a gap. @@ -2024,12 +2031,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_forward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question is a forward extremity, we will just # consider any potential forward gap as not a gap since it's one of @@ -2110,13 +2122,33 @@ class EventsWorkerStore(SQLBaseStore): The closest event_id otherwise None if we can't find any event in the given direction. """ + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" - sql_template = """ + sql_template = f""" SELECT event_id FROM events LEFT JOIN rejections USING (event_id) WHERE - origin_server_ts %s ? - AND room_id = ? + room_id = ? + AND origin_server_ts {comparison_operator} ? + /** + * Make sure the event isn't an `outlier` because we have no way + * to later check whether it's next to a gap. `outliers` do not + * have entries in the `event_edges`, `event_forward_extremeties`, + * and `event_backward_extremities` tables to check against + * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`). + */ + AND NOT outlier /* Make sure event is not rejected */ AND rejections.event_id IS NULL /** @@ -2126,27 +2158,14 @@ class EventsWorkerStore(SQLBaseStore): * Finally, we can tie-break based on when it was received on the server * (`stream_ordering`). */ - ORDER BY origin_server_ts %s, depth %s, stream_ordering %s + ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order} LIMIT 1; """ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - txn.execute( - sql_template % (comparison_operator, order, order, order), - (timestamp, room_id), + sql_template, + (room_id, timestamp), ) row = txn.fetchone() if row: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index ed17b2e7..51416b22 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,7 +29,6 @@ from typing import ( ) from synapse.api.errors import StoreError -from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -63,9 +62,7 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], - enabled_map: Dict[str, bool], - experimental_config: ExperimentalConfig, + rawrules: List[JsonDict], enabled_map: Dict[str, bool] ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -81,16 +78,9 @@ def _load_rules( for rawrule in rawrules ] - push_rules = PushRules( - ruleslist, - ) + push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules( - push_rules, - enabled_map, - msc3786_enabled=experimental_config.msc3786_enabled, - msc3772_enabled=experimental_config.msc3772_enabled, - ) + filtered_rules = FilteredPushRules(push_rules, enabled_map) return filtered_rules @@ -170,7 +160,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules(rows, enabled_map) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -229,9 +219,7 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 246f78ac..dc698952 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -418,6 +418,8 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) + if row["thread_id"]: + receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] results = { room_id: [results[room_id]] if room_id in results else [] diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 898947af..c022510e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, @@ -28,19 +29,48 @@ from typing import ( import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + +@attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ Contains enough information about a related event in order to properly filter @@ -56,6 +86,76 @@ class _RelatedEvent: class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, @@ -779,57 +879,192 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. + + Args: + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from a previous next_batch, or from the start if None. + + Returns: + A tuple of: + A list of thread root event IDs. + + The next_batch, if one exists. + """ + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: + sql = f""" + SELECT thread_id, topological_ordering, stream_ordering + FROM threads + WHERE + room_id = ? + {pagination_clause} + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT ? + """ + + def _get_threads_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) + + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) + + @cached() + async def get_thread_id(self, event_id: str) -> str: """ - Fetch event metadata for events which related to the same event as the given event. + Get the thread ID for an event. This considers multi-level relations, + e.g. an annotation to an event which is part of a thread. + + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. - If the given event has no relation information, returns an empty dictionary. + See also get_thread_id_for_receipts. Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. + event_id: The event ID to fetch the thread ID for. Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) - sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. + sql = """ + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ - def _get_event_relations( - txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result + def _get_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations + "get_related_thread_id", _get_related_thread_id ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7412bce2..7d97f8f6 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -97,6 +97,12 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PartialStateResyncInfo: + joined_via: Optional[str] + servers_in_room: List[str] = attr.ib(factory=list) + + class RoomWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -207,21 +213,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _construct_room_type_where_clause( self, room_types: Union[List[Union[str, None]], None] - ) -> Tuple[Union[str, None], List[str]]: + ) -> Tuple[Union[str, None], list]: if not room_types: return None, [] - else: - # We use None when we want get rooms without a type - is_null_clause = "" - if None in room_types: - is_null_clause = "OR room_type IS NULL" - room_types = [value for value in room_types if value is not None] + # Since None is used to represent a room without a type, care needs to + # be taken into account when constructing the where clause. + clauses = [] + args: list = [] + + room_types_set = set(room_types) + + # We use None to represent a room without a type. + if None in room_types_set: + clauses.append("room_type IS NULL") + room_types_set.remove(None) + + # If there are other room types, generate the proper clause. + if room_types: list_clause, args = make_in_list_sql_clause( - self.database_engine, "room_type", room_types + self.database_engine, "room_type", room_types_set ) + clauses.append(list_clause) - return f"({list_clause} {is_null_clause})", args + return f"({' OR '.join(clauses)})", args async def count_public_rooms( self, @@ -241,14 +256,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] - room_type_clause, args = self._construct_room_type_where_clause( - search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) - if search_filter - else None - ) - room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" - query_args += args - if network_tuple: if network_tuple.appservice_id: published_sql = """ @@ -268,6 +275,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): UNION SELECT room_id from appservice_room_list """ + room_type_clause, args = self._construct_room_type_where_clause( + search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) + if search_filter + else None + ) + room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" + query_args += args + sql = f""" SELECT COUNT(*) @@ -1151,17 +1166,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): desc="get_partial_state_servers_at_join", ) - async def get_partial_state_rooms_and_servers( + async def get_partial_state_room_resync_info( self, - ) -> Mapping[str, Collection[str]]: - """Get all rooms containing events with partial state, and the servers known - to be in the room. + ) -> Mapping[str, PartialStateResyncInfo]: + """Get all rooms containing events with partial state, and the information + needed to restart a "resync" of those rooms. Returns: A dictionary of rooms with partial state, with room IDs as keys and lists of servers in rooms as values. """ - room_servers: Dict[str, List[str]] = {} + room_servers: Dict[str, PartialStateResyncInfo] = {} + + rows = await self.db_pool.simple_select_list( + table="partial_state_rooms", + keyvalues={}, + retcols=("room_id", "joined_via"), + desc="get_server_which_served_partial_join", + ) + + for row in rows: + room_id = row["room_id"] + joined_via = row["joined_via"] + room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) rows = await self.db_pool.simple_select_list( "partial_state_rooms_servers", @@ -1173,7 +1200,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): for row in rows: room_id = row["room_id"] server_name = row["server_name"] - room_servers.setdefault(room_id, []).append(server_name) + entry = room_servers.get(room_id) + if entry is None: + # There is a foreign key constraint which enforces that every room_id in + # partial_state_rooms_servers appears in partial_state_rooms. So we + # expect `entry` to be non-null. (This reasoning fails if we've + # partial-joined between the two SELECTs, but this is unlikely to happen + # in practice.) + continue + entry.servers_in_room.append(server_name) return room_servers @@ -1818,6 +1853,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: """Mark the given room as containing events with partial state. @@ -1833,6 +1869,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): servers: other servers known to be in the room device_lists_stream_id: the device_lists stream ID at the time when we first joined the room. + joined_via: the server name we requested a partial join from. """ await self.db_pool.runInteraction( "store_partial_state_room", @@ -1840,6 +1877,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id, servers, device_lists_stream_id, + joined_via, ) def _store_partial_state_room_txn( @@ -1848,6 +1886,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: DatabasePool.simple_insert_txn( txn, @@ -1857,6 +1896,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "device_lists_stream_id": device_lists_stream_id, # To be updated later once the join event is persisted. "join_event_id": None, + "joined_via": joined_via, }, ) DatabasePool.simple_insert_many_txn( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2337289d..2ed6ad75 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -666,7 +666,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cached_method_name="get_rooms_for_user", list_name="user_ids", ) - async def get_rooms_for_users( + async def _get_rooms_for_users( self, user_ids: Collection[str] ) -> Dict[str, FrozenSet[str]]: """A batched version of `get_rooms_for_user`. @@ -697,6 +697,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {key: frozenset(rooms) for key, rooms in user_rooms.items()} + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched wrapper around `_get_rooms_for_users`, to prevent locking + other calls to `get_rooms_for_user` for large user lists. + """ + all_user_rooms: Dict[str, FrozenSet[str]] = {} + + # 250 users is pretty arbitrary but the data can be quite large if users + # are in many rooms. + for user_ids in batch_iter(user_ids, 250): + all_user_rooms.update(await self._get_rooms_for_users(user_ids)) + + return all_user_rooms + @cached(max_entries=10000) async def does_pair_of_users_share_a_room( self, user_id: str, other_user_id: str diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 530f04e1..09ce855a 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1024,28 +1042,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - async def get_all_new_events_stream( - self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False - ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: + async def get_all_new_event_ids_stream( + self, + from_id: int, + current_id: int, + limit: int, + ) -> Tuple[int, Dict[str, Optional[int]]]: """Get all new events - Returns all events with from_id < stream_ordering <= current_id. + Returns all event ids with from_id < stream_ordering <= current_id. Args: from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return - get_prev_content: whether to fetch previous event content Returns: - A tuple of (next_id, events, event_to_received_ts), where `next_id` + A tuple of (next_id, event_to_received_ts), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, the `current_id`). The `event_to_received_ts` is - a dictionary mapping event ID to the event `received_ts`. + a dictionary mapping event ID to the event `received_ts`, sorted by ascending + stream_ordering. """ - def get_all_new_events_stream_txn( + def get_all_new_event_ids_stream_txn( txn: LoggingTransaction, ) -> Tuple[int, Dict[str, Optional[int]]]: sql = ( @@ -1070,15 +1091,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, event_to_received_ts upper_bound, event_to_received_ts = await self.db_pool.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn - ) - - events = await self.get_events_as_list( - event_to_received_ts.keys(), - get_prev_content=get_prev_content, + "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn ) - return upper_bound, events, event_to_received_ts + return upper_bound, event_to_received_ts async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: @@ -1202,8 +1218,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - assert int(limit) >= 0 - # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. @@ -1282,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1298,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" |