diff options
author | Andrej Shadura <andrewsh@debian.org> | 2022-04-22 20:34:38 +0200 |
---|---|---|
committer | Andrej Shadura <andrewsh@debian.org> | 2022-04-22 20:34:38 +0200 |
commit | 1b9b92888056ce7fd1f3a010ca7afd5c3963d44e (patch) | |
tree | 716cb361eb4332eca7b147f35c87ceb29b3ac958 /synapse/storage/databases | |
parent | 02eb467b57bf21597094de52232c93d5f4a38b7d (diff) |
New upstream version 1.57.1
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/__init__.py | 20 | ||||
-rw-r--r-- | synapse/storage/databases/main/appservice.py | 107 | ||||
-rw-r--r-- | synapse/storage/databases/main/client_ips.py | 167 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 315 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 8 | ||||
-rw-r--r-- | synapse/storage/databases/main/monthly_active_users.py | 60 | ||||
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 13 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 157 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 227 | ||||
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 23 | ||||
-rw-r--r-- | synapse/storage/databases/main/signatures.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/state.py | 32 | ||||
-rw-r--r-- | synapse/storage/databases/main/stream.py | 70 | ||||
-rw-r--r-- | synapse/storage/databases/main/tags.py | 4 |
15 files changed, 818 insertions, 397 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index f024761b..951031af 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -33,7 +33,7 @@ from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore from .cache import CacheInvalidationWorkerStore from .censor_events import CensorEventsStore -from .client_ips import ClientIpStore +from .client_ips import ClientIpWorkerStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore from .directory import DirectoryStore @@ -49,7 +49,7 @@ from .keys import KeyStore from .lock import LockStore from .media_repository import MediaRepositoryStore from .metrics import ServerMetricsStore -from .monthly_active_users import MonthlyActiveUsersStore +from .monthly_active_users import MonthlyActiveUsersWorkerStore from .openid import OpenIdStore from .presence import PresenceStore from .profile import ProfileStore @@ -112,13 +112,13 @@ class DataStore( AccountDataStore, EventPushActionsStore, OpenIdStore, - ClientIpStore, + ClientIpWorkerStore, DeviceStore, DeviceInboxStore, UserDirectoryStore, GroupServerStore, UserErasureStore, - MonthlyActiveUsersStore, + MonthlyActiveUsersWorkerStore, StatsStore, RelationsStore, CensorEventsStore, @@ -146,6 +146,7 @@ class DataStore( extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) @@ -182,17 +183,6 @@ class DataStore( super().__init__(database, db_conn, hs) - device_list_max = self._device_list_id_gen.get_current_token() - self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max - ) - self._user_signature_stream_cache = StreamChangeCache( - "UserSignatureStreamChangeCache", device_list_max - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max - ) - events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 06944465..fa732edc 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple from synapse.appservice import ( ApplicationService, @@ -26,10 +26,16 @@ from synapse.appservice import ( from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -72,9 +78,25 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + def get_max_as_txn_id(txn: Cursor) -> int: + logger.warning("Falling back to slow query, you should port to postgres") + txn.execute( + "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns" + ) + return txn.fetchone()[0] # type: ignore + + self._as_txn_seq_gen = build_sequence_generator( + db_conn, + database.engine, + get_max_as_txn_id, + "application_services_txn_id_seq", + table="application_services_txns", + id_column="txn_id", + ) + super().__init__(database, db_conn, hs) - def get_app_services(self): + def get_app_services(self) -> List[ApplicationService]: return self.services_cache def get_if_app_services_interested_in_user(self, user_id: str) -> bool: @@ -217,6 +239,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -231,27 +254,14 @@ class ApplicationServiceTransactionWorkerStore( appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. Returns: A new transaction. """ - def _create_appservice_txn(txn): - # work out new txn id (highest txn id for this service += 1) - # The highest id may be the last one sent (in which case it is last_txn) - # or it may be the highest in the txns list (which are waiting to be/are - # being sent) - last_txn_id = self._get_last_txn(txn, service.id) - - txn.execute( - "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", - (service.id,), - ) - highest_txn_id = txn.fetchone()[0] - if highest_txn_id is None: - highest_txn_id = 0 - - new_txn_id = max(highest_txn_id, last_txn_id) + 1 + def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction: + new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn) # Insert new txn into txn table event_ids = json_encoder.encode([e.event_id for e in events]) @@ -268,6 +278,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=to_device_messages, one_time_key_counts=one_time_key_counts, unused_fallback_keys=unused_fallback_keys, + device_list_summary=device_list_summary, ) return await self.db_pool.runInteraction( @@ -283,25 +294,8 @@ class ApplicationServiceTransactionWorkerStore( txn_id: The transaction ID being completed. service: The application service which was sent this transaction. """ - txn_id = int(txn_id) - - def _complete_appservice_txn(txn): - # Debugging query: Make sure the txn being completed is EXACTLY +1 from - # what was there before. If it isn't, we've got problems (e.g. the AS - # has probably missed some events), so whine loudly but still continue, - # since it shouldn't fail completion of the transaction. - last_txn_id = self._get_last_txn(txn, service.id) - if (last_txn_id + 1) != txn_id: - logger.error( - "appservice: Completing a transaction which has an ID > 1 from " - "the last ID sent to this AS. We've either dropped events or " - "sent it to the AS out of order. FIX ME. last_txn=%s " - "completing_txn=%s service_id=%s", - last_txn_id, - txn_id, - service.id, - ) + def _complete_appservice_txn(txn: LoggingTransaction) -> None: # Set current txn_id for AS to 'txn_id' self.db_pool.simple_upsert_txn( txn, @@ -332,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore( An AppServiceTransaction or None. """ - def _get_oldest_unsent_txn(txn): + def _get_oldest_unsent_txn( + txn: LoggingTransaction, + ) -> Optional[Dict[str, Any]]: # Monotonically increasing txn ids, so just select the smallest # one in the txns table (we delete them when they are sent) txn.execute( @@ -359,8 +355,8 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - # TODO: to-device messages, one-time key counts and unused fallback keys - # are not yet populated for catch-up transactions. + # TODO: to-device messages, one-time key counts, device list summaries and unused + # fallback keys are not yet populated for catch-up transactions. # We likely want to populate those for reliability. return AppServiceTransaction( service=service, @@ -370,21 +366,11 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) - def _get_last_txn(self, txn, service_id: Optional[str]) -> int: - txn.execute( - "SELECT last_txn FROM application_services_state WHERE as_id=?", - (service_id,), - ) - last_txn_id = txn.fetchone() - if last_txn_id is None or last_txn_id[0] is None: # no row exists - return 0 - else: - return int(last_txn_id[0]) # select 'last_txn' col - async def set_appservice_last_pos(self, pos: int) -> None: - def set_appservice_last_pos_txn(txn): + def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None: txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) @@ -398,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore( ) -> Tuple[int, List[EventBase]]: """Get all new events for an appservice""" - def get_new_events_for_appservice_txn(txn): + def get_new_events_for_appservice_txn( + txn: LoggingTransaction, + ) -> Tuple[int, List[str]]: sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" @@ -430,13 +418,13 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence", "to_device"): + if type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) ) - def get_type_stream_id_for_appservice_txn(txn): + def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int: stream_id_type = "%s_stream_id" % type txn.execute( # We do NOT want to escape `stream_id_type`. @@ -446,7 +434,8 @@ class ApplicationServiceTransactionWorkerStore( ) last_stream_id = txn.fetchone() if last_stream_id is None or last_stream_id[0] is None: # no row exists - return 0 + # Stream tokens always start from 1, to avoid foot guns around `0` being falsey. + return 1 else: return int(last_stream_id[0]) @@ -457,13 +446,13 @@ class ApplicationServiceTransactionWorkerStore( async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence", "to_device"): + if stream_type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) ) - def set_appservice_stream_type_pos_txn(txn): + def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None: stream_id_type = "%s_stream_id" % stream_type txn.execute( "UPDATE application_services_state SET %s = ? WHERE as_id=?" diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 8b0c614e..0df160d2 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -25,7 +25,9 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore +from synapse.storage.databases.main.monthly_active_users import ( + MonthlyActiveUsersWorkerStore, +) from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache @@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return updated -class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): +class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore): def __init__( self, database: DatabasePool, @@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): ): super().__init__(database, db_conn, hs) + if hs.config.redis.redis_enabled: + # If we're using Redis, we can shift this update process off to + # the background worker + self._update_on_this_worker = hs.config.worker.run_background_tasks + else: + # If we're NOT using Redis, this must be handled by the master + self._update_on_this_worker = hs.get_instance_name() == "master" + self.user_ips_max_age = hs.config.server.user_ips_max_age + # (user_id, access_token, ip,) -> last_seen + self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( + cache_name="client_ip_last_seen", max_size=50000 + ) + if hs.config.worker.run_background_tasks and self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + if self._update_on_this_worker: + # This is the designated worker that can write to the client IP + # tables. + + # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) + self._batch_row_update: Dict[ + Tuple[str, str, str], Tuple[str, Optional[str], int] + ] = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + self.hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", self._update_client_ips_batch + ) + @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self) -> None: """Removes entries in user IPs older than the configured period.""" @@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): "_prune_old_user_ips", _prune_old_user_ips_txn ) - async def get_last_client_ip_by_device( + async def _get_last_client_ip_by_device_from_database( self, user_id: str, device_id: Optional[str] ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on. @@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): return {(d["user_id"], d["device_id"]): d for d in res} - async def get_user_ip_and_agents( + async def _get_user_ip_and_agents_from_database( self, user: UserID, since_ts: int = 0 ) -> List[LastConnectionInfo]: """Fetch the IPs and user agents for a user since the given timestamp. @@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): for access_token, ip, user_agent, last_seen in rows ] - -class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - - # (user_id, access_token, ip,) -> last_seen - self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( - cache_name="client_ip_last_seen", max_size=50000 - ) - - super().__init__(database, db_conn, hs) - - # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update: Dict[ - Tuple[str, str, str], Tuple[str, Optional[str], int] - ] = {} - - self._client_ip_looper = self._clock.looping_call( - self._update_client_ips_batch, 5 * 1000 - ) - self.hs.get_reactor().addSystemEventTrigger( - "before", "shutdown", self._update_client_ips_batch - ) - async def insert_client_ip( self, user_id: str, @@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - await self.populate_monthly_active_users(user_id) + # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return self.client_ip_last_seen.set(key, now) - self._batch_row_update[key] = (user_agent, device_id, now) + if self._update_on_this_worker: + await self.populate_monthly_active_users(user_id) + self._batch_row_update[key] = (user_agent, device_id, now) + else: + # We are not the designated writer-worker, so stream over replication + self.hs.get_replication_command_handler().send_user_ip( + user_id, access_token, ip, user_agent, device_id, now + ) @wrap_as_background_process("update_client_ips") async def _update_client_ips_batch(self) -> None: + assert ( + self._update_on_this_worker + ), "This worker is not designated to update client IPs" # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -603,51 +616,57 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): to_update = self._batch_row_update self._batch_row_update = {} - await self.db_pool.runInteraction( - "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update - ) + if to_update: + await self.db_pool.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update + ) def _update_client_ips_batch_txn( self, txn: LoggingTransaction, to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], ) -> None: - if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( - not self.database_engine.can_native_upsert - ): - self.database_engine.lock_table(txn, "user_ips") + assert ( + self._update_on_this_worker + ), "This worker is not designated to update client IPs" + + # Keys and values for the `user_ips` upsert. + user_ips_keys = [] + user_ips_values = [] + + # Keys and values for the `devices` update. + devices_keys = [] + devices_values = [] for entry in to_update.items(): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - - self.db_pool.simple_upsert_txn( - txn, - table="user_ips", - keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip}, - values={ - "user_agent": user_agent, - "device_id": device_id, - "last_seen": last_seen, - }, - lock=False, - ) + user_ips_keys.append((user_id, access_token, ip)) + user_ips_values.append((user_agent, device_id, last_seen)) # Technically an access token might not be associated with # a device so we need to check. if device_id: - # this is always an update rather than an upsert: the row should - # already exist, and if it doesn't, that may be because it has been - # deleted, and we don't want to re-create it. - self.db_pool.simple_update_txn( - txn, - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - updatevalues={ - "user_agent": user_agent, - "last_seen": last_seen, - "ip": ip, - }, - ) + devices_keys.append((user_id, device_id)) + devices_values.append((user_agent, last_seen, ip)) + + self.db_pool.simple_upsert_many_txn( + txn, + table="user_ips", + key_names=("user_id", "access_token", "ip"), + key_values=user_ips_keys, + value_names=("user_agent", "device_id", "last_seen"), + value_values=user_ips_values, + ) + + if devices_values: + self.db_pool.simple_update_many_txn( + txn, + table="devices", + key_names=("user_id", "device_id"), + key_values=devices_keys, + value_names=("user_agent", "last_seen", "ip"), + value_values=devices_values, + ) async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] @@ -662,7 +681,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): A dictionary mapping a tuple of (user_id, device_id) to dicts, with keys giving the column names from the devices table. """ - ret = await super().get_last_client_ip_by_device(user_id, device_id) + ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id) + + if not self._update_on_this_worker: + # Only the writing-worker has additional in-memory data to enhance + # the result + return ret # Update what is retrieved from the database with data which is pending # insertion, as if it has already been stored in the database. @@ -707,9 +731,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): Only the latest user agent for each access token and IP address combination is available. """ + rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts) + + if not self._update_on_this_worker: + # Only the writing-worker has additional in-memory data to enhance + # the result + return rows_from_db + results: Dict[Tuple[str, str], LastConnectionInfo] = { (connection["access_token"], connection["ip"]): connection - for connection in await super().get_user_ip_and_agents(user, since_ts) + for connection in rows_from_db } # Overlay data that is pending insertion on top of the results from the diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3b3a089b..dc8009b2 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) + device_list_max = self._device_list_id_gen.get_current_token() + device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict( + db_conn, + "device_lists_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", + min_device_list_id, + prefilled_cache=device_list_prefill, + ) + + ( + user_signature_stream_prefill, + user_signature_stream_list_id, + ) = self.db_pool.get_cache_dict( + db_conn, + "user_signature_stream", + entity_column="from_user_id", + stream_column="stream_id", + max_value=device_list_max, + limit=1000, + ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", + user_signature_stream_list_id, + prefilled_cache=user_signature_stream_prefill, + ) + + ( + device_list_federation_prefill, + device_list_federation_list_id, + ) = self.db_pool.get_cache_dict( + db_conn, + "device_lists_outbound_pokes", + entity_column="destination", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", + device_list_federation_list_id, + prefilled_cache=device_list_federation_prefill, + ) + if hs.config.worker.run_background_tasks: self._clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 @@ -681,42 +731,64 @@ class DeviceWorkerStore(SQLBaseStore): return self._device_list_stream_cache.get_all_entities_changed(from_key) async def get_users_whose_devices_changed( - self, from_key: int, user_ids: Iterable[str] + self, + from_key: int, + user_ids: Optional[Iterable[str]] = None, + to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key: The device lists stream token - user_ids: The user IDs to query for devices. + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + user_ids: If provided, only check if these users have changed their device lists. + Otherwise changes from all users are returned. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. Returns: - The set of user_ids whose devices have changed since `from_key` + The set of user_ids whose devices have changed since `from_key` (exclusive) + until `to_key` (inclusive). """ - # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + if user_ids is None: + # Get set of all users that have had device list changes since 'from_key' + user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( + from_key + ) + else: + # The same as above, but filter results to only those users in 'user_ids' + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) - if not to_check: + if not user_ids_to_check: return set() def _get_users_whose_devices_changed_txn(txn): changes = set() - sql = """ + stream_id_where_clause = "stream_id > ?" + sql_args = [from_key] + + if to_key: + stream_id_where_clause += " AND stream_id <= ?" + sql_args.append(to_key) + + sql = f""" SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? + WHERE {stream_id_where_clause} AND """ - for chunk in batch_iter(to_check, 100): + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk ) - txn.execute(sql + clause, (from_key,) + tuple(args)) + txn.execute(sql + clause, sql_args + args) changes.update(user_id for user_id, in txn) return changes @@ -788,6 +860,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? """ @@ -1506,7 +1579,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + self, + user_id: str, + device_ids: Collection[str], + hosts: Optional[Collection[str]], + room_ids: Collection[str], ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -1515,7 +1592,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: The ID of the user whose device changed. device_ids: The IDs of any changed devices. If empty, this function will return None. - hosts: The remote destinations that should be notified of the change. + hosts: The remote destinations that should be notified of the change. If + None then the set of hosts have *not* been calculated, and will be + calculated later by a background task. + room_ids: The rooms that the user is in Returns: The maximum stream ID of device list updates that were added to the database, or @@ -1524,34 +1604,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return None - async with self._device_list_id_gen.get_next_mult( - len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_change_to_stream", - self._add_device_change_to_stream_txn, + context = get_active_span_text_map() + + def add_device_changes_txn( + txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes + ): + self._add_device_change_to_stream_txn( + txn, user_id, device_ids, - stream_ids, + stream_ids_for_device_change, ) - if not hosts: - return stream_ids[-1] + self._add_device_outbound_room_poke_txn( + txn, + user_id, + device_ids, + room_ids, + stream_ids_for_device_change, + context, + hosts_have_been_calculated=hosts is not None, + ) - context = get_active_span_text_map() - async with self._device_list_id_gen.get_next_mult( - len(hosts) * len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_outbound_poke_to_stream", - self._add_device_outbound_poke_to_stream_txn, + # If the set of hosts to send to has not been calculated yet (and so + # `hosts` is None) or there are no `hosts` to send to, then skip + # trying to persist them to the DB. + if not hosts: + return + + self._add_device_outbound_poke_to_stream_txn( + txn, user_id, device_ids, hosts, - stream_ids, + stream_ids_for_outbound_pokes, context, ) + # `device_lists_stream` wants a stream ID per device update. + num_stream_ids = len(device_ids) + + if hosts: + # `device_lists_outbound_pokes` wants a different stream ID for + # each row, which is a row per host per device update. + num_stream_ids += len(hosts) * len(device_ids) + + async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids: + stream_ids_for_device_change = stream_ids[: len(device_ids)] + stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :] + + await self.db_pool.runInteraction( + "add_device_change_to_stream", + add_device_changes_txn, + stream_ids_for_device_change, + stream_ids_for_outbound_pokes, + ) + return stream_ids[-1] def _add_device_change_to_stream_txn( @@ -1595,7 +1703,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_ids: Iterable[str], hosts: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: for host in hosts: @@ -1606,8 +1714,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) now = self._clock.time_msec() - next_stream_id = iter(stream_ids) + stream_id_iterator = iter(stream_ids) + encoded_context = json_encoder.encode(context) self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1623,16 +1732,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ ( destination, - next(next_stream_id), + next(stream_id_iterator), user_id, device_id, False, now, - json_encoder.encode(context) - if whitelisted_homeserver(destination) - else "{}", + encoded_context if whitelisted_homeserver(destination) else "{}", ) for destination in hosts for device_id in device_ids ], ) + + def _add_device_outbound_room_poke_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Iterable[str], + room_ids: Collection[str], + stream_ids: List[str], + context: Dict[str, str], + hosts_have_been_calculated: bool, + ) -> None: + """Record the user in the room has updated their device. + + Args: + hosts_have_been_calculated: True if `device_lists_outbound_pokes` + has been updated already with the updates. + """ + + # We only need to convert to outbound pokes if they are our user. + converted_to_destinations = ( + hosts_have_been_calculated or not self.hs.is_mine_id(user_id) + ) + + encoded_context = json_encoder.encode(context) + + # The `device_lists_changes_in_room.stream_id` column matches the + # corresponding `stream_id` of the update in the `device_lists_stream` + # table, i.e. all rows persisted for the same device update will have + # the same `stream_id` (but different room IDs). + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_changes_in_room", + keys=( + "user_id", + "device_id", + "room_id", + "stream_id", + "converted_to_destinations", + "opentracing_context", + ), + values=[ + ( + user_id, + device_id, + room_id, + stream_id, + converted_to_destinations, + encoded_context, + ) + for room_id in room_ids + for device_id, stream_id in zip(device_ids, stream_ids) + ], + ) + + async def get_uncoverted_outbound_room_pokes( + self, limit: int = 10 + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: + """Get device list changes by room that have not yet been handled and + written to `device_lists_outbound_pokes`. + + Returns: + A list of user ID, device ID, room ID, stream ID and optional opentracing context. + """ + + sql = """ + SELECT user_id, device_id, room_id, stream_id, opentracing_context + FROM device_lists_changes_in_room + WHERE NOT converted_to_destinations + ORDER BY stream_id + LIMIT ? + """ + + def get_uncoverted_outbound_room_pokes_txn(txn): + txn.execute(sql, (limit,)) + return txn.fetchall() + + return await self.db_pool.runInteraction( + "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn + ) + + async def add_device_list_outbound_pokes( + self, + user_id: str, + device_id: str, + room_id: str, + stream_id: int, + hosts: Collection[str], + context: Optional[Dict[str, str]], + ) -> None: + """Queue the device update to be sent to the given set of hosts, + calculated from the room ID. + + Marks the associated row in `device_lists_changes_in_room` as handled. + """ + + def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + if hosts: + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_ids=[device_id], + hosts=hosts, + stream_ids=stream_ids, + context=context, + ) + + self.db_pool.simple_update_txn( + txn, + table="device_lists_changes_in_room", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "room_id": room_id, + }, + updatevalues={"converted_to_destinations": True}, + ) + + if not hosts: + # If there are no hosts then we don't try and generate stream IDs. + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + [], + ) + + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + stream_ids, + ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d2532431..3fcd5f5b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -197,12 +197,10 @@ class PersistEventsStore: ) persist_event_counter.inc(len(events_and_contexts)) - if stream < 0: - # backfilled events have negative stream orderings, so we don't - # want to set the event_persisted_position to that. - synapse.metrics.event_persisted_position.set( - events_and_contexts[-1][0].internal_metadata.stream_ordering - ) + if not use_negative_stream_ordering: + # we don't want to set the event_persisted_position to a negative + # stream_ordering. + synapse.metrics.event_persisted_position.set(stream) for event, context in events_and_contexts: if context.app_service: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 59454a47..a60e3f4f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -22,7 +22,6 @@ from typing import ( Dict, Iterable, List, - NoReturn, Optional, Set, Tuple, @@ -1330,10 +1329,9 @@ class EventsWorkerStore(SQLBaseStore): return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: - # this only exists for the benefit of the @cachedList descriptor on - # _have_seen_events_dict - raise NotImplementedError() + async def have_seen_event(self, room_id: str, event_id: str) -> bool: + res = await self._have_seen_events_dict(((room_id, event_id),)) + return res[(room_id, event_id)] def _get_current_state_event_counts_txn( self, txn: LoggingTransaction, room_id: str diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 21662296..4f1c22c7 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -15,7 +15,6 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -36,7 +35,7 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 60 * 60 * 1000 -class MonthlyActiveUsersWorkerStore(SQLBaseStore): +class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): def __init__( self, database: DatabasePool, @@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs + if hs.config.redis.redis_enabled: + # If we're using Redis, we can shift this update process off to + # the background worker + self._update_on_this_worker = hs.config.worker.run_background_tasks + else: + # If we're NOT using Redis, this must be handled by the master + self._update_on_this_worker = hs.get_instance_name() == "master" + self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau self._max_mau_value = hs.config.server.max_mau_value + self._mau_stats_only = hs.config.server.mau_stats_only + + if self._update_on_this_worker: + # Do not add more reserved users than the total allowable number + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], + ) + @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): "reap_monthly_active_users", _reap_users, reserved_users ) - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self._mau_stats_only = hs.config.server.mau_stats_only - - # Do not add more reserved users than the total allowable number - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], - ) - def _initialise_reserved_users( self, txn: LoggingTransaction, threepids: List[dict] ) -> None: @@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS txn: threepids: List of threepid dicts to reserve """ + assert ( + self._update_on_this_worker + ), "This worker is not designated to update MAUs" # XXX what is this function trying to achieve? It upserts into # monthly_active_users for each *registered* reserved mau user, but why? @@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS Args: user_id: user to add/update """ + assert ( + self._update_on_this_worker + ), "This worker is not designated to update MAUs" + # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of # is_support_user which is complicated because I want to cache the result. @@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS txn (cursor): user_id (str): user to add/update """ + assert ( + self._update_on_this_worker + ), "This worker is not designated to update MAUs" # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple @@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS Args: user_id(str): the user_id to query """ + assert ( + self._update_on_this_worker + ), "This worker is not designated to update MAUs" + if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e6f97aee..332e901d 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore): super().__init__(database, db_conn, hs) + max_receipts_stream_id = self.get_max_receipt_stream_id() + receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict( + db_conn, + "receipts_linearized", + entity_column="room_id", + stream_column="stream_id", + max_value=max_receipts_stream_id, + limit=10000, + ) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() + "ReceiptsRoomChangeCache", + min_receipts_stream_id, + prefilled_cache=receipts_stream_prefill, ) def get_max_receipt_stream_id(self) -> int: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7f3d190e..d43163c2 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID, UserInfo +from synapse.types import JsonDict, UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -79,7 +79,7 @@ class TokenLookupResult: # Make the token owner default to the user ID, which is the common case. @token_owner.default - def _default_token_owner(self): + def _default_token_owner(self) -> str: return self.user_id @@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): the account. """ - def set_account_validity_for_user_txn(txn): + def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_update_txn( txn=txn, table="account_validity", @@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_renewal_token_for_user", ) - async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: + async def get_users_expiring_soon(self) -> List[Tuple[str, int]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - A list of dictionaries, each with a user ID and expiration time (in milliseconds). + A list of tuples, each with a user ID and expiration time (in milliseconds). """ - def select_users_txn(txn, now_ms, renew_at): + def select_users_txn( + txn: LoggingTransaction, now_ms: int, renew_at: int + ) -> List[Tuple[str, int]]: sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_users_expiring_soon", @@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): admin: true iff the user is to be a server admin, false otherwise. """ - def set_server_admin_txn(txn): + def set_server_admin_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} ) @@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): user_type: type of the user or None for a user without a type. """ - def set_user_type_txn(txn): + def set_user_type_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, "users", {"name": user.to_string()}, {"user_type": user_type} ) @@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_user_type", set_user_type_txn) - def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: + def _query_for_auth( + self, txn: LoggingTransaction, token: str + ) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, users.is_guest, @@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "is_support_user", self.is_support_user_txn, user_id ) - def is_real_user_txn(self, txn, user_id): + def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool: res = self.db_pool.simple_select_one_onecol_txn( txn=txn, table="users", @@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) return res is None - def is_support_user_txn(self, txn, user_id): + def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool: res = self.db_pool.simple_select_one_onecol_txn( txn=txn, table="users", @@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): A mapping of user_id -> password_hash. """ - def f(txn): + def f(txn: LoggingTransaction) -> Dict[str, str]: sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" txn.execute(sql, (user_id,)) - return dict(txn) + result = cast(List[Tuple[str, str]], txn.fetchall()) + return dict(result) return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) @@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): def _replace_user_external_id_txn( txn: LoggingTransaction, - ): + ) -> None: _remove_user_external_ids_txn(txn, user_id) for auth_provider, external_id in record_external_ids: @@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) return [(r["auth_provider"], r["external_id"]) for r in res] - async def count_all_users(self): + async def count_all_users(self) -> int: """Counts all users registered on the homeserver.""" - def _count_users(txn): + def _count_users(txn: LoggingTransaction) -> int: txn.execute("SELECT COUNT(*) AS users FROM users") rows = self.db_pool.cursor_to_dict(txn) if rows: @@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): who registered on the homeserver in the past 24 hours """ - def _count_daily_user_type(txn): + def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]: yesterday = int(self._clock.time()) - (60 * 60 * 24) sql = """ @@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "count_daily_user_type", _count_daily_user_type ) - async def count_nonbridged_users(self): - def _count_users(txn): + async def count_nonbridged_users(self) -> int: + def _count_users(txn: LoggingTransaction) -> int: txn.execute( """ SELECT COUNT(*) FROM users WHERE appservice_id IS NULL """ ) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_users", _count_users) - async def count_real_users(self): + async def count_real_users(self) -> int: """Counts all users without a special user_type registered on the homeserver.""" - def _count_users(txn): + def _count_users(txn: LoggingTransaction) -> int: txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") rows = self.db_pool.cursor_to_dict(txn) if rows: @@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return user_id def get_user_id_by_threepid_txn( - self, txn, medium: str, address: str + self, txn: LoggingTransaction, medium: str, address: str ) -> Optional[str]: """Returns user id from threepid @@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: + async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, @@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def add_user_bound_threepid( self, user_id: str, medium: str, address: str, id_server: str - ): + ) -> None: """The server proxied a bind request to the given identity server on behalf of the given user. We need to remember this in case the user asks us to unbind the threepid. @@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): assert address or sid - def get_threepid_validation_session_txn(txn): + def get_threepid_validation_session_txn( + txn: LoggingTransaction, + ) -> Optional[Dict[str, Any]]: sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at @@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): session_id: The ID of the session to delete """ - def delete_threepid_session_txn(txn): + def delete_threepid_session_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="threepid_validation_token", @@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def cull_expired_threepid_validation_tokens(self) -> None: """Remove threepid validation tokens with expiry dates that have passed""" - def cull_expired_threepid_validation_tokens_txn(txn, ts): + def cull_expired_threepid_validation_tokens_txn( + txn: LoggingTransaction, ts: int + ) -> None: sql = """ DELETE FROM threepid_validation_token WHERE expires < ? @@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) @wrap_as_background_process("account_validity_set_expiration_dates") - async def _set_expiration_date_when_missing(self): + async def _set_expiration_date_when_missing(self) -> None: """ Retrieves the list of registered users that don't have an expiration date, and adds an expiration date for each of them. """ - def select_users_with_no_expiration_date_txn(txn): + def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None: """Retrieves the list of registered users with no expiration date from the database, filtering out deactivated users. """ @@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): select_users_with_no_expiration_date_txn, ) - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + def set_expiration_date_for_user_txn( + self, txn: LoggingTransaction, user_id: str, use_delta: bool = False + ) -> None: """Sets an expiration date to the account with the given user ID. Args: @@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): token: The registration token pending use """ - def _set_registration_token_pending_txn(txn): + def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None: pending = self.db_pool.simple_select_one_onecol_txn( txn, "registration_tokens", @@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): updatevalues={"pending": pending + 1}, ) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_registration_token_pending", _set_registration_token_pending_txn ) @@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): token: The registration token to be 'used' """ - def _use_registration_token_txn(txn): + def _use_registration_token_txn(txn: LoggingTransaction) -> None: # Normally, res is Optional[Dict[str, Any]]. # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors @@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): }, ) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "use_registration_token", _use_registration_token_txn ) @@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): A list of dicts, each containing details of a token. """ - def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): + def select_registration_tokens_txn( + txn: LoggingTransaction, now: int, valid: Optional[bool] + ) -> List[Dict[str, Any]]: if valid is None: # Return all tokens regardless of validity txn.execute("SELECT * FROM registration_tokens") @@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Whether the row was inserted or not. """ - def _create_registration_token_txn(txn): + def _create_registration_token_txn(txn: LoggingTransaction) -> bool: row = self.db_pool.simple_select_one_txn( txn, "registration_tokens", @@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): A dict with all info about the token, or None if token doesn't exist. """ - def _update_registration_token_txn(txn): + def _update_registration_token_txn( + txn: LoggingTransaction, + ) -> Optional[Dict[str, Any]]: try: self.db_pool.simple_update_one_txn( txn, @@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) -> Optional[RefreshTokenLookupResult]: """Lookup a refresh token with hints about its validity.""" - def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: + def _lookup_refresh_token_txn( + txn: LoggingTransaction, + ) -> Optional[RefreshTokenLookupResult]: txn.execute( """ SELECT @@ -1745,6 +1762,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "replace_refresh_token", _replace_refresh_token_txn ) + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1795,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): unique=False, ) - async def _background_update_set_deactivated_flag(self, progress, batch_size): + async def _background_update_set_deactivated_flag( + self, progress: JsonDict, batch_size: int + ) -> int: """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. """ last_user = progress.get("user_id", "") - def _background_update_set_deactivated_flag_txn(txn): + def _background_update_set_deactivated_flag_txn( + txn: LoggingTransaction, + ) -> Tuple[bool, int]: txn.execute( """ SELECT @@ -1874,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): deactivated, ) - def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): + def set_user_deactivated_status_txn( + self, txn: LoggingTransaction, user_id: str, deactivated: bool + ) -> None: self.db_pool.simple_update_one_txn( txn=txn, table="users", @@ -1887,18 +1922,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) - @cached() - async def is_guest(self, user_id: str) -> bool: - res = await self.db_pool.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__( @@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): return next_id - def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: + def _set_device_for_access_token_txn( + self, txn: LoggingTransaction, token: str, device_id: str + ) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" ) @@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def _register_user( self, - txn, + txn: LoggingTransaction, user_id: str, password_hash: Optional[str], was_guest: bool, @@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin: bool, user_type: Optional[str], shadow_banned: bool, - ): + ) -> None: user_id_obj = UserID.from_string(user_id) now = int(self._clock.time()) @@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): pointless. Use flush_user separately. """ - def user_set_password_hash_txn(txn): + def user_set_password_hash_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) @@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): StoreError(404) if user not found """ - def f(txn): + def f(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, table="users", @@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): StoreError(404) if user not found """ - def f(txn): + def f(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, table="users", @@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): A tuple of (token, token id, device id) for each of the deleted tokens """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]: keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id @@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): return await self.db_pool.runInteraction("user_delete_access_tokens", f) async def delete_access_token(self, access_token: str) -> None: - def f(txn): + def f(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) async def delete_refresh_token(self, refresh_token: str) -> None: - def f(txn): + def f(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_one_txn( txn, table="refresh_tokens", keyvalues={"token": refresh_token} ) @@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): """ # Insert everything into a transaction in order to run atomically - def validate_threepid_session_txn(txn): + def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]: row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_session", @@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): longer be valid """ - def start_or_continue_validation_session_txn(txn): + def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None: # Create or update a validation session self.db_pool.simple_upsert_txn( txn, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b2295fd5..407158ce 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -17,6 +17,7 @@ from typing import ( TYPE_CHECKING, Collection, Dict, + FrozenSet, Iterable, List, Optional, @@ -39,8 +40,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import RoomStreamToken, StreamToken +from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -49,6 +49,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _RelatedEvent: + """ + Contains enough information about a related event in order to properly filter + events from ignored users. + """ + + # The event ID of the related event. + event_id: str + # The sender of the related event. + sender: str + + class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -73,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> PaginationChunk: + ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -90,8 +103,10 @@ class RelationsWorkerStore(SQLBaseStore): to_token: Fetch rows up to the given token, or up to the end if None. Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. + A tuple of: + A list of related event IDs & their senders. + + The next stream token, if one exists. """ # We don't use `event_id`, it's there so that we can cache based on # it. The `event_id` must match the `event.event_id`. @@ -132,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore): order = "ASC" sql = """ - SELECT event_id, relation_type, topological_ordering, stream_ordering + SELECT event_id, relation_type, sender, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -146,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_recent_references_for_event_txn( txn: LoggingTransaction, - ) -> PaginationChunk: + ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None @@ -156,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore): # Do not include edits for redacted events as they leak event # content. if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append({"event_id": row[0]}) - last_topo_id = row[2] - last_stream_id = row[3] + events.append(_RelatedEvent(row[0], row[2])) + last_topo_id = row[3] + last_stream_id = row[4] # If there are more events, generate the next pagination key. next_token = None @@ -179,9 +194,7 @@ class RelationsWorkerStore(SQLBaseStore): groups_key=0, ) - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token - ) + return events[:limit], next_token return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn @@ -252,15 +265,8 @@ class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) async def get_aggregation_groups_for_event( - self, - event_id: str, - room_id: str, - event_type: Optional[str] = None, - limit: int = 5, - direction: str = "b", - from_token: Optional[AggregationPaginationToken] = None, - to_token: Optional[AggregationPaginationToken] = None, - ) -> PaginationChunk: + self, event_id: str, room_id: str, limit: int = 5 + ) -> List[JsonDict]: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -270,82 +276,96 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. - event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. - direction: Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. Returns: List of groups of annotations that match. Each row is a dict with `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [ + args = [ event_id, room_id, RelationTypes.ANNOTATION, + limit, ] - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) + sql = """ + SELECT type, aggregation_key, COUNT(DISTINCT sender) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? + GROUP BY relation_type, type, aggregation_key + ORDER BY COUNT(*) DESC + LIMIT ? + """ - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] - to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] - engine=self.database_engine, + def _get_aggregation_groups_for_event_txn( + txn: LoggingTransaction, + ) -> List[JsonDict]: + txn.execute(sql, args) + + return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] + + return await self.db_pool.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - if direction == "b": - order = "DESC" - else: - order = "ASC" + async def get_aggregation_groups_for_users( + self, + event_id: str, + room_id: str, + limit: int, + users: FrozenSet[str] = frozenset(), + ) -> Dict[Tuple[str, str], int]: + """Fetch the partial aggregations for an event for specific users. - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" + This is used, in conjunction with get_aggregation_groups_for_event, to + remove information from the results for ignored users. - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + Args: + event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. + limit: Only fetch the `limit` groups. + users: The users to fetch information for. + + Returns: + A map of (event type, aggregation key) to a count of users. + """ + + if not users: + return {} + + args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] + + users_sql, users_args = make_in_list_sql_clause( + self.database_engine, "sender", users + ) + args.extend(users_args) + + sql = f""" + SELECT type, aggregation_key, COUNT(DISTINCT sender) FROM event_relations INNER JOIN events USING (event_id) - WHERE {where_clause} + WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + ORDER BY COUNT(*) DESC LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) + """ - def _get_aggregation_groups_for_event_txn( + def _get_aggregation_groups_for_users_txn( txn: LoggingTransaction, - ) -> PaginationChunk: - txn.execute(sql, where_args + [limit + 1]) + ) -> Dict[Tuple[str, str], int]: + txn.execute(sql, args + [limit]) - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) - - if len(events) <= limit: - next_batch = None - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) + return {(row[0], row[1]): row[2] for row in txn} return await self.db_pool.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn ) @cached() @@ -574,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore): return summaries + async def get_threaded_messages_per_user( + self, + event_ids: Collection[str], + users: FrozenSet[str] = frozenset(), + ) -> Dict[Tuple[str, str], int]: + """Get the number of threaded replies for a set of users. + + This is used, in conjunction with get_thread_summaries, to calculate an + accurate count of the replies to a thread by subtracting ignored users. + + Args: + event_ids: The events to check for threaded replies. + users: The user to calculate the count of their replies. + + Returns: + A map of the (event_id, sender) to the count of their replies. + """ + if not users: + return {} + + # Fetch the number of threaded replies. + sql = """ + SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id + WHERE + %s + AND %s + AND %s + GROUP BY parent.event_id, child.sender + """ + + def _get_threaded_messages_per_user_txn( + txn: LoggingTransaction, + ) -> Dict[Tuple[str, str], int]: + users_sql, users_args = make_in_list_sql_clause( + self.database_engine, "child.sender", users + ) + events_clause, events_args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD] + else: + relations_clause = "relation_type = ?" + relations_args = [RelationTypes.THREAD] + + txn.execute( + sql % (users_sql, events_clause, relations_clause), + users_args + events_args + relations_args, + ) + return {(row[0], row[1]): row[2] for row in txn} + + return await self.db_pool.runInteraction( + "get_threaded_messages_per_user", _get_threaded_messages_per_user_txn + ) + @cached() def get_thread_participated(self, event_id: str, user_id: str) -> bool: raise NotImplementedError() @@ -661,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore): %s; """ - def _get_if_events_have_relations(txn) -> List[str]: + def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]: clauses: List[str] = [] clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", parent_ids diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 3248da53..48e83592 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): return None async def get_rooms_for_local_user_where_membership_is( - self, user_id: str, membership_list: Collection[str] + self, + user_id: str, + membership_list: Collection[str], + excluded_rooms: Optional[List[str]] = None, ) -> List[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id: The user ID. membership_list: A list of synapse.api.constants.Membership values which the user must be in. + excluded_rooms: A list of rooms to ignore. Returns: The RoomsForUser that the user matches the membership types. @@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership_list, ) - # Now we filter out forgotten rooms - forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id) - return [room for room in rooms if room.room_id not in forgotten_rooms] + # Now we filter out forgotten and excluded rooms + rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id) + + if excluded_rooms is not None: + rooms_to_exclude.update(set(excluded_rooms)) + + return [room for room in rooms if room.room_id not in rooms_to_exclude] def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id: str, membership_list: List[str] + self, + txn, + user_id: str, + membership_list: List[str], ) -> List[RoomsForUser]: # Paranoia check. if not self.hs.is_mine_id(user_id): @@ -877,7 +888,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return frozenset(cache.hosts_to_joined_users) # Since we'll mutate the cache we need to lock. - with (await self._joined_host_linearizer.queue(room_id)): + async with self._joined_host_linearizer.queue(room_id): if state_entry.state_group == cache.state_group: # Same state group, so nothing to do. We've already checked for # this above, but the cache may have changed while waiting on diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index 0518b8b9..95148fd2 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList class SignatureWorkerStore(EventsWorkerStore): @cached() - def get_event_reference_hash(self, event_id): + def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]: # This is a dummy function to allow get_event_reference_hashes # to use its cache raise NotImplementedError() diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 28460fd3..ecdc1fdc 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections.abc import logging -from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple + +from frozendict import frozendict from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -29,7 +30,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter -from synapse.types import JsonDict, StateMap +from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return room_version - async def get_room_predecessor(self, room_id: str) -> Optional[dict]: + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. @@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): predecessor = create_event.content.get("predecessor", None) # Ensure the key is a dictionary - if not isinstance(predecessor, collections.abc.Mapping): + if not isinstance(predecessor, (dict, frozendict)): return None + # The keys must be strings since the data is JSON. return predecessor async def get_create_event_for_room(self, room_id: str) -> EventBase: @@ -202,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): The current state of the room. """ - def _get_current_state_ids_txn(txn): + def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]: txn.execute( """SELECT type, state_key, event_id FROM current_state_events WHERE room_id = ? @@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): list_name="event_ids", num_args=1, ) - async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict: - """Returns mapping event_id -> state_group""" + async def _get_state_group_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, int]: + """Returns mapping event_id -> state_group. + + Raises: + RuntimeError if the state is unknown at any of the given events + """ rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", @@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="_get_state_group_for_events", ) - return {row["event_id"]: row["state_group"] for row in rows} + res = {row["event_id"]: row["state_group"] for row in rows} + for e in event_ids: + if e not in res: + raise RuntimeError("No state group for unknown or outlier event %s" % e) + return res async def get_referenced_state_groups( self, state_groups: Iterable[int] @@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): ) for user_id in potentially_left_users - joined_users: - await self.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined] return batch_size diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 39e1efe3..6d45a8a9 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,7 +36,17 @@ what sort order was used: """ import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + List, + Optional, + Set, + Tuple, + cast, +) import attr from frozendict import frozendict @@ -585,7 +595,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key async def get_membership_changes_for_user( - self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken + self, + user_id: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, + excluded_rooms: Optional[List[str]] = None, ) -> List[EventBase]: """Fetch membership events for a given user. @@ -610,23 +624,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): min_from_id = from_key.stream max_to_id = to_key.get_max_stream_pos() + args: List[Any] = [user_id, min_from_id, max_to_id] + + ignore_room_clause = "" + if excluded_rooms is not None and len(excluded_rooms) > 0: + ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join( + "?" for _ in excluded_rooms + ) + args = args + excluded_rooms + sql = """ SELECT m.event_id, instance_name, topological_ordering, stream_ordering FROM events AS e, room_memberships AS m WHERE e.event_id = m.event_id AND m.user_id = ? AND e.stream_ordering > ? AND e.stream_ordering <= ? + %s ORDER BY e.stream_ordering ASC - """ - txn.execute( - sql, - ( - user_id, - min_from_id, - max_to_id, - ), + """ % ( + ignore_room_clause, ) + txn.execute(sql, args) + rows = [ _EventDictReturn(event_id, None, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn @@ -722,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): A tuple of (stream ordering, topological ordering, event_id) """ - def _f(txn): + def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]: sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" @@ -732,27 +752,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering)) - return txn.fetchone() + return cast(Optional[Tuple[int, int, str]], txn.fetchone()) return await self.db_pool.runInteraction( "get_room_event_before_stream_ordering", _f ) - async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: - """Returns the current token for rooms stream. + async def get_current_room_stream_token_for_room_id( + self, room_id: Optional[str] = None + ) -> RoomStreamToken: + """Returns the current position of the rooms stream. - By default, it returns the current global stream token. Specifying a - `room_id` causes it to return the current room specific topological - token. + By default, it returns a live token with the current global stream + token. Specifying a `room_id` causes it to return a historic token with + the room specific topological token. """ - token = self.get_room_max_stream_ordering() + stream_ordering = self.get_room_max_stream_ordering() if room_id is None: - return "s%d" % (token,) + return RoomStreamToken(None, stream_ordering) else: topo = await self.db_pool.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) - return "t%d-%d" % (topo, token) + return RoomStreamToken(topo, stream_ordering) def get_stream_id_for_event_txn( self, @@ -827,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): @staticmethod def _set_before_and_after( events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True - ): + ) -> None: """Inserts ordering information to events' internal metadata from the DB rows. @@ -973,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): the `current_id`). """ - def get_all_new_events_stream_txn(txn): + def get_all_new_events_stream_txn( + txn: LoggingTransaction, + ) -> Tuple[int, List[str]]: sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" @@ -1319,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_id_for_instance(self, instance_name: str) -> int: """Get a unique, immutable ID that corresponds to the given Synapse worker instance.""" - def _get_id_for_instance_txn(txn): + def _get_id_for_instance_txn(txn: LoggingTransaction) -> int: instance_id = self.db_pool.simple_select_one_onecol_txn( txn, table="instance_map", diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index c8e508a9..b0f5de67 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore): ) def get_tag_content( - txn: LoggingTransaction, tag_ids + txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]] ) -> List[Tuple[int, Tuple[str, str, str]]]: sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" results = [] @@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore): return self._account_data_id_gen.get_current_token() def _update_revision_txn( - self, txn, user_id: str, room_id: str, next_id: int + self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int ) -> None: """Update the latest revision of the tags for the given user and room. |