summaryrefslogtreecommitdiff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-02-22 14:34:35 +0100
committerAndrej Shadura <andrewsh@debian.org>2022-02-22 14:34:35 +0100
commit85ec4b0c69e373dfcc6a8b0ddee58875c84dcc7b (patch)
treef4ecf53009819d048563a09cd6e29c0092134117 /synapse/storage/databases
parentc67eee54fd48a2e25a91e34498ad01c5f902d53c (diff)
New upstream version 1.53.0
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/account_data.py164
-rw-r--r--synapse/storage/databases/main/appservice.py24
-rw-r--r--synapse/storage/databases/main/cache.py16
-rw-r--r--synapse/storage/databases/main/deviceinbox.py279
-rw-r--r--synapse/storage/databases/main/devices.py22
-rw-r--r--synapse/storage/databases/main/event_federation.py325
-rw-r--r--synapse/storage/databases/main/events.py19
-rw-r--r--synapse/storage/databases/main/push_rule.py20
-rw-r--r--synapse/storage/databases/main/relations.py427
9 files changed, 916 insertions, 380 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 5bfa408f..52146aac 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -106,6 +106,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"AccountDataAndTagsChangeCache", account_max
)
+ self.db_pool.updates.register_background_update_handler(
+ "delete_account_data_for_deactivated_users",
+ self._delete_account_data_for_deactivated_users,
+ )
+
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
@@ -549,72 +554,121 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def purge_account_data_for_user(self, user_id: str) -> None:
"""
- Removes the account data for a user.
+ Removes ALL the account data for a user.
+ Intended to be used upon user deactivation.
- This is intended to be used upon user deactivation and also removes any
- derived information from account data (e.g. push rules and ignored users).
+ Also purges the user from the ignored_users cache table
+ and the push_rules cache tables.
+ """
- Args:
- user_id: The user ID to remove data for.
+ await self.db_pool.runInteraction(
+ "purge_account_data_for_user_txn",
+ self._purge_account_data_for_user_txn,
+ user_id,
+ )
+
+ def _purge_account_data_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> None:
"""
+ See `purge_account_data_for_user`.
+ """
+ # Purge from the primary account_data tables.
+ self.db_pool.simple_delete_txn(
+ txn, table="account_data", keyvalues={"user_id": user_id}
+ )
- def purge_account_data_for_user_txn(txn: LoggingTransaction) -> None:
- # Purge from the primary account_data tables.
- self.db_pool.simple_delete_txn(
- txn, table="account_data", keyvalues={"user_id": user_id}
- )
+ self.db_pool.simple_delete_txn(
+ txn, table="room_account_data", keyvalues={"user_id": user_id}
+ )
- self.db_pool.simple_delete_txn(
- txn, table="room_account_data", keyvalues={"user_id": user_id}
- )
+ # Purge from ignored_users where this user is the ignorer.
+ # N.B. We don't purge where this user is the ignoree, because that
+ # interferes with other users' account data.
+ # It's also not this user's data to delete!
+ self.db_pool.simple_delete_txn(
+ txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
+ )
- # Purge from ignored_users where this user is the ignorer.
- # N.B. We don't purge where this user is the ignoree, because that
- # interferes with other users' account data.
- # It's also not this user's data to delete!
- self.db_pool.simple_delete_txn(
- txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
- )
+ # Remove the push rules
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules", keyvalues={"user_name": user_id}
+ )
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules_enable", keyvalues={"user_name": user_id}
+ )
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules_stream", keyvalues={"user_id": user_id}
+ )
- # Remove the push rules
- self.db_pool.simple_delete_txn(
- txn, table="push_rules", keyvalues={"user_name": user_id}
- )
- self.db_pool.simple_delete_txn(
- txn, table="push_rules_enable", keyvalues={"user_name": user_id}
- )
- self.db_pool.simple_delete_txn(
- txn, table="push_rules_stream", keyvalues={"user_id": user_id}
- )
+ # Invalidate caches as appropriate
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_room_and_type, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_global_account_data_by_type_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_room, (user_id,)
+ )
+ self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_push_rules_enabled_for_user, (user_id,)
+ )
+ # This user might be contained in the ignored_by cache for other users,
+ # so we have to invalidate it all.
+ self._invalidate_all_cache_and_stream(txn, self.ignored_by)
- # Invalidate caches as appropriate
- self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_room_and_type, (user_id,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_user, (user_id,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_global_account_data_by_type_for_user, (user_id,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_room, (user_id,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_push_rules_for_user, (user_id,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_push_rules_enabled_for_user, (user_id,)
- )
- # This user might be contained in the ignored_by cache for other users,
- # so we have to invalidate it all.
- self._invalidate_all_cache_and_stream(txn, self.ignored_by)
+ async def _delete_account_data_for_deactivated_users(
+ self, progress: dict, batch_size: int
+ ) -> int:
+ """
+ Retroactively purges account data for users that have already been deactivated.
+ Gets run as a background update caused by a schema delta.
+ """
- await self.db_pool.runInteraction(
- "purge_account_data_for_user_txn",
- purge_account_data_for_user_txn,
+ last_user: str = progress.get("last_user", "")
+
+ def _delete_account_data_for_deactivated_users_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ sql = """
+ SELECT name FROM users
+ WHERE deactivated = ? and name > ?
+ ORDER BY name ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (1, last_user, batch_size))
+ users = [row[0] for row in txn]
+
+ for user in users:
+ self._purge_account_data_for_user_txn(txn, user_id=user)
+
+ if users:
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ "delete_account_data_for_deactivated_users",
+ {"last_user": users[-1]},
+ )
+
+ return len(users)
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_delete_account_data_for_deactivated_users",
+ _delete_account_data_for_deactivated_users_txn,
)
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "delete_account_data_for_deactivated_users"
+ )
+
+ return number_deleted
+
class AccountDataStore(AccountDataWorkerStore):
pass
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 2bb52884..304814af 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict],
+ to_device_messages: List[JsonDict],
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
service: The service who the transaction is for.
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
+ to_device_messages: A list of to-device messages to put in the transaction.
Returns:
A new transaction.
@@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
(service.id, new_txn_id, event_ids),
)
return AppServiceTransaction(
- service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+ service=service,
+ id=new_txn_id,
+ events=events,
+ ephemeral=ephemeral,
+ to_device_messages=to_device_messages,
)
return await self.db_pool.runInteraction(
@@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(
- service=service, id=entry["txn_id"], events=events, ephemeral=[]
+ service=service,
+ id=entry["txn_id"],
+ events=events,
+ ephemeral=[],
+ to_device_messages=[],
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
- if type not in ("read_receipt", "presence"):
+ if type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
)
- async def set_type_stream_id_for_appservice(
+ async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
- if stream_type not in ("read_receipt", "presence"):
+ if stream_type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
)
- def set_type_stream_id_for_appservice_txn(txn):
+ def set_appservice_stream_type_pos_txn(txn):
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
@@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
)
await self.db_pool.runInteraction(
- "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+ "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 00243480..c428dd55 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -25,7 +25,11 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -236,7 +240,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate_all)
self._send_invalidation_to_replication(txn, cache_func.__name__, None)
- def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+ def _invalidate_state_caches_and_stream(
+ self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str]
+ ) -> None:
"""Special case invalidation of caches based on current state.
We special case this so that we can batch the cache invalidations into a
@@ -244,8 +250,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
Args:
txn
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have changed
+ room_id: Room where state changed
+ members_changed: The user_ids of members that have changed
"""
txn.call_after(self._invalidate_state_caches, room_id, members_changed)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4eca9718..1392363d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -24,6 +24,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@@ -136,63 +137,263 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
- async def get_new_messages_for_device(
+ async def get_messages_for_user_devices(
+ self,
+ user_ids: Collection[str],
+ from_stream_id: int,
+ to_stream_id: int,
+ ) -> Dict[Tuple[str, str], List[JsonDict]]:
+ """
+ Retrieve to-device messages for a given set of users.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
+ Args:
+ user_ids: The users to retrieve to-device messages for.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+
+ Returns:
+ A dictionary of (user id, device id) -> list of to-device messages.
+ """
+ # We expect the stream ID returned by _get_device_messages to always
+ # be to_stream_id. So, no need to return it from this function.
+ (
+ user_id_device_id_to_messages,
+ last_processed_stream_id,
+ ) = await self._get_device_messages(
+ user_ids=user_ids,
+ from_stream_id=from_stream_id,
+ to_stream_id=to_stream_id,
+ )
+
+ assert (
+ last_processed_stream_id == to_stream_id
+ ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
+
+ return user_id_device_id_to_messages
+
+ async def get_messages_for_device(
self,
user_id: str,
- device_id: Optional[str],
- last_stream_id: int,
- current_stream_id: int,
+ device_id: str,
+ from_stream_id: int,
+ to_stream_id: int,
limit: int = 100,
- ) -> Tuple[List[dict], int]:
+ ) -> Tuple[List[JsonDict], int]:
"""
+ Retrieve to-device messages for a single user device.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
Args:
- user_id: The recipient user_id.
- device_id: The recipient device_id.
- last_stream_id: The last stream ID checked.
- current_stream_id: The current position of the to device
- message stream.
- limit: The maximum number of messages to retrieve.
+ user_id: The ID of the user to retrieve messages for.
+ device_id: The ID of the device to retrieve to-device messages for.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+ limit: A limit on the number of to-device messages returned.
Returns:
A tuple containing:
- * A list of messages for the device.
- * The max stream token of these messages. There may be more to retrieve
- if the given limit was reached.
+ * A list of to-device messages within the given stream id range intended for
+ the given user / device combo.
+ * The last-processed stream ID. Subsequent calls of this function with the
+ same device should pass this value as 'from_stream_id'.
"""
- has_changed = self._device_inbox_stream_cache.has_entity_changed(
- user_id, last_stream_id
+ (
+ user_id_device_id_to_messages,
+ last_processed_stream_id,
+ ) = await self._get_device_messages(
+ user_ids=[user_id],
+ device_id=device_id,
+ from_stream_id=from_stream_id,
+ to_stream_id=to_stream_id,
+ limit=limit,
)
- if not has_changed:
- return [], current_stream_id
- def get_new_messages_for_device_txn(txn):
- sql = (
- "SELECT stream_id, message_json FROM device_inbox"
- " WHERE user_id = ? AND device_id = ?"
- " AND ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- " LIMIT ?"
+ if not user_id_device_id_to_messages:
+ # There were no messages!
+ return [], to_stream_id
+
+ # Extract the messages, no need to return the user and device ID again
+ to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+
+ return to_device_messages, last_processed_stream_id
+
+ async def _get_device_messages(
+ self,
+ user_ids: Collection[str],
+ from_stream_id: int,
+ to_stream_id: int,
+ device_id: Optional[str] = None,
+ limit: Optional[int] = None,
+ ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
+ """
+ Retrieve pending to-device messages for a collection of user devices.
+
+ Only to-device messages with stream ids between the given boundaries
+ (from < X <= to) are returned.
+
+ Note that a stream ID can be shared by multiple copies of the same message with
+ different recipient devices. Stream IDs are only unique in the context of a single
+ user ID / device ID pair. Thus, applying a limit (of messages to return) when working
+ with a sliding window of stream IDs is only possible when querying messages of a
+ single user device.
+
+ Finally, note that device IDs are not unique across users.
+
+ Args:
+ user_ids: The user IDs to filter device messages by.
+ from_stream_id: The lower boundary of stream id to filter with (exclusive).
+ to_stream_id: The upper boundary of stream id to filter with (inclusive).
+ device_id: A device ID to query to-device messages for. If not provided, to-device
+ messages from all device IDs for the given user IDs will be queried. May not be
+ provided if `user_ids` contains more than one entry.
+ limit: The maximum number of to-device messages to return. Can only be used when
+ passing a single user ID / device ID tuple.
+
+ Returns:
+ A tuple containing:
+ * A dict of (user_id, device_id) -> list of to-device messages
+ * The last-processed stream ID. If this is less than `to_stream_id`, then
+ there may be more messages to retrieve. If `limit` is not set, then this
+ is always equal to 'to_stream_id'.
+ """
+ if not user_ids:
+ logger.warning("No users provided upon querying for device IDs")
+ return {}, to_stream_id
+
+ # Prevent a query for one user's device also retrieving another user's device with
+ # the same device ID (device IDs are not unique across users).
+ if len(user_ids) > 1 and device_id is not None:
+ raise AssertionError(
+ "Programming error: 'device_id' cannot be supplied to "
+ "_get_device_messages when >1 user_id has been provided"
)
- txn.execute(
- sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
+
+ # A limit can only be applied when querying for a single user ID / device ID tuple.
+ # See the docstring of this function for more details.
+ if limit is not None and device_id is None:
+ raise AssertionError(
+ "Programming error: _get_device_messages was passed 'limit' "
+ "without a specific user_id/device_id"
)
- messages = []
- stream_pos = current_stream_id
+ user_ids_to_query: Set[str] = set()
+ device_ids_to_query: Set[str] = set()
+
+ # Note that a device ID could be an empty str
+ if device_id is not None:
+ # If a device ID was passed, use it to filter results.
+ # Otherwise, device IDs will be derived from the given collection of user IDs.
+ device_ids_to_query.add(device_id)
+
+ # Determine which users have devices with pending messages
+ for user_id in user_ids:
+ if self._device_inbox_stream_cache.has_entity_changed(
+ user_id, from_stream_id
+ ):
+ # This user has new messages sent to them. Query messages for them
+ user_ids_to_query.add(user_id)
+
+ def get_device_messages_txn(txn: LoggingTransaction):
+ # Build a query to select messages from any of the given devices that
+ # are between the given stream id bounds.
+
+ # If a list of device IDs was not provided, retrieve all devices IDs
+ # for the given users. We explicitly do not query hidden devices, as
+ # hidden devices should not receive to-device messages.
+ # Note that this is more efficient than just dropping `device_id` from the query,
+ # since device_inbox has an index on `(user_id, device_id, stream_id)`
+ if not device_ids_to_query:
+ user_device_dicts = self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("device_id",),
+ )
- for row in txn:
- stream_pos = row[0]
- messages.append(db_to_json(row[1]))
+ device_ids_to_query.update(
+ {row["device_id"] for row in user_device_dicts}
+ )
- # If the limit was not reached we know that there's no more data for this
- # user/device pair up to current_stream_id.
- if len(messages) < limit:
- stream_pos = current_stream_id
+ if not device_ids_to_query:
+ # We've ended up with no devices to query.
+ return {}, to_stream_id
- return messages, stream_pos
+ # We include both user IDs and device IDs in this query, as we have an index
+ # (device_inbox_user_stream_id) for them.
+ user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
+ self.database_engine, "user_id", user_ids_to_query
+ )
+ (
+ device_id_many_clause_sql,
+ device_id_many_clause_args,
+ ) = make_in_list_sql_clause(
+ self.database_engine, "device_id", device_ids_to_query
+ )
+
+ sql = f"""
+ SELECT stream_id, user_id, device_id, message_json FROM device_inbox
+ WHERE {user_id_many_clause_sql}
+ AND {device_id_many_clause_sql}
+ AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ sql_args = (
+ *user_id_many_clause_args,
+ *device_id_many_clause_args,
+ from_stream_id,
+ to_stream_id,
+ )
+
+ # If a limit was provided, limit the data retrieved from the database
+ if limit is not None:
+ sql += "LIMIT ?"
+ sql_args += (limit,)
+
+ txn.execute(sql, sql_args)
+
+ # Create and fill a dictionary of (user ID, device ID) -> list of messages
+ # intended for each device.
+ last_processed_stream_pos = to_stream_id
+ recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
+ rowcount = 0
+ for row in txn:
+ rowcount += 1
+
+ last_processed_stream_pos = row[0]
+ recipient_user_id = row[1]
+ recipient_device_id = row[2]
+ message_dict = db_to_json(row[3])
+
+ # Store the device details
+ recipient_device_to_messages.setdefault(
+ (recipient_user_id, recipient_device_id), []
+ ).append(message_dict)
+
+ if limit is not None and rowcount == limit:
+ # We ended up bumping up against the message limit. There may be more messages
+ # to retrieve. Return what we have, as well as the last stream position that
+ # was processed.
+ #
+ # The caller is expected to set this as the lower (exclusive) bound
+ # for the next query of this device.
+ return recipient_device_to_messages, last_processed_stream_pos
+
+ # The limit was not reached, thus we know that recipient_device_to_messages
+ # contains all to-device messages for the given device and stream id range.
+ #
+ # We return to_stream_id, which the caller should then provide as the lower
+ # (exclusive) bound on the next query of this device.
+ return recipient_device_to_messages, to_stream_id
return await self.db_pool.runInteraction(
- "get_new_messages_for_device", get_new_messages_for_device_txn
+ "get_device_messages", get_device_messages_txn
)
@trace
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b2a5cd9a..8d845fe9 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1496,13 +1496,23 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
- self, user_id: str, device_ids: Collection[str], hosts: List[str]
- ) -> int:
+ self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+ ) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
+
+ Args:
+ user_id: The ID of the user whose device changed.
+ device_ids: The IDs of any changed devices. If empty, this function will
+ return None.
+ hosts: The remote destinations that should be notified of the change.
+
+ Returns:
+ The maximum stream ID of device list updates that were added to the database, or
+ None if no updates were added.
"""
if not device_ids:
- return
+ return None
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
@@ -1573,11 +1583,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self,
txn: LoggingTransaction,
user_id: str,
- device_ids: Collection[str],
- hosts: List[str],
+ device_ids: Iterable[str],
+ hosts: Collection[str],
stream_ids: List[str],
context: Dict[str, str],
- ):
+ ) -> None:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ca71f073..277e6422 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,9 +16,10 @@ import logging
from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+import attr
from prometheus_client import Counter, Gauge
-from synapse.api.constants import MAX_DEPTH
+from synapse.api.constants import MAX_DEPTH, EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
@@ -60,6 +61,15 @@ pdus_pruned_from_federation_queue = Counter(
logger = logging.getLogger(__name__)
+# All the info we need while iterating the DAG while backfilling
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class BackfillQueueNavigationItem:
+ depth: int
+ stream_ordering: int
+ event_id: str
+ type: str
+
+
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -74,6 +84,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
):
super().__init__(database, db_conn, hs)
+ self.hs = hs
+
if hs.config.worker.run_background_tasks:
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
@@ -109,7 +121,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
- ) -> List[str]:
+ ) -> Set[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
@@ -118,7 +130,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
include_given: include the given events in result
Returns:
- list of event_ids
+ set of event_ids
"""
# Check if we have indexed the room so we can use the chain cover
@@ -147,7 +159,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
- ) -> List[str]:
+ ) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""
# First we look up the chain ID/sequence numbers for the given events.
@@ -260,11 +272,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)
- return list(results)
+ return results
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
- ) -> List[str]:
+ ) -> Set[str]:
"""Calculates the auth chain IDs.
This is used when we don't have a cover index for the room.
@@ -319,7 +331,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = new_front
results.update(front)
- return list(results)
+ return results
async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
@@ -737,7 +749,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
- async def get_insertion_event_backwards_extremities_in_room(
+ async def get_insertion_event_backward_extremities_in_room(
self, room_id
) -> Dict[str, int]:
"""Get the insertion events we know about that we haven't backfilled yet.
@@ -754,7 +766,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Map from event_id to depth
"""
- def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id):
+ def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
sql = """
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
/* We only want insertion events that are also marked as backwards extremities */
@@ -770,8 +782,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return dict(txn)
return await self.db_pool.runInteraction(
- "get_insertion_event_backwards_extremities_in_room",
- get_insertion_event_backwards_extremities_in_room_txn,
+ "get_insertion_event_backward_extremities_in_room",
+ get_insertion_event_backward_extremities_in_room_txn,
room_id,
)
@@ -997,143 +1009,242 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
- async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
- """Get a list of Events for a given topic that occurred before (and
- including) the events in event_list. Return a list of max size `limit`
+ def _get_connected_batch_event_backfill_results_txn(
+ self, txn: LoggingTransaction, insertion_event_id: str, limit: int
+ ) -> List[BackfillQueueNavigationItem]:
+ """
+ Find any batch connections of a given insertion event.
+ A batch event points at a insertion event via:
+ batch_event.content[MSC2716_BATCH_ID] -> insertion_event.content[MSC2716_NEXT_BATCH_ID]
Args:
- room_id
- event_list
- limit
+ txn: The database transaction to use
+ insertion_event_id: The event ID to navigate from. We will find
+ batch events that point back at this insertion event.
+ limit: Max number of event ID's to query for and return
+
+ Returns:
+ List of batch events that the backfill queue can process
+ """
+ batch_connection_query = """
+ SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i
+ /* Find the batch that connects to the given insertion event */
+ INNER JOIN batch_events AS c
+ ON i.next_batch_id = c.batch_id
+ /* Get the depth of the batch start event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ /* Find an insertion event which matches the given event_id */
+ WHERE i.event_id = ?
+ LIMIT ?
"""
- event_ids = await self.db_pool.runInteraction(
- "get_backfill_events",
- self._get_backfill_events,
- room_id,
- event_list,
- limit,
- )
- events = await self.get_events_as_list(event_ids)
- return sorted(events, key=lambda e: -e.depth)
- def _get_backfill_events(self, txn, room_id, event_list, limit):
- logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
+ # Find any batch connections for the given insertion event
+ txn.execute(
+ batch_connection_query,
+ (insertion_event_id, limit),
+ )
+ return [
+ BackfillQueueNavigationItem(
+ depth=row[0],
+ stream_ordering=row[1],
+ event_id=row[2],
+ type=row[3],
+ )
+ for row in txn
+ ]
- event_results = set()
+ def _get_connected_prev_event_backfill_results_txn(
+ self, txn: LoggingTransaction, event_id: str, limit: int
+ ) -> List[BackfillQueueNavigationItem]:
+ """
+ Find any events connected by prev_event the specified event_id.
- # We want to make sure that we do a breadth-first, "depth" ordered
- # search.
+ Args:
+ txn: The database transaction to use
+ event_id: The event ID to navigate from
+ limit: Max number of event ID's to query for and return
+ Returns:
+ List of prev events that the backfill queue can process
+ """
# Look for the prev_event_id connected to the given event_id
- query = """
- SELECT depth, prev_event_id FROM event_edges
- /* Get the depth of the prev_event_id from the events table */
+ connected_prev_event_query = """
+ SELECT depth, stream_ordering, prev_event_id, events.type FROM event_edges
+ /* Get the depth and stream_ordering of the prev_event_id from the events table */
INNER JOIN events
ON prev_event_id = events.event_id
- /* Find an event which matches the given event_id */
+ /* Look for an edge which matches the given event_id */
WHERE event_edges.event_id = ?
AND event_edges.is_state = ?
+ /* Because we can have many events at the same depth,
+ * we want to also tie-break and sort on stream_ordering */
+ ORDER BY depth DESC, stream_ordering DESC
LIMIT ?
"""
- # Look for the "insertion" events connected to the given event_id
- connected_insertion_event_query = """
- SELECT e.depth, i.event_id FROM insertion_event_edges AS i
- /* Get the depth of the insertion event from the events table */
- INNER JOIN events AS e USING (event_id)
- /* Find an insertion event which points via prev_events to the given event_id */
- WHERE i.insertion_prev_event_id = ?
- LIMIT ?
+ txn.execute(
+ connected_prev_event_query,
+ (event_id, False, limit),
+ )
+ return [
+ BackfillQueueNavigationItem(
+ depth=row[0],
+ stream_ordering=row[1],
+ event_id=row[2],
+ type=row[3],
+ )
+ for row in txn
+ ]
+
+ async def get_backfill_events(
+ self, room_id: str, seed_event_id_list: list, limit: int
+ ):
+ """Get a list of Events for a given topic that occurred before (and
+ including) the events in seed_event_id_list. Return a list of max size `limit`
+
+ Args:
+ room_id
+ seed_event_id_list
+ limit
"""
+ event_ids = await self.db_pool.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events,
+ room_id,
+ seed_event_id_list,
+ limit,
+ )
+ events = await self.get_events_as_list(event_ids)
+ return sorted(
+ events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+ )
- # Find any batch connections of a given insertion event
- batch_connection_query = """
- SELECT e.depth, c.event_id FROM insertion_events AS i
- /* Find the batch that connects to the given insertion event */
- INNER JOIN batch_events AS c
- ON i.next_batch_id = c.batch_id
- /* Get the depth of the batch start event from the events table */
- INNER JOIN events AS e USING (event_id)
- /* Find an insertion event which matches the given event_id */
- WHERE i.event_id = ?
- LIMIT ?
+ def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
+ """
+ We want to make sure that we do a breadth-first, "depth" ordered search.
+ We also handle navigating historical branches of history connected by
+ insertion and batch events.
"""
+ logger.debug(
+ "_get_backfill_events(room_id=%s): seeding backfill with seed_event_id_list=%s limit=%s",
+ room_id,
+ seed_event_id_list,
+ limit,
+ )
+
+ event_id_results = set()
# In a PriorityQueue, the lowest valued entries are retrieved first.
- # We're using depth as the priority in the queue.
- # Depth is lowest at the oldest-in-time message and highest and
- # newest-in-time message. We add events to the queue with a negative depth so that
- # we process the newest-in-time messages first going backwards in time.
+ # We're using depth as the priority in the queue and tie-break based on
+ # stream_ordering. Depth is lowest at the oldest-in-time message and
+ # highest and newest-in-time message. We add events to the queue with a
+ # negative depth so that we process the newest-in-time messages first
+ # going backwards in time. stream_ordering follows the same pattern.
queue = PriorityQueue()
- for event_id in event_list:
- depth = self.db_pool.simple_select_one_onecol_txn(
+ for seed_event_id in seed_event_id_list:
+ event_lookup_result = self.db_pool.simple_select_one_txn(
txn,
table="events",
- keyvalues={"event_id": event_id, "room_id": room_id},
- retcol="depth",
+ keyvalues={"event_id": seed_event_id, "room_id": room_id},
+ retcols=(
+ "type",
+ "depth",
+ "stream_ordering",
+ ),
allow_none=True,
)
- if depth:
- queue.put((-depth, event_id))
+ if event_lookup_result is not None:
+ logger.debug(
+ "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
+ room_id,
+ seed_event_id,
+ event_lookup_result["depth"],
+ event_lookup_result["stream_ordering"],
+ event_lookup_result["type"],
+ )
+
+ if event_lookup_result["depth"]:
+ queue.put(
+ (
+ -event_lookup_result["depth"],
+ -event_lookup_result["stream_ordering"],
+ seed_event_id,
+ event_lookup_result["type"],
+ )
+ )
- while not queue.empty() and len(event_results) < limit:
+ while not queue.empty() and len(event_id_results) < limit:
try:
- _, event_id = queue.get_nowait()
+ _, _, event_id, event_type = queue.get_nowait()
except Empty:
break
- if event_id in event_results:
+ if event_id in event_id_results:
continue
- event_results.add(event_id)
+ event_id_results.add(event_id)
# Try and find any potential historical batches of message history.
- #
- # First we look for an insertion event connected to the current
- # event (by prev_event). If we find any, we need to go and try to
- # find any batch events connected to the insertion event (by
- # batch_id). If we find any, we'll add them to the queue and
- # navigate up the DAG like normal in the next iteration of the loop.
- txn.execute(
- connected_insertion_event_query, (event_id, limit - len(event_results))
- )
- connected_insertion_event_id_results = txn.fetchall()
- logger.debug(
- "_get_backfill_events: connected_insertion_event_query %s",
- connected_insertion_event_id_results,
- )
- for row in connected_insertion_event_id_results:
- connected_insertion_event_depth = row[0]
- connected_insertion_event = row[1]
- queue.put((-connected_insertion_event_depth, connected_insertion_event))
+ if self.hs.config.experimental.msc2716_enabled:
+ # We need to go and try to find any batch events connected
+ # to a given insertion event (by batch_id). If we find any, we'll
+ # add them to the queue and navigate up the DAG like normal in the
+ # next iteration of the loop.
+ if event_type == EventTypes.MSC2716_INSERTION:
+ # Find any batch connections for the given insertion event
+ connected_batch_event_backfill_results = (
+ self._get_connected_batch_event_backfill_results_txn(
+ txn, event_id, limit - len(event_id_results)
+ )
+ )
+ logger.debug(
+ "_get_backfill_events(room_id=%s): connected_batch_event_backfill_results=%s",
+ room_id,
+ connected_batch_event_backfill_results,
+ )
+ for (
+ connected_batch_event_backfill_item
+ ) in connected_batch_event_backfill_results:
+ if (
+ connected_batch_event_backfill_item.event_id
+ not in event_id_results
+ ):
+ queue.put(
+ (
+ -connected_batch_event_backfill_item.depth,
+ -connected_batch_event_backfill_item.stream_ordering,
+ connected_batch_event_backfill_item.event_id,
+ connected_batch_event_backfill_item.type,
+ )
+ )
- # Find any batch connections for the given insertion event
- txn.execute(
- batch_connection_query,
- (connected_insertion_event, limit - len(event_results)),
- )
- batch_start_event_id_results = txn.fetchall()
- logger.debug(
- "_get_backfill_events: batch_start_event_id_results %s",
- batch_start_event_id_results,
+ # Now we just look up the DAG by prev_events as normal
+ connected_prev_event_backfill_results = (
+ self._get_connected_prev_event_backfill_results_txn(
+ txn, event_id, limit - len(event_id_results)
)
- for row in batch_start_event_id_results:
- if row[1] not in event_results:
- queue.put((-row[0], row[1]))
-
- txn.execute(query, (event_id, False, limit - len(event_results)))
- prev_event_id_results = txn.fetchall()
+ )
logger.debug(
- "_get_backfill_events: prev_event_ids %s", prev_event_id_results
+ "_get_backfill_events(room_id=%s): connected_prev_event_backfill_results=%s",
+ room_id,
+ connected_prev_event_backfill_results,
)
+ for (
+ connected_prev_event_backfill_item
+ ) in connected_prev_event_backfill_results:
+ if connected_prev_event_backfill_item.event_id not in event_id_results:
+ queue.put(
+ (
+ -connected_prev_event_backfill_item.depth,
+ -connected_prev_event_backfill_item.stream_ordering,
+ connected_prev_event_backfill_item.event_id,
+ connected_prev_event_backfill_item.type,
+ )
+ )
- for row in prev_event_id_results:
- if row[1] not in event_results:
- queue.put((-row[0], row[1]))
-
- return event_results
+ return event_id_results
async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b7554154..5246fcca 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1801,9 +1801,7 @@ class PersistEventsStore:
)
if rel_type == RelationTypes.REPLACE:
- txn.call_after(
- self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
- )
+ txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
if rel_type == RelationTypes.THREAD:
txn.call_after(
@@ -1814,7 +1812,7 @@ class PersistEventsStore:
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
- (parent_id, event.room_id, event.sender),
+ (parent_id, event.sender),
)
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
@@ -2215,9 +2213,14 @@ class PersistEventsStore:
" SELECT 1 FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
" )"
+ # 1. Don't add an event as a extremity again if we already persisted it
+ # as a non-outlier.
+ # 2. Don't add an outlier as an extremity if it has no prev_events
" AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
+ " SELECT 1 FROM events"
+ " LEFT JOIN event_edges edge"
+ " ON edge.event_id = events.event_id"
+ " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = ? OR edge.event_id IS NULL)"
" )"
)
@@ -2243,6 +2246,10 @@ class PersistEventsStore:
(ev.event_id, ev.room_id)
for ev in events
if not ev.internal_metadata.is_outlier()
+ # If we encountered an event with no prev_events, then we might
+ # as well remove it now because it won't ever have anything else
+ # to backfill from.
+ or len(ev.prev_event_ids()) == 0
],
)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e01c9493..92539f5d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _load_rules(rawrules, enabled_map, use_new_defaults=False):
+def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@@ -52,7 +52,7 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
- rules = list(list_with_base_rules(ruleslist, use_new_defaults))
+ rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule["rule_id"]
@@ -112,10 +112,6 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
- self._users_new_default_push_rules = (
- hs.config.server.users_new_default_push_rules
- )
-
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
@@ -145,9 +141,7 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- use_new_defaults = user_id in self._users_new_default_push_rules
-
- return _load_rules(rows, enabled_map, use_new_defaults)
+ return _load_rules(rows, enabled_map)
@cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
@@ -206,13 +200,7 @@ class PushRulesWorkerStore(
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
- use_new_defaults = user_id in self._users_new_default_push_rules
-
- results[user_id] = _load_rules(
- rules,
- enabled_map_by_user.get(user_id, {}),
- use_new_defaults,
- )
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
return results
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 37468a51..e2c27e59 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,12 +13,23 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+ cast,
+)
import attr
from frozendict import frozendict
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -28,16 +39,14 @@ from synapse.storage.database import (
make_in_list_sql_clause,
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
-from synapse.storage.relations import (
- AggregationPaginationToken,
- PaginationChunk,
- RelationPaginationToken,
-)
-from synapse.types import JsonDict
-from synapse.util.caches.descriptors import cached
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -87,8 +96,8 @@ class RelationsWorkerStore(SQLBaseStore):
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
- from_token: Optional[RelationPaginationToken] = None,
- to_token: Optional[RelationPaginationToken] = None,
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.
@@ -127,8 +136,10 @@ class RelationsWorkerStore(SQLBaseStore):
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
- to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
+ from_token=from_token.room_key.as_historical_tuple()
+ if from_token
+ else None,
+ to_token=to_token.room_key.as_historical_tuple() if to_token else None,
engine=self.database_engine,
)
@@ -166,12 +177,27 @@ class RelationsWorkerStore(SQLBaseStore):
last_topo_id = row[1]
last_stream_id = row[2]
- next_batch = None
+ # If there are more events, generate the next pagination key.
+ next_token = None
if len(events) > limit and last_topo_id and last_stream_id:
- next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+ next_key = RoomStreamToken(last_topo_id, last_stream_id)
+ if from_token:
+ next_token = from_token.copy_and_replace("room_key", next_key)
+ else:
+ next_token = StreamToken(
+ room_key=next_key,
+ presence_key=0,
+ typing_key=0,
+ receipt_key=0,
+ account_data_key=0,
+ push_rules_key=0,
+ to_device_key=0,
+ device_list_key=0,
+ groups_key=0,
+ )
return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
)
return await self.db_pool.runInteraction(
@@ -340,20 +366,24 @@ class RelationsWorkerStore(SQLBaseStore):
)
@cached()
- async def get_applicable_edit(
- self, event_id: str, room_id: str
- ) -> Optional[EventBase]:
+ def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
+ async def _get_applicable_edits(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
- event.
+ events.
Correctly handles checking whether edits were allowed to happen.
Args:
- event_id: The original event ID
- room_id: The original event's room ID
+ event_ids: The original event IDs
Returns:
- The most recent edit, if any.
+ A map of the most recent edit for each event. If there are no edits,
+ the event will map to None.
"""
# We only allow edits for `m.room.message` events that have the same sender
@@ -362,139 +392,238 @@ class RelationsWorkerStore(SQLBaseStore):
# Fetches latest edit that has the same type and sender as the
# original, and is an `m.room.message`.
- sql = """
- SELECT edit.event_id FROM events AS edit
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS original ON
- original.event_id = relates_to_id
- AND edit.type = original.type
- AND edit.sender = original.sender
- WHERE
- relates_to_id = ?
- AND relation_type = ?
- AND edit.room_id = ?
- AND edit.type = 'm.room.message'
- ORDER by edit.origin_server_ts DESC, edit.event_id DESC
- LIMIT 1
- """
+ if isinstance(self.database_engine, PostgresEngine):
+ # The `DISTINCT ON` clause will pick the *first* row it encounters,
+ # so ordering by origin server ts + event ID desc will ensure we get
+ # the latest edit.
+ sql = """
+ SELECT DISTINCT ON (original.event_id) original.event_id, edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ AND edit.room_id = original.room_id
+ WHERE
+ %s
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC
+ """
+ else:
+ # SQLite uses a simplified query which returns all edits for an
+ # original event. The results are then de-duplicated when turned into
+ # a dict. Due to the chosen ordering, the latest edit stomps on
+ # earlier edits.
+ sql = """
+ SELECT original.event_id, edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ AND edit.room_id = original.room_id
+ WHERE
+ %s
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts, edit.event_id
+ """
- def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
- txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
- row = txn.fetchone()
- if row:
- return row[0]
- return None
+ def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]:
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+ args.append(RelationTypes.REPLACE)
- edit_id = await self.db_pool.runInteraction(
- "get_applicable_edit", _get_applicable_edit_txn
+ txn.execute(sql % (clause,), args)
+ return dict(cast(Iterable[Tuple[str, str]], txn.fetchall()))
+
+ edit_ids = await self.db_pool.runInteraction(
+ "get_applicable_edits", _get_applicable_edits_txn
)
- if not edit_id:
- return None
+ edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined]
- return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
+ # Map to the original event IDs to the edit events.
+ #
+ # There might not be an edit event due to there being no edits or
+ # due to the event not being known, either case is treated the same.
+ return {
+ original_event_id: edits.get(edit_ids.get(original_event_id))
+ for original_event_id in event_ids
+ }
@cached()
- async def get_thread_summary(
- self, event_id: str, room_id: str
- ) -> Tuple[int, Optional[EventBase]]:
+ def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
+ async def _get_thread_summaries(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
- event_id: Summarize the thread related to this event ID.
- room_id: The room the event belongs to.
+ event_ids: Summarize the thread related to this event ID.
Returns:
- The number of items in the thread and the most recent response, if any.
+ A map of the thread summary each event. A missing event implies there
+ are no threaded replies.
+
+ Each summary includes the number of items in the thread and the most
+ recent response.
"""
- def _get_thread_summary_txn(
+ def _get_thread_summaries_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, Optional[str]]:
- # Fetch the latest event ID in the thread.
+ ) -> Tuple[Dict[str, int], Dict[str, str]]:
+ # Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
- sql = """
- SELECT event_id
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE
- relates_to_id = ?
- AND room_id = ?
- AND relation_type = ?
- ORDER BY topological_ordering DESC, stream_ordering DESC
- LIMIT 1
- """
+ if isinstance(self.database_engine, PostgresEngine):
+ # The `DISTINCT ON` clause will pick the *first* row it encounters,
+ # so ordering by topologica ordering + stream ordering desc will
+ # ensure we get the latest event in the thread.
+ sql = """
+ SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
+ WHERE
+ %s
+ AND relation_type = ?
+ ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
+ """
+ else:
+ # SQLite uses a simplified query which returns all entries for a
+ # thread. The first result for each thread is chosen to and subsequent
+ # results for a thread are ignored.
+ sql = """
+ SELECT parent.event_id, child.event_id FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
+ WHERE
+ %s
+ AND relation_type = ?
+ ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
+ """
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+ args.append(RelationTypes.THREAD)
- txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
- row = txn.fetchone()
- if row is None:
- return 0, None
+ txn.execute(sql % (clause,), args)
+ latest_event_ids = {}
+ for parent_event_id, child_event_id in txn:
+ # Only consider the latest threaded reply (by topological ordering).
+ if parent_event_id not in latest_event_ids:
+ latest_event_ids[parent_event_id] = child_event_id
- latest_event_id = row[0]
+ # If no threads were found, bail.
+ if not latest_event_ids:
+ return {}, latest_event_ids
# Fetch the number of threaded replies.
sql = """
- SELECT COUNT(event_id)
- FROM event_relations
- INNER JOIN events USING (event_id)
+ SELECT parent.event_id, COUNT(child.event_id) FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
WHERE
- relates_to_id = ?
- AND room_id = ?
+ %s
AND relation_type = ?
+ GROUP BY parent.event_id
"""
- txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
- count = cast(Tuple[int], txn.fetchone())[0]
- return count, latest_event_id
+ # Regenerate the arguments since only threads found above could
+ # possibly have any replies.
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", latest_event_ids.keys()
+ )
+ args.append(RelationTypes.THREAD)
- count, latest_event_id = await self.db_pool.runInteraction(
- "get_thread_summary", _get_thread_summary_txn
+ txn.execute(sql % (clause,), args)
+ counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
+
+ return counts, latest_event_ids
+
+ counts, latest_event_ids = await self.db_pool.runInteraction(
+ "get_thread_summaries", _get_thread_summaries_txn
)
- latest_event = None
- if latest_event_id:
- latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
+ latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
- return count, latest_event
+ # Map to the event IDs to the thread summary.
+ #
+ # There might not be a summary due to there not being a thread or
+ # due to the latest event not being known, either case is treated the same.
+ summaries = {}
+ for parent_event_id, latest_event_id in latest_event_ids.items():
+ latest_event = latest_events.get(latest_event_id)
+
+ summary = None
+ if latest_event:
+ summary = (counts[parent_event_id], latest_event)
+ summaries[parent_event_id] = summary
+
+ return summaries
@cached()
- async def get_thread_participated(
- self, event_id: str, room_id: str, user_id: str
- ) -> bool:
- """Get whether the requesting user participated in a thread.
+ def get_thread_participated(self, event_id: str, user_id: str) -> bool:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
+ async def _get_threads_participated(
+ self, event_ids: Collection[str], user_id: str
+ ) -> Dict[str, bool]:
+ """Get whether the requesting user participated in the given threads.
- This is separate from get_thread_summary since that can be cached across
- all users while this value is specific to the requeser.
+ This is separate from get_thread_summaries since that can be cached across
+ all users while this value is specific to the requester.
Args:
- event_id: The thread related to this event ID.
- room_id: The room the event belongs to.
+ event_ids: The thread related to these event IDs.
user_id: The user requesting the summary.
Returns:
- True if the requesting user participated in the thread, otherwise false.
+ A map of event ID to a boolean which represents if the requesting
+ user participated in that event's thread, otherwise false.
"""
- def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+ def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
# Fetch whether the requester has participated or not.
sql = """
- SELECT 1
- FROM event_relations
- INNER JOIN events USING (event_id)
+ SELECT DISTINCT relates_to_id
+ FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
WHERE
- relates_to_id = ?
- AND room_id = ?
+ %s
AND relation_type = ?
- AND sender = ?
+ AND child.sender = ?
"""
- txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
- return bool(txn.fetchone())
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+ args.extend((RelationTypes.THREAD, user_id))
- return await self.db_pool.runInteraction(
+ txn.execute(sql % (clause,), args)
+ return {row[0] for row in txn.fetchall()}
+
+ participated_threads = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)
+ return {event_id: event_id in participated_threads for event_id in event_ids}
+
async def events_have_relations(
self,
parent_ids: List[str],
@@ -612,9 +741,6 @@ class RelationsWorkerStore(SQLBaseStore):
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
- # State events and redacted events do not get bundled aggregations.
- if event.is_state() or event.internal_metadata.is_redacted():
- return None
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
@@ -634,43 +760,21 @@ class RelationsWorkerStore(SQLBaseStore):
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
- aggregations.annotations = annotations.to_dict()
+ aggregations.annotations = await annotations.to_dict(
+ cast("DataStore", self)
+ )
references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
- aggregations.references = references.to_dict()
-
- edit = None
- if event.type == EventTypes.Message:
- edit = await self.get_applicable_edit(event_id, room_id)
-
- if edit:
- aggregations.replace = edit
-
- # If this event is the start of a thread, include a summary of the replies.
- if self._msc3440_enabled:
- thread_count, latest_thread_event = await self.get_thread_summary(
- event_id, room_id
- )
- participated = await self.get_thread_participated(
- event_id, room_id, user_id
- )
- if latest_thread_event:
- aggregations.thread = _ThreadAggregation(
- latest_event=latest_thread_event,
- count=thread_count,
- current_user_participated=participated,
- )
+ aggregations.references = await references.to_dict(cast("DataStore", self))
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
- self,
- events: Iterable[EventBase],
- user_id: str,
+ self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
@@ -682,14 +786,59 @@ class RelationsWorkerStore(SQLBaseStore):
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
+ # The already processed event IDs. Tracked separately from the result
+ # since the result omits events which do not have bundled aggregations.
+ seen_event_ids = set()
- # TODO Parallelize.
- results = {}
+ # State events and redacted events do not get bundled aggregations.
+ events = [
+ event
+ for event in events
+ if not event.is_state() and not event.internal_metadata.is_redacted()
+ ]
+
+ # event ID -> bundled aggregation in non-serialized form.
+ results: Dict[str, BundledAggregations] = {}
+
+ # Fetch other relations per event.
for event in events:
+ # De-duplicate events by ID to handle the same event requested multiple
+ # times. The caches that _get_bundled_aggregation_for_event use should
+ # capture this, but best to reduce work.
+ if event.event_id in seen_event_ids:
+ continue
+ seen_event_ids.add(event.event_id)
+
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
+ # Fetch any edits.
+ edits = await self._get_applicable_edits(seen_event_ids)
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
+ # Fetch thread summaries.
+ if self._msc3440_enabled:
+ summaries = await self._get_thread_summaries(seen_event_ids)
+ # Only fetch participated for a limited selection based on what had
+ # summaries.
+ participated = await self._get_threads_participated(
+ summaries.keys(), user_id
+ )
+ for event_id, summary in summaries.items():
+ if summary:
+ thread_count, latest_thread_event = summary
+ results.setdefault(
+ event_id, BundledAggregations()
+ ).thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=participated[event_id],
+ )
+
return results