summaryrefslogtreecommitdiff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-10-30 12:23:11 +0100
committerAndrej Shadura <andrewsh@debian.org>2022-10-30 12:23:11 +0100
commit53aa9684018b7b070d195d0e66d121ea43441434 (patch)
tree12b790285708053512a77dd536754d744281b8eb /synapse/storage/databases
parent41aea8fc55649be44b0d55810d8080f8b45fea9e (diff)
New upstream version 1.70.1
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/cache.py10
-rw-r--r--synapse/storage/databases/main/devices.py10
-rw-r--r--synapse/storage/databases/main/event_federation.py54
-rw-r--r--synapse/storage/databases/main/event_push_actions.py620
-rw-r--r--synapse/storage/databases/main/events.py100
-rw-r--r--synapse/storage/databases/main/events_worker.py107
-rw-r--r--synapse/storage/databases/main/push_rule.py22
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/databases/main/relations.py315
-rw-r--r--synapse/storage/databases/main/room.py86
-rw-r--r--synapse/storage/databases/main/roommember.py17
-rw-r--r--synapse/storage/databases/main/stream.py59
12 files changed, 1083 insertions, 319 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 3b8ed1f7..ddb73977 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -244,12 +244,18 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
# redacted.
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
+ self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
+ self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self._attempt_to_invalidate_cache(
"get_invited_rooms_for_local_user", (state_key,)
)
+ self._attempt_to_invalidate_cache(
+ "get_rooms_for_user_with_stream_ordering", (state_key,)
+ )
+ self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
@@ -259,9 +265,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
- self._attempt_to_invalidate_cache(
- "get_mutual_event_relations_for_rel_type", (relates_to,)
- )
+ self._attempt_to_invalidate_cache("get_threads", (room_id,))
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 18358eca..830b076a 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -539,9 +539,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
- "org.matrix.opentracing_context": opentracing_context,
}
+ if opentracing_context != "{}":
+ result["org.matrix.opentracing_context"] = opentracing_context
+
prev_id = stream_id
if device is not None:
@@ -549,7 +551,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
if keys:
result["keys"] = keys
- device_display_name = device.display_name
+ device_display_name = None
+ if (
+ self.hs.config.federation.allow_device_name_lookup_over_federation
+ ):
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6b9a629e..309a4ba6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_id: The event that failed to be fetched or processed
cause: The error message or reason that we failed to pull the event
"""
+ logger.debug(
+ "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s",
+ room_id,
+ event_id,
+ cause,
+ )
await self.db_pool.runInteraction(
"record_event_failed_pull_attempt",
self._record_event_failed_pull_attempt_upsert_txn,
@@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause))
+ @trace
+ async def get_event_ids_to_not_pull_from_backoff(
+ self,
+ room_id: str,
+ event_ids: Collection[str],
+ ) -> List[str]:
+ """
+ Filter down the events to ones that we've failed to pull before recently. Uses
+ exponential backoff.
+
+ Args:
+ room_id: The room that the events belong to
+ event_ids: A list of events to filter down
+
+ Returns:
+ List of event_ids that should not be attempted to be pulled
+ """
+ event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
+ table="event_failed_pull_attempts",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=(
+ "event_id",
+ "last_attempt_ts",
+ "num_attempts",
+ ),
+ desc="get_event_ids_to_not_pull_from_backoff",
+ )
+
+ current_time = self._clock.time_msec()
+ return [
+ event_failed_pull_attempt["event_id"]
+ for event_failed_pull_attempt in event_failed_pull_attempts
+ # Exponential back-off (up to the upper bound) so we don't try to
+ # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
+ if current_time
+ < event_failed_pull_attempt["last_attempt_ts"]
+ + (
+ 2
+ ** min(
+ event_failed_pull_attempt["num_attempts"],
+ BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
+ )
+ )
+ * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
+ ]
+
async def get_missing_events(
self,
room_id: str,
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 72cf91eb..b283ab0f 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -88,7 +88,7 @@ from typing import (
import attr
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -119,6 +119,32 @@ DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
]
+@attr.s(slots=True, auto_attribs=True)
+class _RoomReceipt:
+ """
+ HttpPushAction instances include the information used to generate HTTP
+ requests to a push gateway.
+ """
+
+ unthreaded_stream_ordering: int = 0
+ # threaded_stream_ordering includes the main pseudo-thread.
+ threaded_stream_ordering: Dict[str, int] = attr.Factory(dict)
+
+ def is_unread(self, thread_id: str, stream_ordering: int) -> bool:
+ """Returns True if the stream ordering is unread according to the receipt information."""
+
+ # Only include push actions with a stream ordering after both the unthreaded
+ # and threaded receipt. Properly handles a user without any receipts present.
+ return (
+ self.unthreaded_stream_ordering < stream_ordering
+ and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering
+ )
+
+
+# A _RoomReceipt with no receipts in it.
+MISSING_ROOM_RECEIPT = _RoomReceipt()
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class HttpPushAction:
"""
@@ -157,7 +183,7 @@ class UserPushAction(EmailPushAction):
@attr.s(slots=True, auto_attribs=True)
class NotifCounts:
"""
- The per-user, per-room count of notifications. Used by sync and push.
+ The per-user, per-room, per-thread count of notifications. Used by sync and push.
"""
notify_count: int = 0
@@ -165,6 +191,21 @@ class NotifCounts:
highlight_count: int = 0
+@attr.s(slots=True, auto_attribs=True)
+class RoomNotifCounts:
+ """
+ The per-user, per-room count of notifications. Used by sync and push.
+ """
+
+ main_timeline: NotifCounts
+ # Map of thread ID to the notification counts.
+ threads: Dict[str, NotifCounts]
+
+ def __len__(self) -> int:
+ # To properly account for the amount of space in any caches.
+ return len(self.threads) + 1
+
+
def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str:
@@ -253,6 +294,44 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self._background_backfill_thread_id,
)
+ # Indexes which will be used to quickly make the thread_id column non-null.
+ self.db_pool.updates.register_background_index_update(
+ "event_push_actions_thread_id_null",
+ index_name="event_push_actions_thread_id_null",
+ table="event_push_actions",
+ columns=["thread_id"],
+ where_clause="thread_id IS NULL",
+ )
+ self.db_pool.updates.register_background_index_update(
+ "event_push_summary_thread_id_null",
+ index_name="event_push_summary_thread_id_null",
+ table="event_push_summary",
+ columns=["thread_id"],
+ where_clause="thread_id IS NULL",
+ )
+
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates the event_push_actions and event_push_summary tables.
+ self._clock.call_later(0.0, self._check_event_push_backfill_thread_id)
+ self._event_push_backfill_thread_id_done = False
+
+ @wrap_as_background_process("check_event_push_backfill_thread_id")
+ async def _check_event_push_backfill_thread_id(self) -> None:
+ """
+ Has thread_id finished backfilling?
+
+ If not, we need to just-in-time update it so the queries work.
+ """
+ done = await self.db_pool.updates.has_completed_background_update(
+ "event_push_backfill_thread_id"
+ )
+
+ if done:
+ self._event_push_backfill_thread_id_done = True
+ else:
+ # Reschedule to run.
+ self._clock.call_later(15.0, self._check_event_push_backfill_thread_id)
+
async def _background_backfill_thread_id(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -384,12 +463,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
- @cached(tree=True, max_entries=5000)
+ @cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
user_id: str,
- ) -> NotifCounts:
+ ) -> RoomNotifCounts:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after their latest read receipt.
@@ -402,8 +481,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for.
Returns
- A NotifCounts object containing the notification count, the highlight count
- and the unread message count.
+ A RoomNotifCounts object containing the notification count, the
+ highlight count and the unread message count for both the main timeline
+ and threads.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -417,7 +497,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction,
room_id: str,
user_id: str,
- ) -> NotifCounts:
+ ) -> RoomNotifCounts:
# Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_unthreaded_receipt_for_user_txn(
txn,
@@ -451,8 +531,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction,
room_id: str,
user_id: str,
- receipt_stream_ordering: int,
- ) -> NotifCounts:
+ unthreaded_receipt_stream_ordering: int,
+ ) -> RoomNotifCounts:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
@@ -460,78 +540,223 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: The database transaction.
room_id: The room ID to get unread counts for.
user_id: The user ID to get unread counts for.
- receipt_stream_ordering: The stream ordering of the user's latest
- receipt in the room. If there are no receipts, the stream ordering
- of the user's join event.
+ unthreaded_receipt_stream_ordering: The stream ordering of the user's latest
+ unthreaded receipt in the room. If there are no unthreaded receipts,
+ the stream ordering of the user's join event.
- Returns
- A NotifCounts object containing the notification count, the highlight count
- and the unread message count.
+ Returns:
+ A RoomNotifCounts object containing the notification count, the
+ highlight count and the unread message count for both the main timeline
+ and threads.
"""
- counts = NotifCounts()
+ main_counts = NotifCounts()
+ thread_counts: Dict[str, NotifCounts] = {}
+
+ def _get_thread(thread_id: str) -> NotifCounts:
+ if thread_id == MAIN_TIMELINE:
+ return main_counts
+ return thread_counts.setdefault(thread_id, NotifCounts())
+
+ receipt_types_clause, receipts_args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+ )
+
+ # First ensure that the existing rows have an updated thread_id field.
+ if not self._event_push_backfill_thread_id_done:
+ txn.execute(
+ """
+ UPDATE event_push_summary
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ (MAIN_TIMELINE, room_id, user_id),
+ )
+ txn.execute(
+ """
+ UPDATE event_push_actions
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ (MAIN_TIMELINE, room_id, user_id),
+ )
# First we pull the counts from the summary table.
#
- # We check that `last_receipt_stream_ordering` matches the stream
- # ordering given. If it doesn't match then a new read receipt has arrived and
- # we haven't yet updated the counts in `event_push_summary` to reflect
- # that; in that case we simply ignore `event_push_summary` counts
- # and do a manual count of all of the rows in the `event_push_actions` table
- # for this user/room.
+ # We check that `last_receipt_stream_ordering` matches the stream ordering of the
+ # latest receipt for the thread (which may be either the unthreaded read receipt
+ # or the threaded read receipt).
+ #
+ # If it doesn't match then a new read receipt has arrived and we haven't yet
+ # updated the counts in `event_push_summary` to reflect that; in that case we
+ # simply ignore `event_push_summary` counts.
#
- # If `last_receipt_stream_ordering` is null then that means it's up to
- # date (as the row was written by an older version of Synapse that
+ # We then do a manual count of all the rows in the `event_push_actions` table
+ # for any user/room/thread which did not have a valid summary found.
+ #
+ # If `last_receipt_stream_ordering` is null then that means it's up-to-date
+ # (as the row was written by an older version of Synapse that
# updated `event_push_summary` synchronously when persisting a new read
# receipt).
txn.execute(
- """
- SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
+ f"""
+ SELECT notif_count, COALESCE(unread_count, 0), thread_id
FROM event_push_summary
+ LEFT JOIN (
+ SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND room_id = ?
+ AND stream_ordering > ?
+ AND {receipt_types_clause}
+ GROUP BY thread_id
+ ) AS receipts USING (thread_id)
WHERE room_id = ? AND user_id = ?
AND (
- (last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
- OR last_receipt_stream_ordering = ?
- )
+ (last_receipt_stream_ordering IS NULL AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?))
+ OR last_receipt_stream_ordering = COALESCE(threaded_receipt_stream_ordering, ?)
+ ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0)
""",
- (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
+ (
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ *receipts_args,
+ room_id,
+ user_id,
+ unthreaded_receipt_stream_ordering,
+ unthreaded_receipt_stream_ordering,
+ ),
)
- row = txn.fetchone()
-
- summary_stream_ordering = 0
- if row:
- summary_stream_ordering = row[0]
- counts.notify_count += row[1]
- counts.unread_count += row[2]
+ summarised_threads = set()
+ for notif_count, unread_count, thread_id in txn:
+ summarised_threads.add(thread_id)
+ counts = _get_thread(thread_id)
+ counts.notify_count += notif_count
+ counts.unread_count += unread_count
# Next we need to count highlights, which aren't summarised
- sql = """
- SELECT COUNT(*) FROM event_push_actions
+ sql = f"""
+ SELECT COUNT(*), thread_id FROM event_push_actions
+ LEFT JOIN (
+ SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND room_id = ?
+ AND stream_ordering > ?
+ AND {receipt_types_clause}
+ GROUP BY thread_id
+ ) AS receipts USING (thread_id)
WHERE user_id = ?
AND room_id = ?
- AND stream_ordering > ?
+ AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)
AND highlight = 1
+ GROUP BY thread_id
"""
- txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
- row = txn.fetchone()
- if row:
- counts.highlight_count += row[0]
+ txn.execute(
+ sql,
+ (
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ *receipts_args,
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ ),
+ )
+ for highlight_count, thread_id in txn:
+ _get_thread(thread_id).highlight_count += highlight_count
+
+ # For threads which were summarised we need to count actions since the last
+ # rotation.
+ thread_id_clause, thread_id_args = make_in_list_sql_clause(
+ self.database_engine, "thread_id", summarised_threads
+ )
+
+ # The (inclusive) event stream ordering that was previously summarised.
+ rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
+ )
+
+ unread_counts = self._get_notif_unread_count_for_user_room(
+ txn, room_id, user_id, rotated_upto_stream_ordering
+ )
+ for notif_count, unread_count, thread_id in unread_counts:
+ if thread_id not in summarised_threads:
+ continue
+
+ if thread_id == MAIN_TIMELINE:
+ counts.notify_count += notif_count
+ counts.unread_count += unread_count
+ elif thread_id in thread_counts:
+ thread_counts[thread_id].notify_count += notif_count
+ thread_counts[thread_id].unread_count += unread_count
+ else:
+ # Previous thread summaries of 0 are discarded above.
+ #
+ # TODO If empty summaries are deleted this can be removed.
+ thread_counts[thread_id] = NotifCounts(
+ notify_count=notif_count,
+ unread_count=unread_count,
+ highlight_count=0,
+ )
# Finally we need to count push actions that aren't included in the
# summary returned above. This might be due to recent events that haven't
# been summarised yet or the summary is out of date due to a recent read
# receipt.
- start_unread_stream_ordering = max(
- receipt_stream_ordering, summary_stream_ordering
- )
- notify_count, unread_count = self._get_notif_unread_count_for_user_room(
- txn, room_id, user_id, start_unread_stream_ordering
+ sql = f"""
+ SELECT
+ COUNT(CASE WHEN notif = 1 THEN 1 END),
+ COUNT(CASE WHEN unread = 1 THEN 1 END),
+ thread_id
+ FROM event_push_actions
+ LEFT JOIN (
+ SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND room_id = ?
+ AND stream_ordering > ?
+ AND {receipt_types_clause}
+ GROUP BY thread_id
+ ) AS receipts USING (thread_id)
+ WHERE user_id = ?
+ AND room_id = ?
+ AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)
+ AND NOT {thread_id_clause}
+ GROUP BY thread_id
+ """
+ txn.execute(
+ sql,
+ (
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ *receipts_args,
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ *thread_id_args,
+ ),
)
+ for notif_count, unread_count, thread_id in txn:
+ counts = _get_thread(thread_id)
+ counts.notify_count += notif_count
+ counts.unread_count += unread_count
- counts.notify_count += notify_count
- counts.unread_count += unread_count
-
- return counts
+ return RoomNotifCounts(main_counts, thread_counts)
def _get_notif_unread_count_for_user_room(
self,
@@ -540,7 +765,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
stream_ordering: int,
max_stream_ordering: Optional[int] = None,
- ) -> Tuple[int, int]:
+ thread_id: Optional[str] = None,
+ ) -> List[Tuple[int, int, str]]:
"""Returns the notify and unread counts from `event_push_actions` for
the given user/room in the given range.
@@ -554,45 +780,55 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
stream_ordering: The (exclusive) minimum stream ordering to consider.
max_stream_ordering: The (inclusive) maximum stream ordering to consider.
If this is not given, then no maximum is applied.
+ thread_id: The thread ID to fetch unread counts for. If this is not provided
+ then the results for *all* threads is returned.
+
+ Note that if this is provided the resulting list will only have 0 or
+ 1 tuples in it.
Return:
- A tuple of the notif count and unread count in the given range.
+ A tuple of the notif count and unread count in the given range for
+ each thread.
"""
# If there have been no events in the room since the stream ordering,
# there can't be any push actions either.
if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
- return 0, 0
+ return []
- clause = ""
+ stream_ordering_clause = ""
args = [user_id, room_id, stream_ordering]
if max_stream_ordering is not None:
- clause = "AND ea.stream_ordering <= ?"
+ stream_ordering_clause = "AND ea.stream_ordering <= ?"
args.append(max_stream_ordering)
# If the max stream ordering is less than the min stream ordering,
# then obviously there are zero push actions in that range.
if max_stream_ordering <= stream_ordering:
- return 0, 0
+ return []
+
+ # Either limit the results to a specific thread or fetch all threads.
+ thread_id_clause = ""
+ if thread_id is not None:
+ thread_id_clause = "AND thread_id = ?"
+ args.append(thread_id)
sql = f"""
SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END),
- COUNT(CASE WHEN unread = 1 THEN 1 END)
- FROM event_push_actions ea
- WHERE user_id = ?
+ COUNT(CASE WHEN unread = 1 THEN 1 END),
+ thread_id
+ FROM event_push_actions ea
+ WHERE user_id = ?
AND room_id = ?
AND ea.stream_ordering > ?
- {clause}
+ {stream_ordering_clause}
+ {thread_id_clause}
+ GROUP BY thread_id
"""
txn.execute(sql, args)
- row = txn.fetchone()
-
- if row:
- return cast(Tuple[int, int], row)
-
- return 0, 0
+ return cast(List[Tuple[int, int, str]], txn.fetchall())
async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int
@@ -609,7 +845,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def _get_receipts_by_room_txn(
self, txn: LoggingTransaction, user_id: str
- ) -> Dict[str, int]:
+ ) -> Dict[str, _RoomReceipt]:
"""
Generate a map of room ID to the latest stream ordering that has been
read by the given user.
@@ -619,7 +855,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to fetch receipts for.
Returns:
- A map of room ID to stream ordering for all rooms the user has a receipt in.
+ A map including all rooms the user is in with a receipt. It maps
+ room IDs to _RoomReceipt instances
"""
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
@@ -628,20 +865,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
sql = f"""
- SELECT room_id, MAX(stream_ordering)
+ SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause}
AND user_id = ?
- GROUP BY room_id
+ GROUP BY room_id, thread_id
"""
args.extend((user_id,))
txn.execute(sql, args)
- return {
- room_id: latest_stream_ordering
- for room_id, latest_stream_ordering in txn.fetchall()
- }
+
+ result: Dict[str, _RoomReceipt] = {}
+ for room_id, thread_id, stream_ordering in txn:
+ room_receipt = result.setdefault(room_id, _RoomReceipt())
+ if thread_id is None:
+ room_receipt.unthreaded_stream_ordering = stream_ordering
+ else:
+ room_receipt.threaded_stream_ordering[thread_id] = stream_ordering
+
+ return result
async def get_unread_push_actions_for_user_in_range_for_http(
self,
@@ -674,9 +917,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_push_actions_txn(
txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool]]:
+ ) -> List[Tuple[str, str, str, int, str, bool]]:
sql = """
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
+ SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
+ ep.actions, ep.highlight
FROM event_push_actions AS ep
WHERE
ep.user_id = ?
@@ -686,7 +930,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
ORDER BY ep.stream_ordering ASC LIMIT ?
"""
txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
- return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
+ return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall())
push_actions = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
@@ -699,10 +943,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
stream_ordering=stream_ordering,
actions=_deserialize_action(actions, highlight),
)
- for event_id, room_id, stream_ordering, actions, highlight in push_actions
- # Only include push actions with a stream ordering after any receipt, or without any
- # receipt present (invited to but never read rooms).
- if stream_ordering > receipts_by_room.get(room_id, 0)
+ for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions
+ if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
+ thread_id, stream_ordering
+ )
]
# Now sort it so it's ordered correctly, since currently it will
@@ -746,10 +990,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_push_actions_txn(
txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool, int]]:
+ ) -> List[Tuple[str, str, str, int, str, bool, int]]:
sql = """
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
- ep.highlight, e.received_ts
+ SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
+ ep.actions, ep.highlight, e.received_ts
FROM event_push_actions AS ep
INNER JOIN events AS e USING (room_id, event_id)
WHERE
@@ -760,7 +1004,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
ORDER BY ep.stream_ordering DESC LIMIT ?
"""
txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
- return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
+ return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall())
push_actions = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
@@ -775,10 +1019,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
actions=_deserialize_action(actions, highlight),
received_ts=received_ts,
)
- for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
- # Only include push actions with a stream ordering after any receipt, or without any
- # receipt present (invited to but never read rooms).
- if stream_ordering > receipts_by_room.get(room_id, 0)
+ for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions
+ if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
+ thread_id, stream_ordering
+ )
]
# Now sort it so it's ordered correctly, since currently it will
@@ -1102,7 +1346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
sql = """
- SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
+ SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering
FROM receipts_linearized AS r
INNER JOIN events AS e USING (event_id)
WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
@@ -1123,55 +1367,105 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
limit,
),
)
- rows = cast(List[Tuple[int, str, str, int]], txn.fetchall())
+ rows = cast(List[Tuple[int, str, str, Optional[str], int]], txn.fetchall())
# For each new read receipt we delete push actions from before it and
# recalculate the summary.
- for _, room_id, user_id, stream_ordering in rows:
+ #
+ # Care must be taken of whether it is a threaded or unthreaded receipt.
+ for _, room_id, user_id, thread_id, stream_ordering in rows:
# Only handle our own read receipts.
if not self.hs.is_mine_id(user_id):
continue
+ thread_clause = ""
+ thread_args: Tuple = ()
+ if thread_id is not None:
+ thread_clause = "AND thread_id = ?"
+ thread_args = (thread_id,)
+
+ # For each new read receipt we delete push actions from before it and
+ # recalculate the summary.
txn.execute(
- """
+ f"""
DELETE FROM event_push_actions
WHERE room_id = ?
AND user_id = ?
AND stream_ordering <= ?
AND highlight = 0
+ {thread_clause}
""",
- (room_id, user_id, stream_ordering),
+ (room_id, user_id, stream_ordering, *thread_args),
)
+ # First ensure that the existing rows have an updated thread_id field.
+ if not self._event_push_backfill_thread_id_done:
+ txn.execute(
+ """
+ UPDATE event_push_summary
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ (MAIN_TIMELINE, room_id, user_id),
+ )
+ txn.execute(
+ """
+ UPDATE event_push_actions
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ (MAIN_TIMELINE, room_id, user_id),
+ )
+
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
- notif_count, unread_count = self._get_notif_unread_count_for_user_room(
- txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
+ unread_counts = self._get_notif_unread_count_for_user_room(
+ txn,
+ room_id,
+ user_id,
+ stream_ordering,
+ old_rotate_stream_ordering,
+ thread_id,
)
- # First ensure that the existing rows have an updated thread_id field.
- txn.execute(
- """
- UPDATE event_push_summary
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- ("main", room_id, user_id),
- )
+ # For an unthreaded receipt, mark the summary for all threads in the room
+ # as cleared.
+ if thread_id is None:
+ self.db_pool.simple_update_txn(
+ txn,
+ table="event_push_summary",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ updatevalues={
+ "notif_count": 0,
+ "unread_count": 0,
+ "stream_ordering": old_rotate_stream_ordering,
+ "last_receipt_stream_ordering": stream_ordering,
+ },
+ )
+
+ # For a threaded receipt, we *always* want to update that receipt,
+ # event if there are no new notifications in that thread. This ensures
+ # the stream_ordering & last_receipt_stream_ordering are updated.
+ elif not unread_counts:
+ unread_counts = [(0, 0, thread_id)]
- # Replace the previous summary with the new counts.
- #
- # TODO(threads): Upsert per-thread instead of setting them all to main.
- self.db_pool.simple_upsert_txn(
+ # Then any updated threads get their notification count and unread
+ # count updated.
+ self.db_pool.simple_update_many_txn(
txn,
table="event_push_summary",
- keyvalues={"room_id": room_id, "user_id": user_id, "thread_id": "main"},
- values={
- "notif_count": notif_count,
- "unread_count": unread_count,
- "stream_ordering": old_rotate_stream_ordering,
- "last_receipt_stream_ordering": stream_ordering,
- },
+ key_names=("room_id", "user_id", "thread_id"),
+ key_values=[(room_id, user_id, row[2]) for row in unread_counts],
+ value_names=(
+ "notif_count",
+ "unread_count",
+ "stream_ordering",
+ "last_receipt_stream_ordering",
+ ),
+ value_values=[
+ (row[0], row[1], old_rotate_stream_ordering, stream_ordering)
+ for row in unread_counts
+ ],
)
# We always update `event_push_summary_last_receipt_stream_id` to
@@ -1257,25 +1551,38 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
"""
+ # Ensure that any new actions have an updated thread_id.
+ if not self._event_push_backfill_thread_id_done:
+ txn.execute(
+ """
+ UPDATE event_push_actions
+ SET thread_id = ?
+ WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
+ """,
+ (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ # XXX Do we need to update summaries here too?
+
# Calculate the new counts that should be upserted into event_push_summary
sql = """
- SELECT user_id, room_id,
+ SELECT user_id, room_id, thread_id,
coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering
FROM (
- SELECT user_id, room_id, count(*) as cnt,
+ SELECT user_id, room_id, thread_id, count(*) as cnt,
max(ea.stream_ordering) as stream_ordering
FROM event_push_actions AS ea
- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
AND (
old.last_receipt_stream_ordering IS NULL
OR old.last_receipt_stream_ordering < ea.stream_ordering
)
AND %s = 1
- GROUP BY user_id, room_id
+ GROUP BY user_id, room_id, thread_id
) AS upd
- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
"""
# First get the count of unread messages.
@@ -1289,11 +1596,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
- summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
+ summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
for row in txn:
- summaries[(row[0], row[1])] = _EventPushSummary(
- unread_count=row[2],
- stream_ordering=row[3],
+ summaries[(row[0], row[1], row[2])] = _EventPushSummary(
+ unread_count=row[3],
+ stream_ordering=row[4],
notif_count=0,
)
@@ -1304,48 +1611,50 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
for row in txn:
- if (row[0], row[1]) in summaries:
- summaries[(row[0], row[1])].notif_count = row[2]
+ if (row[0], row[1], row[2]) in summaries:
+ summaries[(row[0], row[1], row[2])].notif_count = row[3]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
- summaries[(row[0], row[1])] = _EventPushSummary(
+ summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=0,
- stream_ordering=row[3],
- notif_count=row[2],
+ stream_ordering=row[4],
+ notif_count=row[3],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
- # Ensure that any updated threads have an updated thread_id.
- txn.execute_batch(
- """
- UPDATE event_push_summary
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- [("main", room_id, user_id) for user_id, room_id in summaries],
- )
- self.db_pool.simple_update_many_txn(
- txn,
- table="event_push_summary",
- key_names=("user_id", "room_id", "thread_id"),
- key_values=[(user_id, room_id, None) for user_id, room_id in summaries],
- value_names=("thread_id",),
- value_values=[("main",) for _ in summaries],
- )
+ # Ensure that any updated threads have the proper thread_id.
+ if not self._event_push_backfill_thread_id_done:
+ txn.execute_batch(
+ """
+ UPDATE event_push_summary
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ [
+ (MAIN_TIMELINE, room_id, user_id)
+ for user_id, room_id, _ in summaries
+ ],
+ )
- # TODO(threads): Update on a per-thread basis.
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
key_names=("user_id", "room_id", "thread_id"),
- key_values=[(user_id, room_id, "main") for user_id, room_id in summaries],
+ key_values=[
+ (user_id, room_id, thread_id)
+ for user_id, room_id, thread_id in summaries
+ ],
value_names=("notif_count", "unread_count", "stream_ordering"),
value_values=[
- (summary.notif_count, summary.unread_count, summary.stream_ordering)
+ (
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ )
for summary in summaries.values()
],
)
@@ -1356,7 +1665,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
async def _remove_old_push_actions_that_have_rotated(self) -> None:
- """Clear out old push actions that have been summarised."""
+ """
+ Clear out old push actions that have been summarised (and are older than
+ 1 day ago).
+ """
# We want to clear out anything that is older than a day that *has* already
# been rotated.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3e158279..00880bb3 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -35,7 +35,7 @@ import attr
from prometheus_client import Counter
import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
@@ -1616,7 +1616,7 @@ class PersistEventsStore:
)
# Remove from relations table.
- self._handle_redact_relations(txn, event.redacts)
+ self._handle_redact_relations(txn, event.room_id, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
@@ -1866,6 +1866,34 @@ class PersistEventsStore:
},
)
+ if relation.rel_type == RelationTypes.THREAD:
+ # Upsert into the threads table, but only overwrite the value if the
+ # new event is of a later topological order OR if the topological
+ # ordering is equal, but the stream ordering is later.
+ sql = """
+ INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
+ VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (room_id, thread_id)
+ DO UPDATE SET
+ latest_event_id = excluded.latest_event_id,
+ topological_ordering = excluded.topological_ordering,
+ stream_ordering = excluded.stream_ordering
+ WHERE
+ threads.topological_ordering <= excluded.topological_ordering AND
+ threads.stream_ordering < excluded.stream_ordering
+ """
+
+ txn.execute(
+ sql,
+ (
+ event.room_id,
+ relation.parent_id,
+ event.event_id,
+ event.depth,
+ event.internal_metadata.stream_ordering,
+ ),
+ )
+
def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
) -> None:
@@ -1989,35 +2017,48 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,))
def _handle_redact_relations(
- self, txn: LoggingTransaction, redacted_event_id: str
+ self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.
Args:
txn
+ room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted.
"""
- # Fetch the current relation of the event being redacted.
- redacted_relates_to = self.db_pool.simple_select_one_onecol_txn(
+ # Fetch the relation of the event being redacted.
+ row = self.db_pool.simple_select_one_txn(
txn,
table="event_relations",
keyvalues={"event_id": redacted_event_id},
- retcol="relates_to_id",
+ retcols=("relates_to_id", "relation_type"),
allow_none=True,
)
+ # Nothing to do if no relation is found.
+ if row is None:
+ return
+
+ redacted_relates_to = row["relates_to_id"]
+ rel_type = row["relation_type"]
+ self.db_pool.simple_delete_txn(
+ txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
+ )
+
# Any relation information for the related event must be cleared.
- if redacted_relates_to is not None:
- self.store._invalidate_cache_and_stream(
- txn, self.store.get_relations_for_event, (redacted_relates_to,)
- )
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_relations_for_event, (redacted_relates_to,)
+ )
+ if rel_type == RelationTypes.ANNOTATION:
self.store._invalidate_cache_and_stream(
txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
)
+ if rel_type == RelationTypes.REPLACE:
self.store._invalidate_cache_and_stream(
txn, self.store.get_applicable_edit, (redacted_relates_to,)
)
+ if rel_type == RelationTypes.THREAD:
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_summary, (redacted_relates_to,)
)
@@ -2025,14 +2066,41 @@ class PersistEventsStore:
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
self.store._invalidate_cache_and_stream(
- txn,
- self.store.get_mutual_event_relations_for_rel_type,
- (redacted_relates_to,),
+ txn, self.store.get_threads, (room_id,)
)
- self.db_pool.simple_delete_txn(
- txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
- )
+ # Find the new latest event in the thread.
+ sql = """
+ SELECT event_id, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND relation_type = ?
+ ORDER BY topological_ordering DESC, stream_ordering DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD))
+
+ # If a latest event is found, update the threads table, this might
+ # be the same current latest event (if an earlier event in the thread
+ # was redacted).
+ latest_event_row = txn.fetchone()
+ if latest_event_row:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="threads",
+ keyvalues={"room_id": room_id, "thread_id": redacted_relates_to},
+ values={
+ "latest_event_id": latest_event_row[0],
+ "topological_ordering": latest_event_row[1],
+ "stream_ordering": latest_event_row[2],
+ },
+ )
+
+ # Otherwise, delete the thread: it no longer exists.
+ else:
+ self.db_pool.simple_delete_one_txn(
+ txn, table="threads", keyvalues={"thread_id": redacted_relates_to}
+ )
def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("topic"), str):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 7cdc9fe9..69fea452 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -374,7 +374,7 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- The event, or None if the event was not found.
+ The event, or None if the event was not found and allow_none is `True`.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
@@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = await self._get_events_from_cache_or_db(
+ event_entry_map = await self.get_unredacted_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = await self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self.get_unredacted_events_from_cache_or_db(
+ [redacted_event_id]
+ )
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore):
return events
@cancellable
- async def _get_events_from_cache_or_db(
- self, event_ids: Iterable[str], allow_rejected: bool = False
+ async def get_unredacted_events_from_cache_or_db(
+ self,
+ event_ids: Iterable[str],
+ allow_rejected: bool = False,
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
+ Note that the events pulled by this function will not have any redactions
+ applied, and no guarantee is made about the ordering of the events returned.
+
If events are pulled from the database, they will be cached for future lookups.
Unknown events are omitted from the response.
@@ -1495,21 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
a dict {event_id -> bool}
"""
- # if the event cache contains the event, obviously we've seen it.
-
- cache_results = {
- event_id
- for event_id in event_ids
- if await self._get_event_cache.contains((event_id,))
- }
- results = dict.fromkeys(cache_results, True)
- remaining = [
- event_id for event_id in event_ids if event_id not in cache_results
- ]
- if not remaining:
- return results
+ # TODO: We used to query the _get_event_cache here as a fast-path before
+ # hitting the database. For if an event were in the cache, we've presumably
+ # seen it before.
+ #
+ # But this is currently an invalid assumption due to the _get_event_cache
+ # not being invalidated when purging events from a room. The optimisation can
+ # be re-added after https://github.com/matrix-org/synapse/issues/13476
- def have_seen_events_txn(txn: LoggingTransaction) -> None:
+ def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1517,16 +1518,17 @@ class EventsWorkerStore(SQLBaseStore):
sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
- txn.database_engine, "e.event_id", remaining
+ txn.database_engine, "e.event_id", event_ids
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}
# ... and then we can update the results for each key
- results.update({eid: (eid in found_events) for eid in remaining})
+ return {eid: (eid in found_events) for eid in event_ids}
- await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
- return results
+ return await self.db_pool.runInteraction(
+ "have_seen_events", have_seen_events_txn
+ )
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
@@ -1969,12 +1971,17 @@ class EventsWorkerStore(SQLBaseStore):
Args:
room_id: room where the event lives
- event_id: event to check
+ event: event to check (can't be an `outlier`)
Returns:
Boolean indicating whether it's an extremity
"""
+ assert not event.internal_metadata.is_outlier(), (
+ "is_event_next_to_backward_gap(...) can't be used with `outlier` events. "
+ "This function relies on `event_backward_extremities` which won't be filled in for `outliers`."
+ )
+
def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
# If the event in question has any of its prev_events listed as a
# backward extremity, it's next to a gap.
@@ -2024,12 +2031,17 @@ class EventsWorkerStore(SQLBaseStore):
Args:
room_id: room where the event lives
- event_id: event to check
+ event: event to check (can't be an `outlier`)
Returns:
Boolean indicating whether it's an extremity
"""
+ assert not event.internal_metadata.is_outlier(), (
+ "is_event_next_to_forward_gap(...) can't be used with `outlier` events. "
+ "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`."
+ )
+
def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
# If the event in question is a forward extremity, we will just
# consider any potential forward gap as not a gap since it's one of
@@ -2110,13 +2122,33 @@ class EventsWorkerStore(SQLBaseStore):
The closest event_id otherwise None if we can't find any event in
the given direction.
"""
+ if direction == "b":
+ # Find closest event *before* a given timestamp. We use descending
+ # (which gives values largest to smallest) because we want the
+ # largest possible timestamp *before* the given timestamp.
+ comparison_operator = "<="
+ order = "DESC"
+ else:
+ # Find closest event *after* a given timestamp. We use ascending
+ # (which gives values smallest to largest) because we want the
+ # closest possible timestamp *after* the given timestamp.
+ comparison_operator = ">="
+ order = "ASC"
- sql_template = """
+ sql_template = f"""
SELECT event_id FROM events
LEFT JOIN rejections USING (event_id)
WHERE
- origin_server_ts %s ?
- AND room_id = ?
+ room_id = ?
+ AND origin_server_ts {comparison_operator} ?
+ /**
+ * Make sure the event isn't an `outlier` because we have no way
+ * to later check whether it's next to a gap. `outliers` do not
+ * have entries in the `event_edges`, `event_forward_extremeties`,
+ * and `event_backward_extremities` tables to check against
+ * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`).
+ */
+ AND NOT outlier
/* Make sure event is not rejected */
AND rejections.event_id IS NULL
/**
@@ -2126,27 +2158,14 @@ class EventsWorkerStore(SQLBaseStore):
* Finally, we can tie-break based on when it was received on the server
* (`stream_ordering`).
*/
- ORDER BY origin_server_ts %s, depth %s, stream_ordering %s
+ ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order}
LIMIT 1;
"""
def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
- if direction == "b":
- # Find closest event *before* a given timestamp. We use descending
- # (which gives values largest to smallest) because we want the
- # largest possible timestamp *before* the given timestamp.
- comparison_operator = "<="
- order = "DESC"
- else:
- # Find closest event *after* a given timestamp. We use ascending
- # (which gives values smallest to largest) because we want the
- # closest possible timestamp *after* the given timestamp.
- comparison_operator = ">="
- order = "ASC"
-
txn.execute(
- sql_template % (comparison_operator, order, order, order),
- (timestamp, room_id),
+ sql_template,
+ (room_id, timestamp),
)
row = txn.fetchone()
if row:
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index ed17b2e7..51416b22 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -29,7 +29,6 @@ from typing import (
)
from synapse.api.errors import StoreError
-from synapse.config.homeserver import ExperimentalConfig
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -63,9 +62,7 @@ logger = logging.getLogger(__name__)
def _load_rules(
- rawrules: List[JsonDict],
- enabled_map: Dict[str, bool],
- experimental_config: ExperimentalConfig,
+ rawrules: List[JsonDict], enabled_map: Dict[str, bool]
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
@@ -81,16 +78,9 @@ def _load_rules(
for rawrule in rawrules
]
- push_rules = PushRules(
- ruleslist,
- )
+ push_rules = PushRules(ruleslist)
- filtered_rules = FilteredPushRules(
- push_rules,
- enabled_map,
- msc3786_enabled=experimental_config.msc3786_enabled,
- msc3772_enabled=experimental_config.msc3772_enabled,
- )
+ filtered_rules = FilteredPushRules(push_rules, enabled_map)
return filtered_rules
@@ -170,7 +160,7 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- return _load_rules(rows, enabled_map, self.hs.config.experimental)
+ return _load_rules(rows, enabled_map)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
@@ -229,9 +219,7 @@ class PushRulesWorkerStore(
results: Dict[str, FilteredPushRules] = {}
for user_id, rules in raw_rules.items():
- results[user_id] = _load_rules(
- rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
- )
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
return results
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 246f78ac..dc698952 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -418,6 +418,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
+ if row["thread_id"]:
+ receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
results = {
room_id: [results[room_id]] if room_id in results else []
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 898947af..c022510e 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,6 +14,7 @@
import logging
from typing import (
+ TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
@@ -28,19 +29,48 @@ from typing import (
import attr
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThreadsNextBatch:
+ topological_ordering: int
+ stream_ordering: int
+
+ def __str__(self) -> str:
+ return f"{self.topological_ordering}_{self.stream_ordering}"
+
+ @classmethod
+ def from_string(cls, string: str) -> "ThreadsNextBatch":
+ """
+ Creates a ThreadsNextBatch from its textual representation.
+ """
+ try:
+ keys = (int(s) for s in string.split("_"))
+ return cls(*keys)
+ except Exception:
+ raise SynapseError(400, "Invalid threads token")
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _RelatedEvent:
"""
Contains enough information about a related event in order to properly filter
@@ -56,6 +86,76 @@ class _RelatedEvent:
class RelationsWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_update_handler(
+ "threads_backfill", self._backfill_threads
+ )
+
+ async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
+ """Backfill the threads table."""
+
+ def threads_backfill_txn(txn: LoggingTransaction) -> int:
+ last_thread_id = progress.get("last_thread_id", "")
+
+ # Get the latest event in each thread by topo ordering / stream ordering.
+ #
+ # Note that the MAX(event_id) is needed to abide by the rules of group by,
+ # but doesn't actually do anything since there should only be a single event
+ # ID per topo/stream ordering pair.
+ sql = f"""
+ SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id > ? AND
+ relation_type = '{RelationTypes.THREAD}'
+ GROUP BY room_id, relates_to_id
+ ORDER BY relates_to_id
+ LIMIT ?
+ """
+ txn.execute(sql, (last_thread_id, batch_size))
+
+ # No more rows to process.
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ # Insert the rows into the threads table. If a matching thread already exists,
+ # assume it is from a newer event.
+ sql = """
+ INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
+ VALUES %s
+ ON CONFLICT (room_id, thread_id)
+ DO NOTHING
+ """
+ if isinstance(txn.database_engine, PostgresEngine):
+ txn.execute_values(sql % ("?",), rows, fetch=False)
+ else:
+ txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows)
+
+ # Mark the progress.
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
+ )
+
+ return txn.rowcount
+
+ result = await self.db_pool.runInteraction(
+ "threads_backfill", threads_backfill_txn
+ )
+
+ if not result:
+ await self.db_pool.updates._end_background_update("threads_backfill")
+
+ return result
+
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
@@ -779,57 +879,192 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
- @cached(iterable=True)
- async def get_mutual_event_relations_for_rel_type(
- self, event_id: str, relation_type: str
- ) -> Set[Tuple[str, str]]:
- raise NotImplementedError()
+ @cached(tree=True)
+ async def get_threads(
+ self,
+ room_id: str,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ """Get a list of thread IDs, ordered by topological ordering of their
+ latest reply.
+
+ Args:
+ room_id: The room the event belongs to.
+ limit: Only fetch the most recent `limit` threads.
+ from_token: Fetch rows from a previous next_batch, or from the start if None.
+
+ Returns:
+ A tuple of:
+ A list of thread root event IDs.
+
+ The next_batch, if one exists.
+ """
+ # Generate the pagination clause, if necessary.
+ #
+ # Find any threads where the latest reply is equal / before the last
+ # thread's topo ordering and earlier in stream ordering.
+ pagination_clause = ""
+ pagination_args: tuple = ()
+ if from_token:
+ pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
+ pagination_args = (
+ from_token.topological_ordering,
+ from_token.stream_ordering,
+ )
- @cachedList(
- cached_method_name="get_mutual_event_relations_for_rel_type",
- list_name="relation_types",
- )
- async def get_mutual_event_relations(
- self, event_id: str, relation_types: Collection[str]
- ) -> Dict[str, Set[Tuple[str, str]]]:
+ sql = f"""
+ SELECT thread_id, topological_ordering, stream_ordering
+ FROM threads
+ WHERE
+ room_id = ?
+ {pagination_clause}
+ ORDER BY topological_ordering DESC, stream_ordering DESC
+ LIMIT ?
+ """
+
+ def _get_threads_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ txn.execute(sql, (room_id, *pagination_args, limit + 1))
+
+ rows = cast(List[Tuple[str, int, int]], txn.fetchall())
+ thread_ids = [r[0] for r in rows]
+
+ # If there are more events, generate the next pagination key from the
+ # last thread which will be returned.
+ next_token = None
+ if len(thread_ids) > limit:
+ last_topo_id = rows[-2][1]
+ last_stream_id = rows[-2][2]
+ next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
+
+ return thread_ids[:limit], next_token
+
+ return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
+
+ @cached()
+ async def get_thread_id(self, event_id: str) -> str:
"""
- Fetch event metadata for events which related to the same event as the given event.
+ Get the thread ID for an event. This considers multi-level relations,
+ e.g. an annotation to an event which is part of a thread.
+
+ It only searches up the relations tree, i.e. it only searches for events
+ which the given event is related to (and which those events are related
+ to, etc.)
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id(X) considers events B and C as part of thread A.
- If the given event has no relation information, returns an empty dictionary.
+ See also get_thread_id_for_receipts.
Args:
- event_id: The event ID which is targeted by relations.
- relation_types: The relation types to check for mutual relations.
+ event_id: The event ID to fetch the thread ID for.
Returns:
- A dictionary of relation type to:
- A set of tuples of:
- The sender
- The event type
+ The event ID of the root event in the thread, if this event is part
+ of a thread. "main", otherwise.
"""
- rel_type_sql, rel_type_args = make_in_list_sql_clause(
- self.database_engine, "relation_type", relation_types
- )
- sql = f"""
- SELECT DISTINCT relation_type, sender, type FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE relates_to_id = ? AND {rel_type_sql}
+ # Recurse event relations up to the *root* event, then search that chain
+ # of relations for a thread relation. If one is found, the root event is
+ # returned.
+ #
+ # Note that this should only ever find 0 or 1 entries since it is invalid
+ # for an event to have a thread relation to an event which also has a
+ # relation.
+ sql = """
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type, 0 depth
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ WHERE relation_type = 'm.thread'
+ ORDER BY depth DESC
+ LIMIT 1;
"""
- def _get_event_relations(
- txn: LoggingTransaction,
- ) -> Dict[str, Set[Tuple[str, str]]]:
- txn.execute(sql, [event_id] + rel_type_args)
- result: Dict[str, Set[Tuple[str, str]]] = {
- rel_type: set() for rel_type in relation_types
- }
- for rel_type, sender, type in txn.fetchall():
- result[rel_type].add((sender, type))
- return result
+ def _get_thread_id(txn: LoggingTransaction) -> str:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
+
+ return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+
+ @cached()
+ async def get_thread_id_for_receipts(self, event_id: str) -> str:
+ """
+ Get the thread ID for an event by traversing to the top-most related event
+ and confirming any children events form a thread.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part
+ of thread A.
+
+ See also get_thread_id.
+
+ Args:
+ event_id: The event ID to fetch the thread ID for.
+
+ Returns:
+ The event ID of the root event in the thread, if this event is part
+ of a thread. "main", otherwise.
+ """
+
+ # Recurse event relations up to the *root* event, then search for any events
+ # related to that root node for a thread relation. If one is found, the
+ # root event is returned.
+ #
+ # Note that there cannot be thread relations in the middle of the chain since
+ # it is invalid for an event to have a thread relation to an event which also
+ # has a relation.
+ sql = """
+ SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type, 0 depth
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ ORDER BY depth DESC
+ LIMIT 1
+ ), ?) AND relation_type = 'm.thread' LIMIT 1;
+ """
+
+ def _get_related_thread_id(txn: LoggingTransaction) -> str:
+ txn.execute(sql, (event_id, event_id))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
return await self.db_pool.runInteraction(
- "get_event_relations", _get_event_relations
+ "get_related_thread_id", _get_related_thread_id
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7412bce2..7d97f8f6 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -97,6 +97,12 @@ class RoomSortOrder(Enum):
STATE_EVENTS = "state_events"
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PartialStateResyncInfo:
+ joined_via: Optional[str]
+ servers_in_room: List[str] = attr.ib(factory=list)
+
+
class RoomWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -207,21 +213,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None]
- ) -> Tuple[Union[str, None], List[str]]:
+ ) -> Tuple[Union[str, None], list]:
if not room_types:
return None, []
- else:
- # We use None when we want get rooms without a type
- is_null_clause = ""
- if None in room_types:
- is_null_clause = "OR room_type IS NULL"
- room_types = [value for value in room_types if value is not None]
+ # Since None is used to represent a room without a type, care needs to
+ # be taken into account when constructing the where clause.
+ clauses = []
+ args: list = []
+
+ room_types_set = set(room_types)
+
+ # We use None to represent a room without a type.
+ if None in room_types_set:
+ clauses.append("room_type IS NULL")
+ room_types_set.remove(None)
+
+ # If there are other room types, generate the proper clause.
+ if room_types:
list_clause, args = make_in_list_sql_clause(
- self.database_engine, "room_type", room_types
+ self.database_engine, "room_type", room_types_set
)
+ clauses.append(list_clause)
- return f"({list_clause} {is_null_clause})", args
+ return f"({' OR '.join(clauses)})", args
async def count_public_rooms(
self,
@@ -241,14 +256,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
- room_type_clause, args = self._construct_room_type_where_clause(
- search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
- if search_filter
- else None
- )
- room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
- query_args += args
-
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
@@ -268,6 +275,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
UNION SELECT room_id from appservice_room_list
"""
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+ query_args += args
+
sql = f"""
SELECT
COUNT(*)
@@ -1151,17 +1166,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
desc="get_partial_state_servers_at_join",
)
- async def get_partial_state_rooms_and_servers(
+ async def get_partial_state_room_resync_info(
self,
- ) -> Mapping[str, Collection[str]]:
- """Get all rooms containing events with partial state, and the servers known
- to be in the room.
+ ) -> Mapping[str, PartialStateResyncInfo]:
+ """Get all rooms containing events with partial state, and the information
+ needed to restart a "resync" of those rooms.
Returns:
A dictionary of rooms with partial state, with room IDs as keys and
lists of servers in rooms as values.
"""
- room_servers: Dict[str, List[str]] = {}
+ room_servers: Dict[str, PartialStateResyncInfo] = {}
+
+ rows = await self.db_pool.simple_select_list(
+ table="partial_state_rooms",
+ keyvalues={},
+ retcols=("room_id", "joined_via"),
+ desc="get_server_which_served_partial_join",
+ )
+
+ for row in rows:
+ room_id = row["room_id"]
+ joined_via = row["joined_via"]
+ room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
rows = await self.db_pool.simple_select_list(
"partial_state_rooms_servers",
@@ -1173,7 +1200,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
for row in rows:
room_id = row["room_id"]
server_name = row["server_name"]
- room_servers.setdefault(room_id, []).append(server_name)
+ entry = room_servers.get(room_id)
+ if entry is None:
+ # There is a foreign key constraint which enforces that every room_id in
+ # partial_state_rooms_servers appears in partial_state_rooms. So we
+ # expect `entry` to be non-null. (This reasoning fails if we've
+ # partial-joined between the two SELECTs, but this is unlikely to happen
+ # in practice.)
+ continue
+ entry.servers_in_room.append(server_name)
return room_servers
@@ -1818,6 +1853,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id: str,
servers: Collection[str],
device_lists_stream_id: int,
+ joined_via: str,
) -> None:
"""Mark the given room as containing events with partial state.
@@ -1833,6 +1869,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
servers: other servers known to be in the room
device_lists_stream_id: the device_lists stream ID at the time when we first
joined the room.
+ joined_via: the server name we requested a partial join from.
"""
await self.db_pool.runInteraction(
"store_partial_state_room",
@@ -1840,6 +1877,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id,
servers,
device_lists_stream_id,
+ joined_via,
)
def _store_partial_state_room_txn(
@@ -1848,6 +1886,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id: str,
servers: Collection[str],
device_lists_stream_id: int,
+ joined_via: str,
) -> None:
DatabasePool.simple_insert_txn(
txn,
@@ -1857,6 +1896,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
"device_lists_stream_id": device_lists_stream_id,
# To be updated later once the join event is persisted.
"join_event_id": None,
+ "joined_via": joined_via,
},
)
DatabasePool.simple_insert_many_txn(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2337289d..2ed6ad75 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -666,7 +666,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cached_method_name="get_rooms_for_user",
list_name="user_ids",
)
- async def get_rooms_for_users(
+ async def _get_rooms_for_users(
self, user_ids: Collection[str]
) -> Dict[str, FrozenSet[str]]:
"""A batched version of `get_rooms_for_user`.
@@ -697,6 +697,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
+ async def get_rooms_for_users(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[str]]:
+ """A batched wrapper around `_get_rooms_for_users`, to prevent locking
+ other calls to `get_rooms_for_user` for large user lists.
+ """
+ all_user_rooms: Dict[str, FrozenSet[str]] = {}
+
+ # 250 users is pretty arbitrary but the data can be quite large if users
+ # are in many rooms.
+ for user_ids in batch_iter(user_ids, 250):
+ all_user_rooms.update(await self._get_rooms_for_users(user_ids))
+
+ return all_user_rooms
+
@cached(max_entries=10000)
async def does_pair_of_users_share_a_room(
self, user_id: str, other_user_id: str
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 530f04e1..09ce855a 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
)
args.extend(event_filter.related_by_rel_types)
+ if event_filter.rel_types:
+ clauses.append(
+ "(%s)"
+ % " OR ".join(
+ "event_relation.relation_type = ?" for _ in event_filter.rel_types
+ )
+ )
+ args.extend(event_filter.rel_types)
+
+ if event_filter.not_rel_types:
+ clauses.append(
+ "((%s) OR event_relation.relation_type IS NULL)"
+ % " AND ".join(
+ "event_relation.relation_type != ?" for _ in event_filter.not_rel_types
+ )
+ )
+ args.extend(event_filter.not_rel_types)
+
return " AND ".join(clauses), args
@@ -1024,28 +1042,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
- async def get_all_new_events_stream(
- self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
- ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
+ async def get_all_new_event_ids_stream(
+ self,
+ from_id: int,
+ current_id: int,
+ limit: int,
+ ) -> Tuple[int, Dict[str, Optional[int]]]:
"""Get all new events
- Returns all events with from_id < stream_ordering <= current_id.
+ Returns all event ids with from_id < stream_ordering <= current_id.
Args:
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
- get_prev_content: whether to fetch previous event content
Returns:
- A tuple of (next_id, events, event_to_received_ts), where `next_id`
+ A tuple of (next_id, event_to_received_ts), where `next_id`
is the next value to pass as `from_id` (it will either be the
stream_ordering of the last returned event, or, if fewer than `limit`
events were found, the `current_id`). The `event_to_received_ts` is
- a dictionary mapping event ID to the event `received_ts`.
+ a dictionary mapping event ID to the event `received_ts`, sorted by ascending
+ stream_ordering.
"""
- def get_all_new_events_stream_txn(
+ def get_all_new_event_ids_stream_txn(
txn: LoggingTransaction,
) -> Tuple[int, Dict[str, Optional[int]]]:
sql = (
@@ -1070,15 +1091,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, event_to_received_ts
upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
- "get_all_new_events_stream", get_all_new_events_stream_txn
- )
-
- events = await self.get_events_as_list(
- event_to_received_ts.keys(),
- get_prev_content=get_prev_content,
+ "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn
)
- return upper_bound, events, event_to_received_ts
+ return upper_bound, event_to_received_ts
async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions:
@@ -1202,8 +1218,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
`to_token`), or `limit` is zero.
"""
- assert int(limit) >= 0
-
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
@@ -1282,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# Multiple labels could cause the same event to appear multiple times.
needs_distinct = True
- # If there is a filter on relation_senders and relation_types join to the
- # relations table.
+ # If there is a relation_senders and relation_types filter join to the
+ # relations table to get events related to the current event.
if event_filter and (
event_filter.related_by_senders or event_filter.related_by_rel_types
):
@@ -1298,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
"""
+ # If there is a not_rel_types filter join to the relations table to get
+ # the event's relation information.
+ if event_filter and (event_filter.rel_types or event_filter.not_rel_types):
+ join_clause += """
+ LEFT JOIN event_relations AS event_relation USING (event_id)
+ """
+
if needs_distinct:
select_keywords += " DISTINCT"