diff options
author | Andrej Shadura <andrewsh@debian.org> | 2022-06-19 15:20:00 +0200 |
---|---|---|
committer | Andrej Shadura <andrewsh@debian.org> | 2022-06-19 15:21:39 +0200 |
commit | 734a8e556ce00029d9d7ab0fed73336d24fa91f3 (patch) | |
tree | b277733532b1b141d534133a4715a2fe765ab533 /synapse/storage/databases | |
parent | 7a966d08c8403bcff00ac636d977097602501a69 (diff) | |
parent | 6dc64c92c6991f09910f3e6db368e6eeb4b1981e (diff) |
Update upstream source from tag 'upstream/1.61.0'
Update to upstream version '1.61.0'
with Debian dir 5b9bb60cc861cbccd0027b7db7acf826071dc6a0
Diffstat (limited to 'synapse/storage/databases')
30 files changed, 1296 insertions, 2246 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 5895b892..11d9d16c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -26,11 +26,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - IdGenerator, - MultiWriterIdGenerator, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -155,12 +151,6 @@ class DataStore( ], ) - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._group_updates_id_gen = StreamIdGenerator( - db_conn, "local_group_updates", "stream_id" - ) - self._cache_id_gen: Optional[MultiWriterIdGenerator] if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about @@ -203,20 +193,6 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( - db_conn, - "local_group_updates", - entity_column="user_id", - stream_column="stream_id", - max_value=self._group_updates_id_gen.get_current_token(), - limit=1000, - ) - self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", - min_group_updates_id, - prefilled_cache=_group_updates_prefill, - ) - self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 945707b0..e284454b 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -203,19 +203,29 @@ class ApplicationServiceTransactionWorkerStore( """Get the application service state. Args: - service: The service whose state to set. + service: The service whose state to get. Returns: - An ApplicationServiceState or none. + An ApplicationServiceState, or None if we have yet to attempt any + transactions to the AS. """ - result = await self.db_pool.simple_select_one( + # if we have created transactions for this AS but not yet attempted to send + # them, we will have a row in the table with state=NULL (recording the stream + # positions we have processed up to). + # + # On the other hand, if we have yet to create any transactions for this AS at + # all, then there will be no row for the AS. + # + # In either case, we return None to indicate "we don't yet know the state of + # this AS". + result = await self.db_pool.simple_select_one_onecol( "application_services_state", {"as_id": service.id}, - ["state"], + retcol="state", allow_none=True, desc="get_appservice_state", ) if result: - return ApplicationServiceState(result.get("state")) + return ApplicationServiceState(result) return None async def set_appservice_state( @@ -296,14 +306,6 @@ class ApplicationServiceTransactionWorkerStore( """ def _complete_appservice_txn(txn: LoggingTransaction) -> None: - # Set current txn_id for AS to 'txn_id' - self.db_pool.simple_upsert_txn( - txn, - "application_services_state", - {"as_id": service.id}, - {"last_txn": txn_id}, - ) - # Delete txn self.db_pool.simple_delete_txn( txn, @@ -452,16 +454,15 @@ class ApplicationServiceTransactionWorkerStore( % (stream_type,) ) - def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None: - stream_id_type = "%s_stream_id" % stream_type - txn.execute( - "UPDATE application_services_state SET %s = ? WHERE as_id=?" - % stream_id_type, - (pos, service.id), - ) - - await self.db_pool.runInteraction( - "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn + # this may be the first time that we're recording any state for this AS, so + # we don't yet know if a row for it exists; hence we have to upsert here. + await self.db_pool.simple_upsert( + table="application_services_state", + keyvalues={"as_id": service.id}, + values={f"{stream_type}_stream_id": pos}, + # no need to lock when emulating upsert: as_id is a unique key + lock=False, + desc="set_appservice_stream_type_pos", ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index dd4e83a2..1653a6a9 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._instance_name = hs.get_instance_name() + self.db_pool.updates.register_background_index_update( + update_name="cache_invalidation_index_by_instance", + index_name="cache_invalidation_stream_by_instance_instance_index", + table="cache_invalidation_stream_by_instance", + columns=("instance_name", "stream_id"), + psql_only=True, # The table is only on postgres DBs. + ) + async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 2df4dd4e..d900064c 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -28,6 +28,7 @@ from typing import ( cast, ) +from synapse.api.constants import EduTypes from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( get_active_span_text_map, @@ -419,7 +420,7 @@ class DeviceWorkerStore(SQLBaseStore): # Add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - results.append(("m.signing_key_update", result)) + results.append((EduTypes.SIGNING_KEY_UPDATE, result)) # also send the unstable version # FIXME: remove this when enough servers have upgraded # and remove the length budgeting above. @@ -545,7 +546,7 @@ class DeviceWorkerStore(SQLBaseStore): else: result["deleted"] = True - results.append(("m.device_list_update", result)) + results.append((EduTypes.DEVICE_LIST_UPDATE, result)) return results @@ -1153,6 +1154,45 @@ class DeviceWorkerStore(SQLBaseStore): _prune_txn, ) + async def get_local_devices_not_accessed_since( + self, since_ms: int + ) -> Dict[str, List[str]]: + """Retrieves local devices that haven't been accessed since a given date. + + Args: + since_ms: the timestamp to select on, every device with a last access date + from before that time is returned. + + Returns: + A dictionary with an entry for each user with at least one device matching + the request, which value is a list of the device ID(s) for the corresponding + device(s). + """ + + def get_devices_not_accessed_since_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, str]]: + sql = """ + SELECT user_id, device_id + FROM devices WHERE last_seen < ? AND hidden = FALSE + """ + txn.execute(sql, (since_ms,)) + return self.db_pool.cursor_to_dict(txn) + + rows = await self.db_pool.runInteraction( + "get_devices_not_accessed_since", + get_devices_not_accessed_since_txn, + ) + + devices: Dict[str, List[str]] = {} + for row in rows: + # Remote devices are never stale from our point of view. + if self.hs.is_mine_id(row["user_id"]): + user_devices = devices.setdefault(row["user_id"], []) + user_devices.append(row["device_id"]) + + return devices + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__( diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index b789a588..af59be6b 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -21,7 +21,7 @@ from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction -from synapse.types import JsonDict, JsonSerializable +from synapse.types import JsonDict, JsonSerializable, StreamKeyType from synapse.util import json_encoder @@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "message": "Set room key", "room_id": room_id, "session_id": session_id, - "room_key": room_key, + StreamKeyType.ROOM: room_key, } ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 47102247..eec55b64 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,7 +14,17 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + cast, +) import attr from prometheus_client import Counter, Gauge @@ -33,7 +43,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. - room = await self.get_room(room_id) + room = await self.get_room(room_id) # type: ignore[attr-defined] if room["has_auth_chain_index"]: try: return await self.db_pool.runInteraction( @@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) def _get_auth_chain_ids_using_cover_index_txn( - self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool + self, + txn: LoggingTransaction, + room_id: str, + event_ids: Collection[str], + include_given: bool, ) -> Set[str]: """Calculates the auth chain IDs using the chain index.""" @@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas chains: Dict[int, int] = {} # Add all linked chains reachable from initial set of chains. - for batch in batch_iter(event_chains, 1000): + for batch2 in batch_iter(event_chains, 1000): clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch + txn.database_engine, "origin_chain_id", batch2 ) txn.execute(sql % (clause,), args) @@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas front = set(event_ids) while front: - new_front = set() + new_front: Set[str] = set() for chunk in batch_iter(front, 100): # Pull the auth events either from the cache or DB. to_fetch: List[str] = [] # Event IDs to fetch from DB @@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Note we need to batch up the results by event ID before # adding to the cache. - to_cache = {} + to_cache: Dict[str, List[Tuple[str, int]]] = {} for event_id, auth_event_id, auth_event_depth in txn: to_cache.setdefault(event_id, []).append( (auth_event_id, auth_event_depth) @@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. - room = await self.get_room(room_id) + room = await self.get_room(room_id) # type: ignore[attr-defined] if room["has_auth_chain_index"]: try: return await self.db_pool.runInteraction( @@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) def _get_auth_chain_difference_using_cover_index_txn( - self, txn: Cursor, room_id: str, state_sets: List[Set[str]] + self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]] ) -> Set[str]: """Calculates the auth chain difference using the chain index. @@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # (We need to take a copy of `seen_chains` as we want to mutate it in # the loop) - for batch in batch_iter(set(seen_chains), 1000): + for batch2 in batch_iter(set(seen_chains), 1000): clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch + txn.database_engine, "origin_chain_id", batch2 ) txn.execute(sql % (clause,), args) @@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return result def _get_auth_chain_difference_txn( - self, txn, state_sets: List[Set[str]] + self, txn: LoggingTransaction, state_sets: List[Set[str]] ) -> Set[str]: """Calculates the auth chain difference using a breadth first search. @@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # I think building a temporary list with fetchall is more efficient than # just `search.extend(txn)`, but this is unconfirmed - search.extend(txn.fetchall()) + search.extend(cast(List[Tuple[int, str]], txn.fetchall())) # sort by depth search.sort() @@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # We parse the results and add the to the `found` set and the # cache (note we need to batch up the results by event ID before # adding to the cache). - to_cache = {} + to_cache: Dict[str, List[Tuple[str, int]]] = {} for event_id, auth_event_id, auth_event_depth in txn: to_cache.setdefault(event_id, []).append( (auth_event_id, auth_event_depth) @@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return {eid for eid, n in event_to_missing_sets.items() if n} async def get_oldest_event_ids_with_depth_in_room( - self, room_id + self, room_id: str ) -> List[Tuple[str, int]]: """Gets the oldest events(backwards extremities) in the room along with the aproximate depth. @@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas List of (event_id, depth) tuples """ - def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + def get_oldest_event_ids_with_depth_in_room_txn( + txn: LoggingTransaction, room_id: str + ) -> List[Tuple[str, int]]: # Assemble a dictionary with event_id -> depth for the oldest events # we know of in the room. Backwards extremeties are the oldest # events we know of in the room but we only know of them because @@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, False)) - return txn.fetchall() + return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_oldest_event_ids_with_depth_in_room", @@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) async def get_insertion_event_backward_extremities_in_room( - self, room_id + self, room_id: str ) -> List[Tuple[str, int]]: """Get the insertion events we know about that we haven't backfilled yet. @@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas List of (event_id, depth) tuples """ - def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): + def get_insertion_event_backward_extremities_in_room_txn( + txn: LoggingTransaction, room_id: str + ) -> List[Tuple[str, int]]: sql = """ SELECT b.event_id, MAX(e.depth) FROM insertion_events as i /* We only want insertion events that are also marked as backwards extremities */ @@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ txn.execute(sql, (room_id,)) - return txn.fetchall() + return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_insertion_event_backward_extremities_in_room", @@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) - async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: + async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs Args: @@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return max_depth_event_id, current_max_depth - async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: + async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the min depth from a set of event IDs Args: @@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) - def _get_prev_events_for_room_txn(self, txn, room_id: str): + def _get_prev_events_for_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> List[str]: # we just use the 10 newest events. Older events will become # prev_events of future events. @@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas sorted by extremity count. """ - def _get_rooms_with_many_extremities_txn(txn): + def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]: where_clause = "1=1" if room_id_filter: where_clause = "room_id NOT IN (%s)" % ( @@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas "get_min_depth", self._get_min_depth_interaction, room_id ) - def _get_min_depth_interaction(self, txn, room_id): + def _get_min_depth_interaction( + self, txn: LoggingTransaction, room_id: str + ) -> Optional[int]: min_depth = self.db_pool.simple_select_one_onecol_txn( txn, table="room_depth", @@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. - last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) + last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined] # We don't always have a full stream_to_exterm_id table, e.g. after # the upgrade that introduced it, so we make sure we never ask for a # stream_ordering from before a restart - last_change = max(self._stream_order_on_start, last_change) + last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined] # provided the last_change is recent enough, we now clamp the requested # stream_ordering to it. - if last_change > self.stream_ordering_month_ago: + if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined] stream_ordering = min(last_change, stream_ordering) return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas stream_orderings from that point. """ - if stream_ordering <= self.stream_ordering_month_ago: + if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined] raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ @@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas WHERE room_id = ? """ - def get_forward_extremeties_for_room_txn(txn): + def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] @@ -1033,7 +1057,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas INNER JOIN batch_events AS c ON i.next_batch_id = c.batch_id /* Get the depth of the batch start event from the events table */ - INNER JOIN events AS e USING (event_id) + INNER JOIN events AS e ON c.event_id = e.event_id /* Find an insertion event which matches the given event_id */ WHERE i.event_id = ? LIMIT ? @@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ] async def get_backfill_events( - self, room_id: str, seed_event_id_list: list, limit: int - ): + self, room_id: str, seed_event_id_list: List[str], limit: int + ) -> List[EventBase]: """Get a list of Events for a given topic that occurred before (and including) the events in seed_event_id_list. Return a list of max size `limit` @@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) events = await self.get_events_as_list(event_ids) return sorted( - events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) + # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering. + # But it's never None, because these events were previously persisted to the DB. + events, + key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator] ) - def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): + def _get_backfill_events( + self, + txn: LoggingTransaction, + room_id: str, + seed_event_id_list: List[str], + limit: int, + ) -> Set[str]: """ We want to make sure that we do a breadth-first, "depth" ordered search. We also handle navigating historical branches of history connected by @@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas limit, ) - event_id_results = set() + event_id_results: Set[str] = set() # In a PriorityQueue, the lowest valued entries are retrieved first. # We're using depth as the priority in the queue and tie-break based on @@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # highest and newest-in-time message. We add events to the queue with a # negative depth so that we process the newest-in-time messages first # going backwards in time. stream_ordering follows the same pattern. - queue = PriorityQueue() + queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue() for seed_event_id in seed_event_id_list: event_lookup_result = self.db_pool.simple_select_one_txn( @@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return event_id_results - async def get_missing_events(self, room_id, earliest_events, latest_events, limit): + async def get_missing_events( + self, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[EventBase]: ids = await self.db_pool.runInteraction( "get_missing_events", self._get_missing_events, @@ -1264,25 +1303,29 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) return await self.get_events_as_list(ids) - def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): + def _get_missing_events( + self, + txn: LoggingTransaction, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[str]: seen_events = set(earliest_events) front = set(latest_events) - seen_events - event_results = [] + event_results: List[str] = [] query = ( "SELECT prev_event_id FROM event_edges " - "WHERE room_id = ? AND event_id = ? AND is_state = ? " + "WHERE event_id = ? AND NOT is_state " "LIMIT ?" ) while front and len(event_results) < limit: new_front = set() for event_id in front: - txn.execute( - query, (room_id, event_id, False, limit - len(event_results)) - ) - + txn.execute(query, (event_id, limit - len(event_results))) new_results = {t[0] for t in txn} - seen_events new_front |= new_results @@ -1311,7 +1354,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @wrap_as_background_process("delete_old_forward_extrem_cache") async def _delete_old_forward_extrem_cache(self) -> None: - def _delete_old_forward_extrem_cache_txn(txn): + def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None: # Delete entries older than a month, while making sure we don't delete # the only entries for a room. sql = """ @@ -1324,7 +1367,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) AND stream_ordering < ? """ txn.execute( - sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) + sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined] ) await self.db_pool.runInteraction( @@ -1382,7 +1425,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ if self.db_pool.engine.supports_returning: - def _remove_received_event_from_staging_txn(txn): + def _remove_received_event_from_staging_txn( + txn: LoggingTransaction, + ) -> Optional[int]: sql = """ DELETE FROM federation_inbound_events_staging WHERE origin = ? AND event_id = ? @@ -1390,21 +1435,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ txn.execute(sql, (origin, event_id)) - return txn.fetchone() + row = cast(Optional[Tuple[int]], txn.fetchone()) + + if row is None: + return None - row = await self.db_pool.runInteraction( + return row[0] + + return await self.db_pool.runInteraction( "remove_received_event_from_staging", _remove_received_event_from_staging_txn, db_autocommit=True, ) - if row is None: - return None - - return row[0] else: - def _remove_received_event_from_staging_txn(txn): + def _remove_received_event_from_staging_txn( + txn: LoggingTransaction, + ) -> Optional[int]: received_ts = self.db_pool.simple_select_one_onecol_txn( txn, table="federation_inbound_events_staging", @@ -1437,7 +1485,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) -> Optional[Tuple[str, str]]: """Get the next event ID in the staging area for the given room.""" - def _get_next_staged_event_id_for_room_txn(txn): + def _get_next_staged_event_id_for_room_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, str]]: sql = """ SELECT origin, event_id FROM federation_inbound_events_staging @@ -1448,7 +1498,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id,)) - return txn.fetchone() + return cast(Optional[Tuple[str, str]], txn.fetchone()) return await self.db_pool.runInteraction( "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn @@ -1461,7 +1511,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) -> Optional[Tuple[str, EventBase]]: """Get the next event in the staging area for the given room.""" - def _get_next_staged_event_for_room_txn(txn): + def _get_next_staged_event_for_room_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, str, str]]: sql = """ SELECT event_json, internal_metadata, origin FROM federation_inbound_events_staging @@ -1471,7 +1523,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ txn.execute(sql, (room_id,)) - return txn.fetchone() + return cast(Optional[Tuple[str, str, str]], txn.fetchone()) row = await self.db_pool.runInteraction( "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn @@ -1599,18 +1651,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) @wrap_as_background_process("_get_stats_for_federation_staging") - async def _get_stats_for_federation_staging(self): + async def _get_stats_for_federation_staging(self) -> None: """Update the prometheus metrics for the inbound federation staging area.""" - def _get_stats_for_federation_staging_txn(txn): + def _get_stats_for_federation_staging_txn( + txn: LoggingTransaction, + ) -> Tuple[int, int]: txn.execute("SELECT count(*) FROM federation_inbound_events_staging") - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) txn.execute( "SELECT min(received_ts) FROM federation_inbound_events_staging" ) - (received_ts,) = txn.fetchone() + (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone()) # If there is nothing in the staging area default it to 0. age = 0 @@ -1651,19 +1705,21 @@ class EventFederationStore(EventFederationWorkerStore): self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) - async def clean_room_for_join(self, room_id): - return await self.db_pool.runInteraction( + async def clean_room_for_join(self, room_id: str) -> None: + await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) - def _clean_room_for_join_txn(self, txn, room_id): + def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None: query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - async def _background_delete_non_state_event_auth(self, progress, batch_size): - def delete_event_auth(txn): + async def _background_delete_non_state_event_auth( + self, progress: JsonDict, batch_size: int + ) -> int: + def delete_event_auth(txn: LoggingTransaction) -> bool: target_min_stream_id = progress.get("target_min_stream_id_inclusive") max_stream_id = progress.get("max_stream_id_exclusive") diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b7c4c622..b0199793 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -938,7 +938,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): users can still get a list of recent highlights. Args: - txn: The transcation + txn: The transaction room_id: Room ID to delete from user_id: user ID to delete for stream_ordering: The lowest stream ordering which will diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 2c86a870..17e35cf6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -36,9 +36,8 @@ from prometheus_client import Counter import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions -from synapse.crypto.event_signing import compute_event_reference_hash -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events import EventBase, relation_from_event +from synapse.events.snapshot import EventContext from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines.postgres import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator -from synapse.types import StateMap, get_domain_from_id +from synapse.types import JsonDict, StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.stringutils import non_null_str_or_none @@ -130,7 +129,6 @@ class PersistEventsStore: self, events_and_contexts: List[Tuple[EventBase, EventContext]], *, - current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremities: Dict[str, Set[str]], use_negative_stream_ordering: bool = False, @@ -141,8 +139,6 @@ class PersistEventsStore: Args: events_and_contexts: - current_state_for_room: Map from room_id to the current state of - the room based on forward extremities state_delta_for_room: Map from room_id to the delta to apply to room state new_forward_extremities: Map from room_id to set of event IDs @@ -217,9 +213,6 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, new_state in current_state_for_room.items(): - self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) @@ -237,7 +230,9 @@ class PersistEventsStore: """ results: List[str] = [] - def _get_events_which_are_prevs_txn(txn, batch): + def _get_events_which_are_prevs_txn( + txn: LoggingTransaction, batch: Collection[str] + ) -> None: sql = """ SELECT prev_event_id, internal_metadata FROM event_edges @@ -287,7 +282,9 @@ class PersistEventsStore: # and their prev events. existing_prevs = set() - def _get_prevs_before_rejected_txn(txn, batch): + def _get_prevs_before_rejected_txn( + txn: LoggingTransaction, batch: Collection[str] + ) -> None: to_recursively_check = batch while to_recursively_check: @@ -517,7 +514,7 @@ class PersistEventsStore: @classmethod def _add_chain_cover_index( cls, - txn, + txn: LoggingTransaction, db_pool: DatabasePool, event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], @@ -811,7 +808,7 @@ class PersistEventsStore: @staticmethod def _allocate_chain_ids( - txn, + txn: LoggingTransaction, db_pool: DatabasePool, event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], @@ -945,7 +942,7 @@ class PersistEventsStore: self, txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], - ): + ) -> None: """Persist the mapping from transaction IDs to event IDs (if defined).""" to_insert = [] @@ -999,7 +996,7 @@ class PersistEventsStore: txn: LoggingTransaction, state_delta_by_room: Dict[str, DeltaState], stream_id: int, - ): + ) -> None: for room_id, delta_state in state_delta_by_room.items(): to_delete = delta_state.to_delete to_insert = delta_state.to_insert @@ -1157,7 +1154,7 @@ class PersistEventsStore: txn, room_id, members_changed ) - def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): + def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state events. @@ -1191,7 +1188,7 @@ class PersistEventsStore: txn: LoggingTransaction, new_forward_extremities: Dict[str, Set[str]], max_stream_order: int, - ): + ) -> None: for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} @@ -1256,9 +1253,9 @@ class PersistEventsStore: def _update_room_depths_txn( self, - txn, + txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], - ): + ) -> None: """Update min_depth for each room Args: @@ -1387,7 +1384,7 @@ class PersistEventsStore: # nothing to do here return - def event_dict(event): + def event_dict(event: EventBase) -> JsonDict: d = event.get_dict() d.pop("redacted", None) d.pop("redacted_because", None) @@ -1478,18 +1475,20 @@ class PersistEventsStore: ), ) - def _store_rejected_events_txn(self, txn, events_and_contexts): + def _store_rejected_events_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> List[Tuple[EventBase, EventContext]]: """Add rows to the 'rejections' table for received events which were rejected Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting Returns: - list[(EventBase, EventContext)] new list, without the rejected - events. + new list, without the rejected events. """ # Remove the rejected events from the list now that we've added them # to the events table and the events_json table. @@ -1510,7 +1509,7 @@ class PersistEventsStore: events_and_contexts: List[Tuple[EventBase, EventContext]], all_events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, - ): + ) -> None: """Update all the miscellaneous tables for new events Args: @@ -1601,15 +1600,14 @@ class PersistEventsStore: inhibit_local_membership_updates=inhibit_local_membership_updates, ) - # Insert event_reference_hashes table. - self._store_event_reference_hashes_txn( - txn, [event for event, _ in events_and_contexts] - ) - # Prefill the event cache self._add_to_cache(txn, events_and_contexts) - def _add_to_cache(self, txn, events_and_contexts): + def _add_to_cache( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> None: to_prefill = [] rows = [] @@ -1640,7 +1638,7 @@ class PersistEventsStore: if not row["rejects"] and not row["redacts"]: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - def prefill(): + def prefill() -> None: for cache_entry in to_prefill: self.store._get_event_cache.set( (cache_entry.event.event_id,), cache_entry @@ -1670,19 +1668,24 @@ class PersistEventsStore: ) def insert_labels_for_event_txn( - self, txn, event_id, labels, room_id, topological_ordering - ): + self, + txn: LoggingTransaction, + event_id: str, + labels: List[str], + room_id: str, + topological_ordering: int, + ) -> None: """Store the mapping between an event's ID and its labels, with one row per (event_id, label) tuple. Args: - txn (LoggingTransaction): The transaction to execute. - event_id (str): The event's ID. - labels (list[str]): A list of text labels. - room_id (str): The ID of the room the event was sent to. - topological_ordering (int): The position of the event in the room's topology. + txn: The transaction to execute. + event_id: The event's ID. + labels: A list of text labels. + room_id: The ID of the room the event was sent to. + topological_ordering: The position of the event in the room's topology. """ - return self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", keys=("event_id", "label", "room_id", "topological_ordering"), @@ -1691,44 +1694,32 @@ class PersistEventsStore: ], ) - def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + def _insert_event_expiry_txn( + self, txn: LoggingTransaction, event_id: str, expiry_ts: int + ) -> None: """Save the expiry timestamp associated with a given event ID. Args: - txn (LoggingTransaction): The database transaction to use. - event_id (str): The event ID the expiry timestamp is associated with. - expiry_ts (int): The timestamp at which to expire (delete) the event. + txn: The database transaction to use. + event_id: The event ID the expiry timestamp is associated with. + expiry_ts: The timestamp at which to expire (delete) the event. """ - return self.db_pool.simple_insert_txn( + self.db_pool.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, ) - def _store_event_reference_hashes_txn(self, txn, events): - """Store a hash for a PDU - Args: - txn (cursor): - events (list): list of Events. - """ - - vals = [] - for event in events: - ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes))) - - self.db_pool.simple_insert_many_txn( - txn, - table="event_reference_hashes", - keys=("event_id", "algorithm", "hash"), - values=vals, - ) - def _store_room_members_txn( - self, txn, events, *, inhibit_local_membership_updates: bool = False - ): + self, + txn: LoggingTransaction, + events: List[EventBase], + *, + inhibit_local_membership_updates: bool = False, + ) -> None: """ Store a room member in the database. + Args: txn: The transaction to use. events: List of events to store. @@ -1765,6 +1756,7 @@ class PersistEventsStore: ) for event in events: + assert event.internal_metadata.stream_ordering is not None txn.call_after( self.store._membership_stream_cache.entity_has_changed, event.state_key, @@ -1813,55 +1805,54 @@ class PersistEventsStore: txn: The current database transaction. event: The event which might have relations. """ - relation = event.content.get("m.relates_to") + relation = relation_from_event(event) if not relation: - # No relations - return - - # Relations must have a type and parent event ID. - rel_type = relation.get("rel_type") - if not isinstance(rel_type, str): + # No relation, nothing to do. return - parent_id = relation.get("event_id") - if not isinstance(parent_id, str): - return - - # Annotations have a key field. - aggregation_key = None - if rel_type == RelationTypes.ANNOTATION: - aggregation_key = relation.get("key") - self.db_pool.simple_insert_txn( txn, table="event_relations", values={ "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, + "relates_to_id": relation.parent_id, + "relation_type": relation.rel_type, + "aggregation_key": relation.aggregation_key, }, ) - txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,)) txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate, (parent_id,) + self.store.get_relations_for_event.invalidate, (relation.parent_id,) + ) + txn.call_after( + self.store.get_aggregation_groups_for_event.invalidate, + (relation.parent_id,), + ) + txn.call_after( + self.store.get_mutual_event_relations_for_rel_type.invalidate, + (relation.parent_id,), ) - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + if relation.rel_type == RelationTypes.REPLACE: + txn.call_after( + self.store.get_applicable_edit.invalidate, (relation.parent_id,) + ) - if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + if relation.rel_type == RelationTypes.THREAD: + txn.call_after( + self.store.get_thread_summary.invalidate, (relation.parent_id,) + ) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and # potentially error-prone) so it is always invalidated. txn.call_after( self.store.get_thread_participated.invalidate, - (parent_id, event.sender), + (relation.parent_id, event.sender), ) - def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): + def _handle_insertion_event( + self, txn: LoggingTransaction, event: EventBase + ) -> None: """Handles keeping track of insertion events and edges/connections. Part of MSC2716. @@ -1922,7 +1913,7 @@ class PersistEventsStore: }, ) - def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase): + def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None: """Handles inserting the batch edges/connections between the batch event and an insertion event. Part of MSC2716. @@ -2017,30 +2008,39 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, + self.store.get_mutual_event_relations_for_rel_type, + (redacted_relates_to,), + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): self.store_event_search_txn( txn, event, "content.topic", event.content["topic"] ) - def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("name"), str): self.store_event_search_txn( txn, event, "content.name", event.content["name"] ) - def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_message_txn( + self, txn: LoggingTransaction, event: EventBase + ) -> None: if isinstance(event.content.get("body"), str): self.store_event_search_txn( txn, event, "content.body", event.content["body"] ) - def _store_retention_policy_for_room_txn(self, txn, event): + def _store_retention_policy_for_room_txn( + self, txn: LoggingTransaction, event: EventBase + ) -> None: if not event.is_state(): logger.debug("Ignoring non-state m.room.retention event") return @@ -2100,8 +2100,11 @@ class PersistEventsStore: ) def _set_push_actions_for_event_and_users_txn( - self, txn, events_and_contexts, all_events_and_contexts - ): + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + all_events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> None: """Handles moving push actions from staging table to main event_push_actions table for all events in `events_and_contexts`. @@ -2109,12 +2112,10 @@ class PersistEventsStore: from the push action staging area. Args: - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. + events_and_contexts: events we are persisting + all_events_and_contexts: all events that we were going to persist. + This includes events we've already persisted, etc, that wouldn't + appear in events_and_context. """ # Only non outlier events will have push actions associated with them, @@ -2183,7 +2184,9 @@ class PersistEventsStore: ), ) - def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): + def _remove_push_actions_for_event_id_txn( + self, txn: LoggingTransaction, room_id: str, event_id: str + ) -> None: # Sad that we have to blow away the cache for the whole room here txn.call_after( self.store.get_unread_event_push_actions_by_room_for_user.invalidate, @@ -2194,7 +2197,9 @@ class PersistEventsStore: (room_id, event_id), ) - def _store_rejections_txn(self, txn, event_id, reason): + def _store_rejections_txn( + self, txn: LoggingTransaction, event_id: str, reason: str + ) -> None: self.db_pool.simple_insert_txn( txn, table="rejections", @@ -2206,8 +2211,10 @@ class PersistEventsStore: ) def _store_event_state_mappings_txn( - self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] - ): + self, + txn: LoggingTransaction, + events_and_contexts: Collection[Tuple[EventBase, EventContext]], + ) -> None: state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): @@ -2264,7 +2271,9 @@ class PersistEventsStore: state_group_id, ) - def _update_min_depth_for_room_txn(self, txn, room_id, depth): + def _update_min_depth_for_room_txn( + self, txn: LoggingTransaction, room_id: str, depth: int + ) -> None: min_depth = self.store._get_min_depth_interaction(txn, room_id) if min_depth is not None and depth >= min_depth: @@ -2277,7 +2286,9 @@ class PersistEventsStore: values={"min_depth": depth}, ) - def _handle_mult_prev_events(self, txn, events): + def _handle_mult_prev_events( + self, txn: LoggingTransaction, events: List[EventBase] + ) -> None: """ For the given event, update the event edges table and forward and backward extremities tables. @@ -2295,7 +2306,9 @@ class PersistEventsStore: self._update_backward_extremeties(txn, events) - def _update_backward_extremeties(self, txn, events): + def _update_backward_extremeties( + self, txn: LoggingTransaction, events: List[EventBase] + ) -> None: """Updates the event_backward_extremities tables based on the new/updated events being persisted. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a4a604a4..b99b1077 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,6 +14,7 @@ import logging import threading +import weakref from enum import Enum, auto from typing import ( TYPE_CHECKING, @@ -23,6 +24,7 @@ from typing import ( Dict, Iterable, List, + MutableMapping, Optional, Set, Tuple, @@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore): str, ObservableDeferred[Dict[str, EventCacheEntry]] ] = {} + # We keep track of the events we have currently loaded in memory so that + # we can reuse them even if they've been evicted from the cache. We only + # track events that don't need redacting in here (as then we don't need + # to track redaction status). + self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary() + self._event_fetch_lock = threading.Condition() self._event_fetch_list: List[ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] @@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id: str) -> None: self._get_event_cache.invalidate((event_id,)) + self._event_ref.pop(event_id, None) + self._current_event_fetches.pop(event_id, None) def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True @@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore): event_map = {} for event_id in events: + # First check if it's in the event cache ret = self._get_event_cache.get( (event_id,), None, update_metrics=update_metrics ) - if not ret: + if ret: + event_map[event_id] = ret continue - event_map[event_id] = ret + # Otherwise check if we still have the event in memory. + event = self._event_ref.get(event_id) + if event: + # Reconstruct an event cache entry + + cache_entry = EventCacheEntry( + event=event, + # We don't cache weakrefs to redacted events, so we know + # this is None. + redacted_event=None, + ) + event_map[event_id] = cache_entry + + # We add the entry back into the cache as we want to keep + # recently queried events in the cache. + self._get_event_cache.set((event_id,), cache_entry) return event_map @@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry + if not redacted_event: + # We only cache references to unredacted events. + self._event_ref[event_id] = original_ev + return result_map async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: @@ -1325,14 +1356,23 @@ class EventsWorkerStore(SQLBaseStore): Returns: The set of events we have already seen. """ - res = await self._have_seen_events_dict( - (room_id, event_id) for event_id in event_ids - ) - return {eid for ((_rid, eid), have_event) in res.items() if have_event} + + # @cachedList chomps lots of memory if you call it with a big list, so + # we break it down. However, each batch requires its own index scan, so we make + # the batches as big as possible. + + results: Set[str] = set() + for chunk in batch_iter(event_ids, 500): + r = await self._have_seen_events_dict( + [(room_id, event_id) for event_id in chunk] + ) + results.update(eid for ((_rid, eid), have_event) in r.items() if have_event) + + return results @cachedList(cached_method_name="have_seen_event", list_name="keys") async def _have_seen_events_dict( - self, keys: Iterable[Tuple[str, str]] + self, keys: Collection[Tuple[str, str]] ) -> Dict[Tuple[str, str], bool]: """Helper for have_seen_events @@ -1344,11 +1384,12 @@ class EventsWorkerStore(SQLBaseStore): cache_results = { (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) } - results = {x: True for x in cache_results} + results = dict.fromkeys(cache_results, True) + remaining = [k for k in keys if k not in cache_results] + if not remaining: + return results - def have_seen_events_txn( - txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] - ) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> None: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1356,21 +1397,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk] + txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} - # ... and then we can update the results for each row in the batch - results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk}) - - # each batch requires its own index scan, so we make the batches as big as - # possible. - for chunk in batch_iter((k for k in keys if k not in cache_results), 500): - await self.db_pool.runInteraction( - "have_seen_events", have_seen_events_txn, chunk + # ... and then we can update the results for each key + results.update( + {(rid, eid): (eid in found_events) for (rid, eid) in remaining} ) + await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) return results @cached(max_entries=100000, tree=True) @@ -1891,6 +1928,18 @@ class EventsWorkerStore(SQLBaseStore): LIMIT 1 """ + # We consider any forward extremity as the latest in the room and + # not a forward gap. + # + # To expand, even though there is technically a gap at the front of + # the room where the forward extremities are, we consider those the + # latest messages in the room so asking other homeservers for more + # is useless. The new latest messages will just be federated as + # usual. + txn.execute(forward_extremity_query, (event.room_id, event.event_id)) + if txn.fetchone(): + return False + # Check to see whether the event in question is already referenced # by another event. If we don't see any edges, we're next to a # forward gap. @@ -1899,8 +1948,7 @@ class EventsWorkerStore(SQLBaseStore): /* Check to make sure the event referencing our event in question is not rejected */ LEFT JOIN rejections ON event_edges.event_id = rejections.event_id WHERE - event_edges.room_id = ? - AND event_edges.prev_event_id = ? + event_edges.prev_event_id = ? /* It's not a valid edge if the event referencing our event in * question is rejected. */ @@ -1908,25 +1956,11 @@ class EventsWorkerStore(SQLBaseStore): LIMIT 1 """ - # We consider any forward extremity as the latest in the room and - # not a forward gap. - # - # To expand, even though there is technically a gap at the front of - # the room where the forward extremities are, we consider those the - # latest messages in the room so asking other homeservers for more - # is useless. The new latest messages will just be federated as - # usual. - txn.execute(forward_extremity_query, (event.room_id, event.event_id)) - forward_extremities = txn.fetchall() - if len(forward_extremities): - return False - # If there are no forward edges to the event in question (another # event hasn't referenced this event in their prev_events), then we # assume there is a forward gap in the history. - txn.execute(forward_edge_query, (event.room_id, event.event_id)) - forward_edges = txn.fetchall() - if not len(forward_edges): + txn.execute(forward_edge_query, (event.event_id,)) + if not txn.fetchone(): return True return False diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 04efad9e..c15a7136 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -13,1417 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING -from typing_extensions import TypedDict - -from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import ( - DatabasePool, - LoggingDatabaseConnection, - LoggingTransaction, -) -from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection if TYPE_CHECKING: from synapse.server import HomeServer -# The category ID for the "default" category. We don't store as null in the -# database to avoid the fun of null != null -_DEFAULT_CATEGORY_ID = "" -_DEFAULT_ROLE_ID = "" - - -# A room in a group. -class _RoomInGroup(TypedDict): - room_id: str - is_public: bool - -class GroupServerWorkerStore(SQLBaseStore): +class GroupServerStore(SQLBaseStore): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): - database.updates.register_background_index_update( - update_name="local_group_updates_index", - index_name="local_group_updates_stream_id_index", - table="local_group_updates", - columns=("stream_id",), - unique=True, - ) + # Register a legacy groups background update as a no-op. + database.updates.register_noop_background_update("local_group_updates_index") super().__init__(database, db_conn, hs) - - async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( - table="groups", - keyvalues={"group_id": group_id}, - retcols=( - "name", - "short_description", - "long_description", - "avatar_url", - "is_public", - "join_policy", - ), - allow_none=True, - desc="get_group", - ) - - async def get_users_in_group( - self, group_id: str, include_private: bool = False - ) -> List[Dict[str, Any]]: - # TODO: Pagination - - keyvalues: JsonDict = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - return await self.db_pool.simple_select_list( - table="group_users", - keyvalues=keyvalues, - retcols=("user_id", "is_public", "is_admin"), - desc="get_users_in_group", - ) - - async def get_invited_users_in_group(self, group_id: str) -> List[str]: - # TODO: Pagination - - return await self.db_pool.simple_select_onecol( - table="group_invites", - keyvalues={"group_id": group_id}, - retcol="user_id", - desc="get_invited_users_in_group", - ) - - async def get_rooms_in_group( - self, group_id: str, include_private: bool = False - ) -> List[_RoomInGroup]: - """Retrieve the rooms that belong to a given group. Does not return rooms that - lack members. - - Args: - group_id: The ID of the group to query for rooms - include_private: Whether to return private rooms in results - - Returns: - A list of dictionaries, each in the form of: - - { - "room_id": "!a_room_id:example.com", # The ID of the room - "is_public": False # Whether this is a public room or not - } - """ - - # TODO: Pagination - - def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]: - sql = """ - SELECT room_id, is_public FROM group_rooms - WHERE group_id = ? - AND room_id IN ( - SELECT group_rooms.room_id FROM group_rooms - LEFT JOIN room_stats_current ON - group_rooms.room_id = room_stats_current.room_id - AND joined_members > 0 - AND local_users_in_room > 0 - LEFT JOIN rooms ON - group_rooms.room_id = rooms.room_id - AND (room_version <> '') = ? - ) - """ - args = [group_id, False] - - if not include_private: - sql += " AND is_public = ?" - args += [True] - - txn.execute(sql, args) - - return [ - {"room_id": room_id, "is_public": is_public} - for room_id, is_public in txn - ] - - return await self.db_pool.runInteraction( - "get_rooms_in_group", _get_rooms_in_group_txn - ) - - async def get_rooms_for_summary_by_category( - self, - group_id: str, - include_private: bool = False, - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - """Get the rooms and categories that should be included in a summary request - - Args: - group_id: The ID of the group to query the summary for - include_private: Whether to return private rooms in results - - Returns: - A tuple containing: - - * A list of dictionaries with the keys: - * "room_id": str, the room ID - * "is_public": bool, whether the room is public - * "category_id": str|None, the category ID if set, else None - * "order": int, the sort order of rooms - - * A dictionary with the key: - * category_id (str): a dictionary with the keys: - * "is_public": bool, whether the category is public - * "profile": str, the category profile - * "order": int, the sort order of rooms in this category - """ - - def _get_rooms_for_summary_txn( - txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - keyvalues: JsonDict = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT room_id, is_public, category_id, room_order - FROM group_summary_rooms - WHERE group_id = ? - AND room_id IN ( - SELECT group_rooms.room_id FROM group_rooms - LEFT JOIN room_stats_current ON - group_rooms.room_id = room_stats_current.room_id - AND joined_members > 0 - AND local_users_in_room > 0 - LEFT JOIN rooms ON - group_rooms.room_id = rooms.room_id - AND (room_version <> '') = ? - ) - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, False, True)) - else: - txn.execute(sql, (group_id, False)) - - rooms = [ - { - "room_id": row[0], - "is_public": row[1], - "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT category_id, is_public, profile, cat_order - FROM group_summary_room_categories - INNER JOIN group_room_categories USING (group_id, category_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - categories = { - row[0]: { - "is_public": row[1], - "profile": db_to_json(row[2]), - "order": row[3], - } - for row in txn - } - - return rooms, categories - - return await self.db_pool.runInteraction( - "get_rooms_for_summary", _get_rooms_for_summary_txn - ) - - async def get_group_categories(self, group_id: str) -> JsonDict: - rows = await self.db_pool.simple_select_list( - table="group_room_categories", - keyvalues={"group_id": group_id}, - retcols=("category_id", "is_public", "profile"), - desc="get_group_categories", - ) - - return { - row["category_id"]: { - "is_public": row["is_public"], - "profile": db_to_json(row["profile"]), - } - for row in rows - } - - async def get_group_category(self, group_id: str, category_id: str) -> JsonDict: - category = await self.db_pool.simple_select_one( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcols=("is_public", "profile"), - desc="get_group_category", - ) - - category["profile"] = db_to_json(category["profile"]) - - return category - - async def get_group_roles(self, group_id: str) -> JsonDict: - rows = await self.db_pool.simple_select_list( - table="group_roles", - keyvalues={"group_id": group_id}, - retcols=("role_id", "is_public", "profile"), - desc="get_group_roles", - ) - - return { - row["role_id"]: { - "is_public": row["is_public"], - "profile": db_to_json(row["profile"]), - } - for row in rows - } - - async def get_group_role(self, group_id: str, role_id: str) -> JsonDict: - role = await self.db_pool.simple_select_one( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcols=("is_public", "profile"), - desc="get_group_role", - ) - - role["profile"] = db_to_json(role["profile"]) - - return role - - async def get_local_groups_for_room(self, room_id: str) -> List[str]: - """Get all of the local group that contain a given room - Args: - room_id: The ID of a room - Returns: - A list of group ids containing this room - """ - return await self.db_pool.simple_select_onecol( - table="group_rooms", - keyvalues={"room_id": room_id}, - retcol="group_id", - desc="get_local_groups_for_room", - ) - - async def get_users_for_summary_by_role( - self, group_id: str, include_private: bool = False - ) -> Tuple[List[JsonDict], JsonDict]: - """Get the users and roles that should be included in a summary request - - Returns: - ([users], [roles]) - """ - - def _get_users_for_summary_txn( - txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], JsonDict]: - keyvalues: JsonDict = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT user_id, is_public, role_id, user_order - FROM group_summary_users - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - users = [ - { - "user_id": row[0], - "is_public": row[1], - "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT role_id, is_public, profile, role_order - FROM group_summary_roles - INNER JOIN group_roles USING (group_id, role_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - roles = { - row[0]: { - "is_public": row[1], - "profile": db_to_json(row[2]), - "order": row[3], - } - for row in txn - } - - return users, roles - - return await self.db_pool.runInteraction( - "get_users_for_summary_by_role", _get_users_for_summary_txn - ) - - async def is_user_in_group(self, user_id: str, group_id: str) -> bool: - result = await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="is_user_in_group", - ) - return bool(result) - - async def is_user_admin_in_group( - self, group_id: str, user_id: str - ) -> Optional[bool]: - return await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="is_admin", - allow_none=True, - desc="is_user_admin_in_group", - ) - - async def is_user_invited_to_local_group( - self, group_id: str, user_id: str - ) -> Optional[bool]: - """Has the group server invited a user?""" - return await self.db_pool.simple_select_one_onecol( - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - desc="is_user_invited_to_local_group", - allow_none=True, - ) - - async def get_users_membership_info_in_group( - self, group_id: str, user_id: str - ) -> JsonDict: - """Get a dict describing the membership of a user in a group. - - Example if joined: - - { - "membership": "join", - "is_public": True, - "is_privileged": False, - } - - Returns: - An empty dict if the user is not join/invite/etc - """ - - def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict: - row = self.db_pool.simple_select_one_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("is_admin", "is_public"), - allow_none=True, - ) - - if row: - return { - "membership": "join", - "is_public": row["is_public"], - "is_privileged": row["is_admin"], - } - - row = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - - if row: - return {"membership": "invite"} - - return {} - - return await self.db_pool.runInteraction( - "get_users_membership_info_in_group", _get_users_membership_in_group_txn - ) - - async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: - """Get all groups a user is publicising""" - return await self.db_pool.simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, - retcol="group_id", - desc="get_publicised_groups_for_user", - ) - - async def get_attestations_need_renewals( - self, valid_until_ms: int - ) -> List[Dict[str, Any]]: - """Get all attestations that need to be renewed until givent time""" - - def _get_attestations_need_renewals_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: - sql = """ - SELECT group_id, user_id FROM group_attestations_renewals - WHERE valid_until_ms <= ? - """ - txn.execute(sql, (valid_until_ms,)) - return self.db_pool.cursor_to_dict(txn) - - return await self.db_pool.runInteraction( - "get_attestations_need_renewals", _get_attestations_need_renewals_txn - ) - - async def get_remote_attestation( - self, group_id: str, user_id: str - ) -> Optional[JsonDict]: - """Get the attestation that proves the remote agrees that the user is - in the group. - """ - row = await self.db_pool.simple_select_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("valid_until_ms", "attestation_json"), - desc="get_remote_attestation", - allow_none=True, - ) - - now = int(self._clock.time_msec()) - if row and now < row["valid_until_ms"]: - return db_to_json(row["attestation_json"]) - - return None - - async def get_joined_groups(self, user_id: str) -> List[str]: - return await self.db_pool.simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join"}, - retcol="group_id", - desc="get_joined_groups", - ) - - async def get_all_groups_for_user( - self, user_id: str, now_token: int - ) -> List[JsonDict]: - def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: - sql = """ - SELECT group_id, type, membership, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND membership != 'leave' - AND stream_id <= ? - """ - txn.execute(sql, (user_id, now_token)) - return [ - { - "group_id": row[0], - "type": row[1], - "membership": row[2], - "content": db_to_json(row[3]), - } - for row in txn - ] - - return await self.db_pool.runInteraction( - "get_all_groups_for_user", _get_all_groups_for_user_txn - ) - - async def get_groups_changes_for_user( - self, user_id: str, from_token: int, to_token: int - ) -> List[JsonDict]: - has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined] - user_id, from_token - ) - if not has_changed: - return [] - - def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: - sql = """ - SELECT group_id, membership, type, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND ? < stream_id AND stream_id <= ? - """ - txn.execute(sql, (user_id, from_token, to_token)) - return [ - { - "group_id": group_id, - "membership": membership, - "type": gtype, - "content": db_to_json(content_json), - } - for group_id, membership, gtype, content_json in txn - ] - - return await self.db_pool.runInteraction( - "get_groups_changes_for_user", _get_groups_changes_for_user_txn - ) - - async def get_all_groups_changes( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for groups replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - last_id = int(last_id) - has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined] - - if not has_changed: - return [], current_id, False - - def _get_all_groups_changes_txn( - txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - sql = """ - SELECT stream_id, group_id, user_id, type, content - FROM local_group_updates - WHERE ? < stream_id AND stream_id <= ? - LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - updates = cast( - List[Tuple[int, tuple]], - [ - (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) - for stream_id, group_id, user_id, gtype, content_json in txn - ], - ) - - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db_pool.runInteraction( - "get_all_groups_changes", _get_all_groups_changes_txn - ) - - -class GroupServerStore(GroupServerWorkerStore): - async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: - """Set the join policy of a group. - - join_policy can be one of: - * "invite" - * "open" - """ - await self.db_pool.simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues={"join_policy": join_policy}, - desc="set_group_join_policy", - ) - - async def add_room_to_summary( - self, - group_id: str, - room_id: str, - category_id: Optional[str], - order: Optional[int], - is_public: Optional[bool], - ) -> None: - """Add (or update) room's entry in summary. - - Args: - group_id - room_id - category_id: If not None then adds the category to the end of - the summary if its not already there. - order: If not None inserts the room at that position, e.g. an order - of 1 will put the room first. Otherwise, the room gets added to - the end. - is_public - """ - await self.db_pool.runInteraction( - "add_room_to_summary", - self._add_room_to_summary_txn, - group_id, - room_id, - category_id, - order, - is_public, - ) - - def _add_room_to_summary_txn( - self, - txn: LoggingTransaction, - group_id: str, - room_id: str, - category_id: Optional[str], - order: Optional[int], - is_public: Optional[bool], - ) -> None: - """Add (or update) room's entry in summary. - - Args: - txn - group_id - room_id - category_id: If not None then adds the category to the end of - the summary if its not already there. - order: If not None inserts the room at that position, e.g. an order - of 1 will put the room first. Otherwise, the room gets added to - the end. - is_public - """ - room_in_group = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - retcol="room_id", - allow_none=True, - ) - if not room_in_group: - raise SynapseError(400, "room not in group") - - if category_id is None: - category_id = _DEFAULT_CATEGORY_ID - else: - cat_exists = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcol="group_id", - allow_none=True, - ) - if not cat_exists: - raise SynapseError(400, "Category doesn't exist") - - # TODO: Check category is part of summary already - cat_exists = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_summary_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcol="group_id", - allow_none=True, - ) - if not cat_exists: - # If not, add it with an order larger than all others - txn.execute( - """ - INSERT INTO group_summary_room_categories - (group_id, category_id, cat_order) - SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 - FROM group_summary_room_categories - WHERE group_id = ? AND category_id = ? - """, - (group_id, category_id, group_id, category_id), - ) - - existing = self.db_pool.simple_select_one_txn( - txn, - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "room_id": room_id, - "category_id": category_id, - }, - retcols=("room_order", "is_public"), - allow_none=True, - ) - - if order is not None: - # Shuffle other room orders that come after the given order - sql = """ - UPDATE group_summary_rooms SET room_order = room_order + 1 - WHERE group_id = ? AND category_id = ? AND room_order >= ? - """ - txn.execute(sql, (group_id, category_id, order)) - elif not existing: - sql = """ - SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms - WHERE group_id = ? AND category_id = ? - """ - txn.execute(sql, (group_id, category_id)) - (order,) = cast(Tuple[int], txn.fetchone()) - - if existing: - to_update = {} - if order is not None: - to_update["room_order"] = order - if is_public is not None: - to_update["is_public"] = is_public - self.db_pool.simple_update_txn( - txn, - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - }, - updatevalues=to_update, - ) - else: - if is_public is None: - is_public = True - - self.db_pool.simple_insert_txn( - txn, - table="group_summary_rooms", - values={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - "room_order": order, - "is_public": is_public, - }, - ) - - async def remove_room_from_summary( - self, group_id: str, room_id: str, category_id: Optional[str] - ) -> int: - if category_id is None: - category_id = _DEFAULT_CATEGORY_ID - - return await self.db_pool.simple_delete( - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - }, - desc="remove_room_from_summary", - ) - - async def upsert_group_category( - self, - group_id: str, - category_id: str, - profile: Optional[JsonDict], - is_public: Optional[bool], - ) -> None: - """Add/update room category for group""" - insertion_values: JsonDict = {} - update_values: JsonDict = {"category_id": category_id} # This cannot be empty - - if profile is None: - insertion_values["profile"] = "{}" - else: - update_values["profile"] = json_encoder.encode(profile) - - if is_public is None: - insertion_values["is_public"] = True - else: - update_values["is_public"] = is_public - - await self.db_pool.simple_upsert( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - values=update_values, - insertion_values=insertion_values, - desc="upsert_group_category", - ) - - async def remove_group_category(self, group_id: str, category_id: str) -> int: - return await self.db_pool.simple_delete( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - desc="remove_group_category", - ) - - async def upsert_group_role( - self, - group_id: str, - role_id: str, - profile: Optional[JsonDict], - is_public: Optional[bool], - ) -> None: - """Add/remove user role""" - insertion_values: JsonDict = {} - update_values: JsonDict = {"role_id": role_id} # This cannot be empty - - if profile is None: - insertion_values["profile"] = "{}" - else: - update_values["profile"] = json_encoder.encode(profile) - - if is_public is None: - insertion_values["is_public"] = True - else: - update_values["is_public"] = is_public - - await self.db_pool.simple_upsert( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - values=update_values, - insertion_values=insertion_values, - desc="upsert_group_role", - ) - - async def remove_group_role(self, group_id: str, role_id: str) -> int: - return await self.db_pool.simple_delete( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - desc="remove_group_role", - ) - - async def add_user_to_summary( - self, - group_id: str, - user_id: str, - role_id: Optional[str], - order: Optional[int], - is_public: Optional[bool], - ) -> None: - """Add (or update) user's entry in summary. - - Args: - group_id - user_id - role_id: If not None then adds the role to the end of the summary if - its not already there. - order: If not None inserts the user at that position, e.g. an order - of 1 will put the user first. Otherwise, the user gets added to - the end. - is_public - """ - await self.db_pool.runInteraction( - "add_user_to_summary", - self._add_user_to_summary_txn, - group_id, - user_id, - role_id, - order, - is_public, - ) - - def _add_user_to_summary_txn( - self, - txn: LoggingTransaction, - group_id: str, - user_id: str, - role_id: Optional[str], - order: Optional[int], - is_public: Optional[bool], - ) -> None: - """Add (or update) user's entry in summary. - - Args: - txn - group_id - user_id - role_id: If not None then adds the role to the end of the summary if - its not already there. - order: If not None inserts the user at that position, e.g. an order - of 1 will put the user first. Otherwise, the user gets added to - the end. - is_public - """ - user_in_group = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - if not user_in_group: - raise SynapseError(400, "user not in group") - - if role_id is None: - role_id = _DEFAULT_ROLE_ID - else: - role_exists = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcol="group_id", - allow_none=True, - ) - if not role_exists: - raise SynapseError(400, "Role doesn't exist") - - # TODO: Check role is part of the summary already - role_exists = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_summary_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcol="group_id", - allow_none=True, - ) - if not role_exists: - # If not, add it with an order larger than all others - txn.execute( - """ - INSERT INTO group_summary_roles - (group_id, role_id, role_order) - SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 - FROM group_summary_roles - WHERE group_id = ? AND role_id = ? - """, - (group_id, role_id, group_id, role_id), - ) - - existing = self.db_pool.simple_select_one_txn( - txn, - table="group_summary_users", - keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, - retcols=("user_order", "is_public"), - allow_none=True, - ) - - if order is not None: - # Shuffle other users orders that come after the given order - sql = """ - UPDATE group_summary_users SET user_order = user_order + 1 - WHERE group_id = ? AND role_id = ? AND user_order >= ? - """ - txn.execute(sql, (group_id, role_id, order)) - elif not existing: - sql = """ - SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users - WHERE group_id = ? AND role_id = ? - """ - txn.execute(sql, (group_id, role_id)) - (order,) = cast(Tuple[int], txn.fetchone()) - - if existing: - to_update = {} - if order is not None: - to_update["user_order"] = order - if is_public is not None: - to_update["is_public"] = is_public - self.db_pool.simple_update_txn( - txn, - table="group_summary_users", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - }, - updatevalues=to_update, - ) - else: - if is_public is None: - is_public = True - - self.db_pool.simple_insert_txn( - txn, - table="group_summary_users", - values={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - "user_order": order, - "is_public": is_public, - }, - ) - - async def remove_user_from_summary( - self, group_id: str, user_id: str, role_id: Optional[str] - ) -> int: - if role_id is None: - role_id = _DEFAULT_ROLE_ID - - return await self.db_pool.simple_delete( - table="group_summary_users", - keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, - desc="remove_user_from_summary", - ) - - async def add_group_invite(self, group_id: str, user_id: str) -> None: - """Record that the group server has invited a user""" - await self.db_pool.simple_insert( - table="group_invites", - values={"group_id": group_id, "user_id": user_id}, - desc="add_group_invite", - ) - - async def add_user_to_group( - self, - group_id: str, - user_id: str, - is_admin: bool = False, - is_public: bool = True, - local_attestation: Optional[dict] = None, - remote_attestation: Optional[dict] = None, - ) -> None: - """Add a user to the group server. - - Args: - group_id - user_id - is_admin - is_public - local_attestation: The attestation the GS created to give to the remote - server. Optional if the user and group are on the same server - remote_attestation: The attestation given to GS by remote server. - Optional if the user and group are on the same server - """ - - def _add_user_to_group_txn(txn: LoggingTransaction) -> None: - self.db_pool.simple_insert_txn( - txn, - table="group_users", - values={ - "group_id": group_id, - "user_id": user_id, - "is_admin": is_admin, - "is_public": is_public, - }, - ) - - self.db_pool.simple_delete_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - if local_attestation: - self.db_pool.simple_insert_txn( - txn, - table="group_attestations_renewals", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": local_attestation["valid_until_ms"], - }, - ) - if remote_attestation: - self.db_pool.simple_insert_txn( - txn, - table="group_attestations_remote", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json_encoder.encode(remote_attestation), - }, - ) - - await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) - - async def remove_user_from_group(self, group_id: str, user_id: str) -> None: - def _remove_user_from_group_txn(txn: LoggingTransaction) -> None: - self.db_pool.simple_delete_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_summary_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - await self.db_pool.runInteraction( - "remove_user_from_group", _remove_user_from_group_txn - ) - - async def add_room_to_group( - self, group_id: str, room_id: str, is_public: bool - ) -> None: - await self.db_pool.simple_insert( - table="group_rooms", - values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, - desc="add_room_to_group", - ) - - async def update_room_in_group_visibility( - self, group_id: str, room_id: str, is_public: bool - ) -> int: - return await self.db_pool.simple_update( - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - updatevalues={"is_public": is_public}, - desc="update_room_in_group_visibility", - ) - - async def remove_room_from_group(self, group_id: str, room_id: str) -> None: - def _remove_room_from_group_txn(txn: LoggingTransaction) -> None: - self.db_pool.simple_delete_txn( - txn, - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - ) - - self.db_pool.simple_delete_txn( - txn, - table="group_summary_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - ) - - await self.db_pool.runInteraction( - "remove_room_from_group", _remove_room_from_group_txn - ) - - async def update_group_publicity( - self, group_id: str, user_id: str, publicise: bool - ) -> None: - """Update whether the user is publicising their membership of the group""" - await self.db_pool.simple_update_one( - table="local_group_membership", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={"is_publicised": publicise}, - desc="update_group_publicity", - ) - - async def register_user_group_membership( - self, - group_id: str, - user_id: str, - membership: str, - is_admin: bool = False, - content: Optional[JsonDict] = None, - local_attestation: Optional[dict] = None, - remote_attestation: Optional[dict] = None, - is_publicised: bool = False, - ) -> int: - """Registers that a local user is a member of a (local or remote) group. - - Args: - group_id: The group the member is being added to. - user_id: THe user ID to add to the group. - membership: The type of group membership. - is_admin: Whether the user should be added as a group admin. - content: Content of the membership, e.g. includes the inviter - if the user has been invited. - local_attestation: If remote group then store the fact that we - have given out an attestation, else None. - remote_attestation: If remote group then store the remote - attestation from the group, else None. - is_publicised: Whether this should be publicised. - """ - - content = content or {} - - def _register_user_group_membership_txn( - txn: LoggingTransaction, next_id: int - ) -> int: - # TODO: Upsert? - self.db_pool.simple_delete_txn( - txn, - table="local_group_membership", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_insert_txn( - txn, - table="local_group_membership", - values={ - "group_id": group_id, - "user_id": user_id, - "is_admin": is_admin, - "membership": membership, - "is_publicised": is_publicised, - "content": json_encoder.encode(content), - }, - ) - - self.db_pool.simple_insert_txn( - txn, - table="local_group_updates", - values={ - "stream_id": next_id, - "group_id": group_id, - "user_id": user_id, - "type": "membership", - "content": json_encoder.encode( - {"membership": membership, "content": content} - ), - }, - ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined] - - # TODO: Insert profile to ensure it comes down stream if its a join. - - if membership == "join": - if local_attestation: - self.db_pool.simple_insert_txn( - txn, - table="group_attestations_renewals", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": local_attestation["valid_until_ms"], - }, - ) - if remote_attestation: - self.db_pool.simple_insert_txn( - txn, - table="group_attestations_remote", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json_encoder.encode(remote_attestation), - }, - ) - else: - self.db_pool.simple_delete_txn( - txn, - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - return next_id - - async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined] - res = await self.db_pool.runInteraction( - "register_user_group_membership", - _register_user_group_membership_txn, - next_id, - ) - return res - - async def create_group( - self, - group_id: str, - user_id: str, - name: str, - avatar_url: str, - short_description: str, - long_description: str, - ) -> None: - await self.db_pool.simple_insert( - table="groups", - values={ - "group_id": group_id, - "name": name, - "avatar_url": avatar_url, - "short_description": short_description, - "long_description": long_description, - "is_public": True, - }, - desc="create_group", - ) - - async def update_group_profile(self, group_id: str, profile: JsonDict) -> None: - await self.db_pool.simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues=profile, - desc="update_group_profile", - ) - - async def update_attestation_renewal( - self, group_id: str, user_id: str, attestation: dict - ) -> None: - """Update an attestation that we have renewed""" - await self.db_pool.simple_update_one( - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, - desc="update_attestation_renewal", - ) - - async def update_remote_attestion( - self, group_id: str, user_id: str, attestation: dict - ) -> None: - """Update an attestation that a remote has renewed""" - await self.db_pool.simple_update_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={ - "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json_encoder.encode(attestation), - }, - desc="update_remote_attestion", - ) - - async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: - """Remove an attestation that we thought we should renew, but actually - shouldn't. Ideally this would never get called as we would never - incorrectly try and do attestations for local users on local groups. - - Args: - group_id - user_id - """ - return await self.db_pool.simple_delete( - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - desc="remove_attestation_renewal", - ) - - def get_group_stream_token(self) -> int: - return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined] - - async def delete_group(self, group_id: str) -> None: - """Deletes a group fully from the database. - - Args: - group_id: The group ID to delete. - """ - - def _delete_group_txn(txn: LoggingTransaction) -> None: - tables = [ - "groups", - "group_users", - "group_invites", - "group_rooms", - "group_summary_rooms", - "group_summary_room_categories", - "group_room_categories", - "group_summary_users", - "group_summary_roles", - "group_roles", - "group_attestations_renewals", - "group_attestations_remote", - ] - - for table in tables: - self.db_pool.simple_delete_txn( - txn, table=table, keyvalues={"group_id": group_id} - ) - - await self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index bedacaf0..2d7633fb 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from types import TracebackType -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Set, Tuple, Type from weakref import WeakValueDictionary from twisted.internet.interfaces import IReactorCore @@ -84,6 +84,8 @@ class LockStore(SQLBaseStore): self._on_shutdown, ) + self._acquiring_locks: Set[Tuple[str, str]] = set() + @wrap_as_background_process("LockStore._on_shutdown") async def _on_shutdown(self) -> None: """Called when the server is shutting down""" @@ -103,6 +105,21 @@ class LockStore(SQLBaseStore): context manager if the lock is successfully acquired, which *must* be used (otherwise the lock will leak). """ + if (lock_name, lock_key) in self._acquiring_locks: + return None + try: + self._acquiring_locks.add((lock_name, lock_key)) + return await self._try_acquire_lock(lock_name, lock_key) + finally: + self._acquiring_locks.discard((lock_name, lock_key)) + + async def _try_acquire_lock( + self, lock_name: str, lock_key: str + ) -> Optional["Lock"]: + """Try to acquire a lock for the given name/key. Will return an async + context manager if the lock is successfully acquired, which *must* be + used (otherwise the lock will leak). + """ # Check if this process has taken out a lock and if it's still valid. lock = self._live_tokens.get((lock_name, lock_key)) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 40ac377c..d028be16 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -251,12 +251,36 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn ) - async def get_local_media_before( + async def get_local_media_ids( self, before_ts: int, size_gt: int, keep_profiles: bool, + include_quarantined_media: bool, + include_protected_media: bool, ) -> List[str]: + """ + Retrieve a list of media IDs from the local media store. + + Args: + before_ts: Only retrieve IDs from media that was either last accessed + (or if never accessed, created) before the given UNIX timestamp in ms. + size_gt: Only retrieve IDs from media that has a size (in bytes) greater than + the given integer. + keep_profiles: If True, exclude media IDs from the results that are used in the + following situations: + * global profile user avatar + * per-room profile user avatar + * room avatar + * a user's avatar in the user directory + include_quarantined_media: If False, exclude media IDs from the results that have + been marked as quarantined. + include_protected_media: If False, exclude media IDs from the results that have + been marked as protected from quarantine. + + Returns: + A list of local media IDs. + """ # to find files that have never been accessed (last_access_ts IS NULL) # compare with `created_ts` @@ -278,10 +302,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id) AND NOT EXISTS (SELECT 1 - FROM groups - WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id) - AND NOT EXISTS - (SELECT 1 FROM room_memberships WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id) AND NOT EXISTS @@ -298,12 +318,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) sql += sql_keep - def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]: + if include_quarantined_media is False: + # Do not include media that has been quarantined + sql += """ + AND quarantined_by IS NULL + """ + + if include_protected_media is False: + # Do not include media that has been protected from quarantine + sql += """ + AND NOT safe_from_quarantine + """ + + def _get_local_media_ids_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (before_ts, before_ts, size_gt)) return [row[0] for row in txn] return await self.db_pool.runInteraction( - "get_local_media_before", _get_local_media_before_txn + "get_local_media_ids", _get_local_media_ids_txn ) async def store_local_media( @@ -603,15 +635,37 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_remote_media_thumbnail", ) - async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]: + async def get_remote_media_ids( + self, before_ts: int, include_quarantined_media: bool + ) -> List[Dict[str, str]]: + """ + Retrieve a list of server name, media ID tuples from the remote media cache. + + Args: + before_ts: Only retrieve IDs from media that was either last accessed + (or if never accessed, created) before the given UNIX timestamp in ms. + include_quarantined_media: If False, exclude media IDs from the results that have + been marked as quarantined. + + Returns: + A list of tuples containing: + * The server name of homeserver where the media originates from, + * The ID of the media. + """ sql = ( "SELECT media_origin, media_id, filesystem_id" " FROM remote_media_cache" " WHERE last_access_ts < ?" ) + if include_quarantined_media is False: + # Only include media that has not been quarantined + sql += """ + AND quarantined_by IS NULL + """ + return await self.db_pool.execute( - "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts + "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts ) async def delete_remote_media(self, media_origin: str, media_id: str) -> None: diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 1480a0f0..14294a0b 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -14,12 +14,16 @@ import calendar import logging import time -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, List, Tuple, cast from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) @@ -71,8 +75,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): self._last_user_visit_update = self._get_start_of_day() @wrap_as_background_process("read_forward_extremities") - async def _read_forward_extremities(self): - def fetch(txn): + async def _read_forward_extremities(self) -> None: + def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]: txn.execute( """ SELECT t1.c, t2.c @@ -85,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): ) t2 ON t1.room_id = t2.room_id """ ) - return txn.fetchall() + return cast(List[Tuple[int, int]], txn.fetchall()) res = await self.db_pool.runInteraction("read_forward_extremities", fetch) @@ -95,7 +99,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (x[0] - 1) * x[1] for x in res if x[1] ) - async def count_daily_e2ee_messages(self): + async def count_daily_e2ee_messages(self) -> int: """ Returns an estimate of the number of messages sent in the last day. @@ -103,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): call to this function, it will return None. """ - def _count_messages(txn): + def _count_messages(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(*) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) - async def count_daily_sent_e2ee_messages(self): - def _count_messages(txn): + async def count_daily_sent_e2ee_messages(self) -> int: + def _count_messages(txn: LoggingTransaction) -> int: # This is good enough as if you have silly characters in your own # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname @@ -129,29 +133,29 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( "count_daily_sent_e2ee_messages", _count_messages ) - async def count_daily_active_e2ee_rooms(self): - def _count(txn): + async def count_daily_active_e2ee_rooms(self) -> int: + def _count(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( "count_daily_active_e2ee_rooms", _count ) - async def count_daily_messages(self): + async def count_daily_messages(self) -> int: """ Returns an estimate of the number of messages sent in the last day. @@ -159,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): call to this function, it will return None. """ - def _count_messages(txn): + def _count_messages(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(*) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_messages", _count_messages) - async def count_daily_sent_messages(self): - def _count_messages(txn): + async def count_daily_sent_messages(self) -> int: + def _count_messages(txn: LoggingTransaction) -> int: # This is good enough as if you have silly characters in your own # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname @@ -185,22 +189,22 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( "count_daily_sent_messages", _count_messages ) - async def count_daily_active_rooms(self): - def _count(txn): + async def count_daily_active_rooms(self) -> int: + def _count(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) @@ -226,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_monthly_users", self._count_users, thirty_days_ago ) - def _count_users(self, txn, time_from): + def _count_users(self, txn: LoggingTransaction, time_from: int) -> int: """ Returns number of users seen in the past time_from period """ @@ -238,7 +242,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): ) u """ txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() + # Mypy knows that fetchone() might return None if there are no rows. + # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always + # returns exactly one row. + (count,) = cast(Tuple[int], txn.fetchone()) return count async def count_r30_users(self) -> Dict[str, int]: @@ -252,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): A mapping of counts globally as well as broken out by platform. """ - def _count_r30_users(txn): + def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]: thirty_days_in_secs = 86400 * 30 now = int(self._clock.time()) thirty_days_ago_in_secs = now - thirty_days_in_secs @@ -317,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) results["all"] = count return results @@ -344,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): - "web" (any web application -- it's not possible to distinguish Element Web here) """ - def _count_r30v2_users(txn): + def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]: thirty_days_in_secs = 86400 * 30 now = int(self._clock.time()) sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs @@ -441,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): thirty_days_in_secs * 1000, ), ) - row = txn.fetchone() - if row is None: - results["all"] = 0 - else: - results["all"] = row[0] + (count,) = cast(Tuple[int], txn.fetchone()) + results["all"] = count return results @@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_r30v2_users", _count_r30v2_users ) - def _get_start_of_day(self): + def _get_start_of_day(self) -> int: """ Returns millisecond unixtime for start of UTC day. """ @@ -467,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): Generates daily visit data for use in cohort/ retention analysis """ - def _generate_user_daily_visits(txn): + def _generate_user_daily_visits(txn: LoggingTransaction) -> None: logger.info("Calling _generate_user_daily_visits") today_start = self._get_start_of_day() a_day_in_milliseconds = 24 * 60 * 60 * 1000 diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 5beb8f1d..9a63f953 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -122,6 +122,51 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): "count_users_by_service", _count_users_by_service ) + async def get_monthly_active_users_by_service( + self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None + ) -> List[Tuple[str, str]]: + """Generates list of monthly active users and their services. + Please see "get_monthly_active_count_by_service" docstring for more details + about services. + + Arguments: + start_timestamp: If specified, only include users that were first active + at or after this point + end_timestamp: If specified, only include users that were first active + at or before this point + + Returns: + A list of tuples (appservice_id, user_id). "native" is emitted as the + appservice for users that don't come from appservices (i.e. native Matrix + users). + + """ + if start_timestamp is not None and end_timestamp is not None: + where_clause = 'WHERE "timestamp" >= ? and "timestamp" <= ?' + query_params = [start_timestamp, end_timestamp] + elif start_timestamp is not None: + where_clause = 'WHERE "timestamp" >= ?' + query_params = [start_timestamp] + elif end_timestamp is not None: + where_clause = 'WHERE "timestamp" <= ?' + query_params = [end_timestamp] + else: + where_clause = "" + query_params = [] + + def _list_users(txn: LoggingTransaction) -> List[Tuple[str, str]]: + sql = f""" + SELECT COALESCE(appservice_id, 'native'), user_id + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + {where_clause}; + """ + + txn.execute(sql, query_params) + return cast(List[Tuple[str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction("list_users", _list_users) + async def get_registered_reserved_users(self) -> List[str]: """Of the reserved threepids defined in config, retrieve those that are associated with registered users diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index b47c5114..9769a18a 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -22,6 +22,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection from synapse.storage.util.id_generators import ( @@ -56,7 +57,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): ) -class PresenceStore(PresenceBackgroundUpdateStore): +class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore): def __init__( self, database: DatabasePool, @@ -281,20 +282,30 @@ class PresenceStore(PresenceBackgroundUpdateStore): True if the user should have full presence sent to them, False otherwise. """ - def _should_user_receive_full_presence_with_token_txn( - txn: LoggingTransaction, - ) -> bool: - sql = """ - SELECT 1 FROM users_to_send_full_presence_to - WHERE user_id = ? - AND presence_stream_id >= ? - """ - txn.execute(sql, (user_id, from_token)) - return bool(txn.fetchone()) + token = await self._get_full_presence_stream_token_for_user(user_id) + if token is None: + return False - return await self.db_pool.runInteraction( - "should_user_receive_full_presence_with_token", - _should_user_receive_full_presence_with_token_txn, + return from_token <= token + + @cached() + async def _get_full_presence_stream_token_for_user( + self, user_id: str + ) -> Optional[int]: + """Get the presence token corresponding to the last full presence update + for this user. + + If the user presents a sync token with a presence stream token at least + as old as the result, then we need to send them a full presence update. + + If this user has never needed a full presence update, returns `None`. + """ + return await self.db_pool.simple_select_one_onecol( + table="users_to_send_full_presence_to", + keyvalues={"user_id": user_id}, + retcol="presence_stream_id", + allow_none=True, + desc="_get_full_presence_stream_token_for_user", ) async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: @@ -307,18 +318,28 @@ class PresenceStore(PresenceBackgroundUpdateStore): # Add user entries to the table, updating the presence_stream_id column if the user already # exists in the table. presence_stream_id = self._presence_id_gen.get_current_token() - await self.db_pool.simple_upsert_many( - table="users_to_send_full_presence_to", - key_names=("user_id",), - key_values=[(user_id,) for user_id in user_ids], - value_names=("presence_stream_id",), - # We save the current presence stream ID token along with the user ID entry so - # that when a user /sync's, even if they syncing multiple times across separate - # devices at different times, each device will receive full presence once - when - # the presence stream ID in their sync token is less than the one in the table - # for their user ID. - value_values=[(presence_stream_id,) for _ in user_ids], - desc="add_users_to_send_full_presence_to", + + def _add_users_to_send_full_presence_to(txn: LoggingTransaction) -> None: + self.db_pool.simple_upsert_many_txn( + txn, + table="users_to_send_full_presence_to", + key_names=("user_id",), + key_values=[(user_id,) for user_id in user_ids], + value_names=("presence_stream_id",), + # We save the current presence stream ID token along with the user ID entry so + # that when a user /sync's, even if they syncing multiple times across separate + # devices at different times, each device will receive full presence once - when + # the presence stream ID in their sync token is less than the one in the table + # for their user ID. + value_values=[(presence_stream_id,) for _ in user_ids], + ) + for user_id in user_ids: + self._invalidate_cache_and_stream( + txn, self._get_full_presence_stream_token_for_user, (user_id,) + ) + + return await self.db_pool.runInteraction( + "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to ) async def get_presence_for_all_users( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index e197b720..a1747f04 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo @@ -55,17 +54,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) - async def get_from_remote_profile_cache( - self, user_id: str - ) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - retcols=("displayname", "avatar_url"), - allow_none=True, - desc="get_from_remote_profile_cache", - ) - async def create_profile(self, user_localpart: str) -> None: await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" @@ -91,97 +79,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="set_profile_avatar_url", ) - async def update_remote_profile_cache( - self, user_id: str, displayname: Optional[str], avatar_url: Optional[str] - ) -> int: - return await self.db_pool.simple_update( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - updatevalues={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="update_remote_profile_cache", - ) - - async def maybe_delete_remote_profile_cache(self, user_id: str) -> None: - """Check if we still care about the remote user's profile, and if we - don't then remove their profile from the cache - """ - subscribed = await self.is_subscribed_remote_profile_for_user(user_id) - if not subscribed: - await self.db_pool.simple_delete( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - desc="delete_remote_profile_cache", - ) - - async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool: - """Check whether we are interested in a remote user's profile.""" - res: Optional[str] = await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True - - res = await self.db_pool.simple_select_one_onecol( - table="group_invites", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True - return False - - async def get_remote_profile_cache_entries_that_expire( - self, last_checked: int - ) -> List[Dict[str, str]]: - """Get all users who haven't been checked since `last_checked`""" - - def _get_remote_profile_cache_entries_that_expire_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, str]]: - sql = """ - SELECT user_id, displayname, avatar_url - FROM remote_profile_cache - WHERE last_check < ? - """ - - txn.execute(sql, (last_checked,)) - - return self.db_pool.cursor_to_dict(txn) - - return await self.db_pool.runInteraction( - "get_remote_profile_cache_entries_that_expire", - _get_remote_profile_cache_entries_that_expire_txn, - ) - class ProfileStore(ProfileWorkerStore): - async def add_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str - ) -> None: - """Ensure we are caching the remote user's profiles. - - This should only be called when `is_subscribed_remote_profile_for_user` - would return true for the user. - """ - await self.db_pool.simple_upsert( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - values={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="add_remote_profile_cache", - ) + pass diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index bfc85b3a..ba385f9f 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # event_forward_extremities # event_json # event_push_actions - # event_reference_hashes # event_relations # event_search # event_to_state_groups @@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_auth", "event_edges", "event_forward_extremities", - "event_reference_hashes", "event_relations", "event_search", "rejections", @@ -324,12 +322,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: - # We *immediately* delete the room from the rooms table. This ensures - # that we don't race when persisting events (as that transaction checks - # that the room exists). - txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,)) - - # Next, we fetch all the state groups that should be deleted, before + # First, fetch all the state groups that should be deleted, before # we delete that information. txn.execute( """ @@ -369,7 +362,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_edges", "event_json", "event_push_actions_staging", - "event_reference_hashes", "event_relations", "event_to_state_groups", "event_auth_chains", @@ -390,7 +382,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): (room_id,), ) - # and finally, the tables with an index on room_id (or no useful index) + # next, the tables with an index on room_id (or no useful index) for table in ( "current_state_events", "destination_rooms", @@ -398,8 +390,12 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_forward_extremities", "event_push_actions", "event_search", + "partial_state_events", "events", - "group_rooms", + "federation_inbound_events_staging", + "local_current_membership", + "partial_state_rooms_servers", + "partial_state_rooms", "receipts_graph", "receipts_linearized", "room_aliases", @@ -416,10 +412,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "e2e_room_keys", "event_push_summary", "pusher_throttle", - "group_summary_rooms", "room_account_data", "room_tags", - "local_current_membership", + # "rooms" happens last, to keep the foreign keys in the other tables + # happy + "rooms", ): logger.info("[purge] removing %s from %s", room_id, table) txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 4ed913e2..d5aefe02 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -14,14 +14,18 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore @@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, AbstractStreamIdTracker, + IdGenerator, StreamIdGenerator, ) +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -54,10 +61,19 @@ def _is_experimental_rule_enabled( and not experimental_config.msc3786_enabled ): return False + if ( + rule_id == "global/underride/.org.matrix.msc3772.thread_reply" + and not experimental_config.msc3772_enabled + ): + return False return True -def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig): +def _load_rules( + rawrules: List[JsonDict], + enabled_map: Dict[str, bool], + experimental_config: ExperimentalConfig, +) -> List[JsonDict]: ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -137,7 +153,7 @@ class PushRulesWorkerStore( ) @abc.abstractmethod - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. Returns: @@ -146,7 +162,7 @@ class PushRulesWorkerStore( raise NotImplementedError() @cached(max_entries=5000) - async def get_push_rules_for_user(self, user_id): + async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]: rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, @@ -158,7 +174,7 @@ class PushRulesWorkerStore( "conditions", "actions", ), - desc="get_push_rules_enabled_for_user", + desc="get_push_rules_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) @@ -168,14 +184,14 @@ class PushRulesWorkerStore( return _load_rules(rows, enabled_map, self.hs.config.experimental) @cached(max_entries=5000) - async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]: + async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, - retcols=("user_name", "rule_id", "enabled"), + retcols=("rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) - return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} + return {r["rule_id"]: bool(r["enabled"]) for r in results} async def have_push_rules_changed_for_user( self, user_id: str, last_id: int @@ -184,29 +200,27 @@ class PushRulesWorkerStore( return False else: - def have_push_rules_changed_txn(txn): + def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool: sql = ( "SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return bool(count) return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) - @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - ) - async def bulk_get_push_rules(self, user_ids): + @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") + async def bulk_get_push_rules( + self, user_ids: Collection[str] + ) -> Dict[str, List[JsonDict]]: if not user_ids: return {} - results = {user_id: [] for user_id in user_ids} + results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules", @@ -230,67 +244,16 @@ class PushRulesWorkerStore( return results - async def copy_push_rule_from_room_to_room( - self, new_room_id: str, user_id: str, rule: dict - ) -> None: - """Copy a single push rule from one room to another for a specific user. - - Args: - new_room_id: ID of the new room. - user_id : ID of user the push rule belongs to. - rule: A push rule. - """ - # Create new rule id - rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) - new_rule_id = rule_id_scope + "/" + new_room_id - - # Change room id in each condition - for condition in rule.get("conditions", []): - if condition.get("key") == "room_id": - condition["pattern"] = new_room_id - - # Add the rule for the new room - await self.add_push_rule( - user_id=user_id, - rule_id=new_rule_id, - priority_class=rule["priority_class"], - conditions=rule["conditions"], - actions=rule["actions"], - ) - - async def copy_push_rules_from_room_to_room_for_user( - self, old_room_id: str, new_room_id: str, user_id: str - ) -> None: - """Copy all of the push rules from one room to another for a specific - user. - - Args: - old_room_id: ID of the old room. - new_room_id: ID of the new room. - user_id: ID of user to copy push rules for. - """ - # Retrieve push rules for this user - user_push_rules = await self.get_push_rules_for_user(user_id) - - # Get rules relating to the old room and copy them to the new room - for rule in user_push_rules: - conditions = rule.get("conditions", []) - if any( - (c.get("key") == "room_id" and c.get("pattern") == old_room_id) - for c in conditions - ): - await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) - @cachedList( - cached_method_name="get_push_rules_enabled_for_user", - list_name="user_ids", - num_args=1, + cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids" ) - async def bulk_get_push_rules_enabled(self, user_ids): + async def bulk_get_push_rules_enabled( + self, user_ids: Collection[str] + ) -> Dict[str, Dict[str, bool]]: if not user_ids: return {} - results = {user_id: {} for user_id in user_ids} + results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", @@ -306,7 +269,7 @@ class PushRulesWorkerStore( async def get_all_push_rule_updates( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: """Get updates for push_rules replication stream. Args: @@ -331,7 +294,9 @@ class PushRulesWorkerStore( if last_id == current_id: return [], current_id, False - def get_all_push_rule_updates_txn(txn): + def get_all_push_rule_updates_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: sql = """ SELECT stream_id, user_id FROM push_rules_stream @@ -340,7 +305,10 @@ class PushRulesWorkerStore( LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] + updates = cast( + List[Tuple[int, Tuple[str]]], + [(stream_id, (user_id,)) for stream_id, user_id in txn], + ) limited = False upper_bound = current_id @@ -356,15 +324,30 @@ class PushRulesWorkerStore( class PushRuleStore(PushRulesWorkerStore): + # Because we have write access, this will be a StreamIdGenerator + # (see PushRulesWorkerStore.__init__) + _push_rules_stream_id_gen: AbstractStreamIdGenerator + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + async def add_push_rule( self, - user_id, - rule_id, - priority_class, - conditions, - actions, - before=None, - after=None, + user_id: str, + rule_id: str, + priority_class: int, + conditions: List[Dict[str, str]], + actions: List[Union[JsonDict, str]], + before: Optional[str] = None, + after: Optional[str] = None, ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) @@ -400,17 +383,17 @@ class PushRuleStore(PushRulesWorkerStore): def _add_push_rule_relative_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + conditions_json: str, + actions_json: str, + before: str, + after: str, + ) -> None: # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. self.database_engine.lock_table(txn, "push_rules") @@ -470,15 +453,15 @@ class PushRuleStore(PushRulesWorkerStore): def _add_push_rule_highest_priority_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + conditions_json: str, + actions_json: str, + ) -> None: # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. self.database_engine.lock_table(txn, "push_rules") @@ -510,17 +493,17 @@ class PushRuleStore(PushRulesWorkerStore): def _upsert_push_rule_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - conditions_json, - actions_json, - update_stream=True, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + priority: int, + conditions_json: str, + actions_json: str, + update_stream: bool = True, + ) -> None: """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -600,7 +583,11 @@ class PushRuleStore(PushRulesWorkerStore): rule_id: The rule_id of the rule to be deleted """ - def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + def delete_push_rule_txn( + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + ) -> None: # we don't use simple_delete_one_txn because that would fail if the # user did not have a push_rule_enable row. self.db_pool.simple_delete_txn( @@ -661,14 +648,14 @@ class PushRuleStore(PushRulesWorkerStore): def _set_push_rule_enabled_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - enabled, - is_default_rule, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + enabled: bool, + is_default_rule: bool, + ) -> None: new_id = self._push_rules_enable_id_gen.get_next() if not is_default_rule: @@ -740,7 +727,11 @@ class PushRuleStore(PushRulesWorkerStore): """ actions_json = json_encoder.encode(actions) - def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): + def set_push_rule_actions_txn( + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + ) -> None: if is_default_rule: # Add a dummy rule to the rules table with the user specified # actions. @@ -794,8 +785,15 @@ class PushRuleStore(PushRulesWorkerStore): ) def _insert_push_rules_update_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None - ): + self, + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + op: str, + data: Optional[JsonDict] = None, + ) -> None: values = { "stream_id": stream_id, "event_stream_ordering": event_stream_ordering, @@ -814,5 +812,56 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: return self._push_rules_stream_id_gen.get_current_token() + + async def copy_push_rule_from_room_to_room( + self, new_room_id: str, user_id: str, rule: dict + ) -> None: + """Copy a single push rule from one room to another for a specific user. + + Args: + new_room_id: ID of the new room. + user_id : ID of user the push rule belongs to. + rule: A push rule. + """ + # Create new rule id + rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) + new_rule_id = rule_id_scope + "/" + new_room_id + + # Change room id in each condition + for condition in rule.get("conditions", []): + if condition.get("key") == "room_id": + condition["pattern"] = new_room_id + + # Add the rule for the new room + await self.add_push_rule( + user_id=user_id, + rule_id=new_rule_id, + priority_class=rule["priority_class"], + conditions=rule["conditions"], + actions=rule["actions"], + ) + + async def copy_push_rules_from_room_to_room_for_user( + self, old_room_id: str, new_room_id: str, user_id: str + ) -> None: + """Copy all of the push rules from one room to another for a specific + user. + + Args: + old_room_id: ID of the old room. + new_room_id: ID of the new room. + user_id: ID of user to copy push rules for. + """ + # Retrieve push rules for this user + user_push_rules = await self.get_push_rules_for_user(user_id) + + # Get rules relating to the old room and copy them to the new room + for rule in user_push_rules: + conditions = rule.get("conditions", []) + if any( + (c.get("key") == "room_id" and c.get("pattern") == old_room_id) + for c in conditions + ): + await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 91286c9b..bd0cfa7f 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -91,12 +91,6 @@ class PusherWorkerStore(SQLBaseStore): yield PusherConfig(**r) - async def user_has_pusher(self, user_id: str) -> bool: - ret = await self.db_pool.simple_select_one_onecol( - "pushers", {"user_name": user_id}, "id", allow_none=True - ) - return ret is not None - async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str ) -> Iterator[PusherConfig]: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index d035969a..21e954cc 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -26,7 +26,7 @@ from typing import ( cast, ) -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import EduTypes, ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -363,7 +363,7 @@ class ReceiptsWorkerStore(SQLBaseStore): row["user_id"] ] = db_to_json(row["data"]) - return [{"type": "m.receipt", "room_id": room_id, "content": content}] + return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] @cachedList( cached_method_name="_get_linearized_receipts_for_room", @@ -411,7 +411,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # receipts by room, event and type. room_event = results.setdefault( row["room_id"], - {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, ) # The content is of the form: @@ -476,7 +476,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # receipts by room, event and type. room_event = results.setdefault( row["room_id"], - {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, ) # The content is of the form: @@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) - def insert_linearized_receipt_txn( + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, room_id: str, @@ -673,8 +673,11 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) + # When updating a local users read receipt, remove any push actions + # which resulted from the receipt's event and all earlier events. if ( - receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) + self.hs.is_mine_id(user_id) + and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) and stream_ordering is not None ): self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] @@ -683,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore): return rx_ts + def _graph_to_linear( + self, txn: LoggingTransaction, room_id: str, event_ids: List[str] + ) -> str: + """ + Generate a linearized event from a list of events (i.e. a list of forward + extremities in the room). + + This should allow for calculation of the correct read receipt even if + servers have different event ordering. + + Args: + txn: The transaction + room_id: The room ID the events are in. + event_ids: The list of event IDs to linearize. + + Returns: + The linearized event ID. + """ + # TODO: Make this better. + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) + + txn.execute(sql, [room_id] + list(args)) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + async def insert_receipt( self, room_id: str, @@ -709,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore): linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. - # TODO: Make this better. - def graph_to_linear(txn: LoggingTransaction) -> str: - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", event_ids - ) - - sql = """ - SELECT event_id WHERE room_id = ? AND stream_ordering IN ( - SELECT max(stream_ordering) WHERE %s - ) - """ % ( - clause, - ) - - txn.execute(sql, [room_id] + list(args)) - rows = txn.fetchall() - if rows: - return rows[0][0] - else: - raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = await self.db_pool.runInteraction( - "insert_receipt_conv", graph_to_linear + "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", - self.insert_linearized_receipt_txn, + self._insert_linearized_receipt_txn, room_id, receipt_type, user_id, @@ -758,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore): now - event_ts, ) - await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) - - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id - - async def insert_graph_receipt( - self, - room_id: str, - receipt_type: str, - user_id: str, - event_ids: List[str], - data: JsonDict, - ) -> None: - assert self._can_write_to_receipts - await self.db_pool.runInteraction( "insert_graph_receipt", - self.insert_graph_receipt_txn, + self._insert_graph_receipt_txn, room_id, receipt_type, user_id, @@ -784,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - def insert_graph_receipt_txn( + max_persisted_id = self._receipts_id_gen.get_current_token() + + return stream_id, max_persisted_id + + def _insert_graph_receipt_txn( self, txn: LoggingTransaction, room_id: str, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 484976ca..b457bc18 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList logger = logging.getLogger(__name__) @@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore): if len(events) > limit and last_topo_id and last_stream_id: next_key = RoomStreamToken(last_topo_id, last_stream_id) if from_token: - next_token = from_token.copy_and_replace("room_key", next_key) + next_token = from_token.copy_and_replace( + StreamKeyType.ROOM, next_key + ) else: next_token = StreamToken( room_key=next_key, @@ -765,6 +767,59 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(iterable=True) + async def get_mutual_event_relations_for_rel_type( + self, event_id: str, relation_type: str + ) -> Set[Tuple[str, str]]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_mutual_event_relations_for_rel_type", + list_name="relation_types", + ) + async def get_mutual_event_relations( + self, event_id: str, relation_types: Collection[str] + ) -> Dict[str, Set[Tuple[str, str]]]: + """ + Fetch event metadata for events which related to the same event as the given event. + + If the given event has no relation information, returns an empty dictionary. + + Args: + event_id: The event ID which is targeted by relations. + relation_types: The relation types to check for mutual relations. + + Returns: + A dictionary of relation type to: + A set of tuples of: + The sender + The event type + """ + rel_type_sql, rel_type_args = make_in_list_sql_clause( + self.database_engine, "relation_type", relation_types + ) + + sql = f""" + SELECT DISTINCT relation_type, sender, type FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND {rel_type_sql} + """ + + def _get_event_relations( + txn: LoggingTransaction, + ) -> Dict[str, Set[Tuple[str, str]]]: + txn.execute(sql, [event_id] + rel_type_args) + result: Dict[str, Set[Tuple[str, str]]] = { + rel_type: set() for rel_type in relation_types + } + for rel_type, sender, type in txn.fetchall(): + result[rel_type].add((sender, type)) + return result + + return await self.db_pool.runInteraction( + "get_event_relations", _get_event_relations + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 87e9482c..68d4fc2e 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -23,6 +23,7 @@ from typing import ( Collection, Dict, List, + Mapping, Optional, Tuple, Union, @@ -45,7 +46,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor from synapse.storage.util.id_generators import IdGenerator -from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.stringutils import MXC_REGEX @@ -233,24 +234,23 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): UNION SELECT room_id from appservice_room_list """ - sql = """ + sql = f""" SELECT COUNT(*) FROM ( - %(published_sql)s + {published_sql} ) published INNER JOIN room_stats_state USING (room_id) INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + join_rules = '{JoinRules.PUBLIC}' + OR join_rules = '{JoinRules.KNOCK}' + OR join_rules = '{JoinRules.KNOCK_RESTRICTED}' OR history_visibility = 'world_readable' ) AND joined_members > 0 - """ % { - "published_sql": published_sql, - "knock_join_rule": JoinRules.KNOCK, - } + """ txn.execute(sql, query_args) return cast(Tuple[int], txn.fetchone())[0] @@ -369,29 +369,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): if where_clauses: where_clause = " AND " + " AND ".join(where_clauses) - sql = """ + dir = "DESC" if forwards else "ASC" + sql = f""" SELECT room_id, name, topic, canonical_alias, joined_members, avatar, history_visibility, guest_access, join_rules FROM ( - %(published_sql)s + {published_sql} ) published INNER JOIN room_stats_state USING (room_id) INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + join_rules = '{JoinRules.PUBLIC}' + OR join_rules = '{JoinRules.KNOCK}' + OR join_rules = '{JoinRules.KNOCK_RESTRICTED}' OR history_visibility = 'world_readable' ) AND joined_members > 0 - %(where_clause)s - ORDER BY joined_members %(dir)s, room_id %(dir)s - """ % { - "published_sql": published_sql, - "where_clause": where_clause, - "dir": "DESC" if forwards else "ASC", - "knock_join_rule": JoinRules.KNOCK, - } + {where_clause} + ORDER BY + joined_members {dir}, + room_id {dir} + """ if limit is not None: query_args.append(limit) @@ -699,7 +699,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) @cached() - async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]: + async def get_retention_policy_for_room(self, room_id: str) -> RetentionPolicy: """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -707,12 +707,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): the 'max_lifetime' if no default policy has been defined in the server's configuration). + If support for retention policies is disabled, a policy with a 'min_lifetime' and + 'max_lifetime' of None is returned. + Args: room_id: The ID of the room to get the retention policy of. Returns: A dict containing "min_lifetime" and "max_lifetime" for this room. """ + # If the room retention feature is disabled, return a policy with no minimum nor + # maximum. This prevents incorrectly filtering out events when sending to + # the client. + if not self.config.retention.retention_enabled: + return RetentionPolicy() def get_retention_policy_for_room_txn( txn: LoggingTransaction, @@ -736,10 +744,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - return { - "min_lifetime": self.config.retention.retention_default_min_lifetime, - "max_lifetime": self.config.retention.retention_default_max_lifetime, - } + return RetentionPolicy( + min_lifetime=self.config.retention.retention_default_min_lifetime, + max_lifetime=self.config.retention.retention_default_max_lifetime, + ) min_lifetime = ret[0]["min_lifetime"] max_lifetime = ret[0]["max_lifetime"] @@ -754,10 +762,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): if max_lifetime is None: max_lifetime = self.config.retention.retention_default_max_lifetime - return { - "min_lifetime": min_lifetime, - "max_lifetime": max_lifetime, - } + return RetentionPolicy( + min_lifetime=min_lifetime, + max_lifetime=max_lifetime, + ) async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room @@ -994,7 +1002,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False - ) -> Dict[str, Dict[str, Optional[int]]]: + ) -> Dict[str, RetentionPolicy]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. @@ -1016,7 +1024,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_rooms_for_retention_period_in_range_txn( txn: LoggingTransaction, - ) -> Dict[str, Dict[str, Optional[int]]]: + ) -> Dict[str, RetentionPolicy]: range_conditions = [] args = [] @@ -1047,10 +1055,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): rooms_dict = {} for row in rows: - rooms_dict[row["room_id"]] = { - "min_lifetime": row["min_lifetime"], - "max_lifetime": row["max_lifetime"], - } + rooms_dict[row["room_id"]] = RetentionPolicy( + min_lifetime=row["min_lifetime"], + max_lifetime=row["max_lifetime"], + ) if include_null: # If required, do a second query that retrieves all of the rooms we know @@ -1065,10 +1073,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): # policy in its state), add it with a null policy. for row in rows: if row["room_id"] not in rooms_dict: - rooms_dict[row["room_id"]] = { - "min_lifetime": None, - "max_lifetime": None, - } + rooms_dict[row["room_id"]] = RetentionPolicy() return rooms_dict @@ -1077,6 +1082,32 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_rooms_for_retention_period_in_range_txn, ) + async def get_partial_state_rooms_and_servers( + self, + ) -> Mapping[str, Collection[str]]: + """Get all rooms containing events with partial state, and the servers known + to be in the room. + + Returns: + A dictionary of rooms with partial state, with room IDs as keys and + lists of servers in rooms as values. + """ + room_servers: Dict[str, List[str]] = {} + + rows = await self.db_pool.simple_select_list( + "partial_state_rooms_servers", + keyvalues=None, + retcols=("room_id", "server_name"), + desc="get_partial_state_rooms", + ) + + for row in rows: + room_id = row["room_id"] + server_name = row["server_name"] + room_servers.setdefault(room_id, []).append(server_name) + + return room_servers + async def clear_partial_state_room(self, room_id: str) -> bool: # this can race with incoming events, so we watch out for FK errors. # TODO(faster_joins): this still doesn't completely fix the race, since the persist process @@ -1108,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): keyvalues={"room_id": room_id}, ) + async def is_partial_state_room(self, room_id: str) -> bool: + """Checks if this room has partial state. + + Returns true if this is a "partial-state" room, which means that the state + at events in the room, and `current_state_events`, may not yet be + complete. + """ + + entry = await self.db_pool.simple_select_one_onecol( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcol="room_id", + allow_none=True, + desc="is_partial_state_room", + ) + + return entry is not None + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 48e83592..31bc8c56 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,6 +15,7 @@ import logging from typing import ( TYPE_CHECKING, + Callable, Collection, Dict, FrozenSet, @@ -37,7 +38,12 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -46,7 +52,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import PersistedEventPosition, get_domain_from_id +from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -115,7 +121,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @wrap_as_background_process("_count_known_servers") - async def _count_known_servers(self): + async def _count_known_servers(self) -> int: """ Count the servers that this server knows about. @@ -123,7 +129,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): `synapse_federation_known_servers` LaterGauge to collect. """ - def _transact(txn): + def _transact(txn: LoggingTransaction) -> int: if isinstance(self.database_engine, Sqlite3Engine): query = """ SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) @@ -150,7 +156,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._known_servers_count = max([count, 1]) return self._known_servers_count - def _check_safe_current_state_events_membership_updated_txn(self, txn): + def _check_safe_current_state_events_membership_updated_txn( + self, txn: LoggingTransaction + ) -> None: """Checks if it is safe to assume the new current_state_events membership column is up to date """ @@ -182,7 +190,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): "get_users_in_room", self.get_users_in_room_txn, room_id ) - def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: + def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]: # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. @@ -222,7 +230,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): A mapping from user ID to ProfileInfo. """ - def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]: + def _get_users_in_room_with_profiles( + txn: LoggingTransaction, + ) -> Dict[str, ProfileInfo]: sql = """ SELECT state_key, display_name, avatar_url FROM room_memberships as m INNER JOIN current_state_events as c @@ -250,7 +260,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): dict of membership states, pointing to a MemberSummary named tuple. """ - def _get_room_summary_txn(txn): + def _get_room_summary_txn( + txn: LoggingTransaction, + ) -> Dict[str, MemberSummary]: # first get counts. # We do this all in one transaction to keep the cache small. # FIXME: get rid of this when we have room_stats @@ -279,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (room_id,)) - res = {} + res: Dict[str, MemberSummary] = {} for count, membership in txn: res.setdefault(membership, MemberSummary([], count)) @@ -400,7 +412,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): def _get_rooms_for_local_user_where_membership_is_txn( self, - txn, + txn: LoggingTransaction, user_id: str, membership_list: List[str], ) -> List[RoomsForUser]: @@ -488,7 +500,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) def _get_rooms_for_user_with_stream_ordering_txn( - self, txn, user_id: str + self, txn: LoggingTransaction, user_id: str ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called @@ -542,7 +554,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) def _get_rooms_for_users_with_stream_ordering_txn( - self, txn, user_ids: Collection[str] + self, txn: LoggingTransaction, user_ids: Collection[str] ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: clause, args = make_in_list_sql_clause( @@ -575,7 +587,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, [Membership.JOIN] + args) - result = {user_id: set() for user_id in user_ids} + result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { + user_id: set() for user_id in user_ids + } for user_id, room_id, instance, stream_id in txn: result[user_id].add( GetRoomsForUserWithStreamOrdering( @@ -595,7 +609,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): if not user_ids: return set() - def _get_users_server_still_shares_room_with_txn(txn): + def _get_users_server_still_shares_room_with_txn( + txn: LoggingTransaction, + ) -> Set[str]: sql = """ SELECT state_key FROM current_state_events WHERE @@ -619,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) async def get_rooms_for_user( - self, user_id: str, on_invalidate=None + self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None ) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. @@ -654,10 +670,34 @@ class RoomMemberWorkerStore(EventsWorkerStore): return user_who_share_room + @cached(cache_context=True, iterable=True) + async def get_mutual_rooms_between_users( + self, user_ids: FrozenSet[str], cache_context: _CacheContext + ) -> FrozenSet[str]: + """ + Returns the set of rooms that all users in `user_ids` share. + + Args: + user_ids: A frozen set of all users to investigate and return + overlapping joined rooms for. + cache_context + """ + shared_room_ids: Optional[FrozenSet[str]] = None + for user_id in user_ids: + room_ids = await self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + if shared_room_ids is not None: + shared_room_ids &= room_ids + else: + shared_room_ids = room_ids + + return shared_room_ids or frozenset() + async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: - state_group = context.state_group + state_group: Union[object, int] = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -666,14 +706,16 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() current_state_ids = await context.get_current_state_ids() + assert current_state_ids is not None + assert state_group is not None return await self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) async def get_joined_users_from_state( - self, room_id, state_entry + self, room_id: str, state_entry: "_StateCacheEntry" ) -> Dict[str, ProfileInfo]: - state_group = state_entry.state_group + state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -681,6 +723,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() + assert state_group is not None with Measure(self._clock, "get_joined_users_from_state"): return await self._get_joined_users_from_context( room_id, state_group, state_entry.state, context=state_entry @@ -689,12 +732,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) async def _get_joined_users_from_context( self, - room_id, - state_group, - current_state_ids, - cache_context, - event=None, - context=None, + room_id: str, + state_group: Union[object, int], + current_state_ids: StateMap[str], + cache_context: _CacheContext, + event: Optional[EventBase] = None, + context: Optional[Union[EventContext, "_StateCacheEntry"]] = None, ) -> Dict[str, ProfileInfo]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states @@ -765,14 +808,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): return users_in_room @cached(max_entries=10000) - def _get_joined_profile_from_event_id(self, event_id): + def _get_joined_profile_from_event_id( + self, event_id: str + ) -> Optional[Tuple[str, ProfileInfo]]: raise NotImplementedError() @cachedList( cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", ) - async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): + async def _get_joined_profiles_from_event_ids( + self, event_ids: Iterable[str] + ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. @@ -780,8 +827,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): event_ids: The member event IDs to lookup Returns: - dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID - to `user_id` and ProfileInfo (or None if not join event). + Map from event ID to `user_id` and ProfileInfo (or None if not join event). """ rows = await self.db_pool.simple_select_many_batch( @@ -847,8 +893,47 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True - async def get_joined_hosts(self, room_id: str, state_entry): - state_group = state_entry.state_group + @cached(iterable=True, max_entries=10000) + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + """Get current hosts in room based on current state.""" + + # First we check if we already have `get_users_in_room` in the cache, as + # we can just calculate result from that + users = self.get_users_in_room.cache.get_immediate( + (room_id,), None, update_metrics=False + ) + if users is not None: + return {get_domain_from_id(u) for u in users} + + if isinstance(self.database_engine, Sqlite3Engine): + # If we're using SQLite then let's just always use + # `get_users_in_room` rather than funky SQL. + users = await self.get_users_in_room(room_id) + return {get_domain_from_id(u) for u in users} + + # For PostgreSQL we can use a regex to pull out the domains from the + # joined users in `current_state_events` via regex. + + def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ + SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$') + FROM current_state_events + WHERE + type = 'm.room.member' + AND membership = 'join' + AND room_id = ? + """ + txn.execute(sql, (room_id,)) + return {d for d, in txn} + + return await self.db_pool.runInteraction( + "get_current_hosts_in_room", get_current_hosts_in_room_txn + ) + + async def get_joined_hosts( + self, room_id: str, state_entry: "_StateCacheEntry" + ) -> FrozenSet[str]: + state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -856,6 +941,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() + assert state_group is not None with Measure(self._clock, "get_joined_hosts"): return await self._get_joined_hosts( room_id, state_group, state_entry=state_entry @@ -863,7 +949,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, room_id: str, state_group: int, state_entry: "_StateCacheEntry" + self, + room_id: str, + state_group: Union[object, int], + state_entry: "_StateCacheEntry", ) -> FrozenSet[str]: # We don't use `state_group`, it's there so that we can cache based on # it. However, its important that its never None, since two @@ -881,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # `get_joined_hosts` is called with the "current" state group for the # room, and so consecutive calls will be for consecutive state groups # which point to the previous state group. - cache = await self._get_joined_hosts_cache(room_id) + cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc] # If the state group in the cache matches, we already have the data we need. if state_entry.state_group == cache.state_group: @@ -897,6 +986,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): elif state_entry.prev_group == cache.state_group: # The cached work is for the previous state group, so we work out # the delta. + assert state_entry.delta_ids is not None for (typ, state_key), event_id in state_entry.delta_ids.items(): if typ != EventTypes.Member: continue @@ -942,7 +1032,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns False if they have since re-joined.""" - def f(txn): + def f(txn: LoggingTransaction) -> int: sql = ( "SELECT" " COUNT(*)" @@ -973,7 +1063,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): The forgotten rooms. """ - def _get_forgotten_rooms_for_user_txn(txn): + def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]: # This is a slightly convoluted query that first looks up all rooms # that the user has forgotten in the past, then rechecks that list # to see if any have subsequently been updated. This is done so that @@ -1076,7 +1166,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): clause, ) - def _is_local_host_in_room_ignoring_users_txn(txn): + def _is_local_host_in_room_ignoring_users_txn( + txn: LoggingTransaction, + ) -> bool: txn.execute(sql, (room_id, Membership.JOIN, *args)) return bool(txn.fetchone()) @@ -1110,15 +1202,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): where_clause="forgotten = 1", ) - async def _background_add_membership_profile(self, progress, batch_size): + async def _background_add_membership_profile( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start + "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] ) max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 + "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] ) - def add_membership_profile_txn(txn): + def add_membership_profile_txn(txn: LoggingTransaction) -> int: sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events @@ -1182,13 +1276,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): return result - async def _background_current_state_membership(self, progress, batch_size): + async def _background_current_state_membership( + self, progress: JsonDict, batch_size: int + ) -> int: """Update the new membership column on current_state_events. This works by iterating over all rooms in alphebetical order. """ - def _background_current_state_membership_txn(txn, last_processed_room): + def _background_current_state_membership_txn( + txn: LoggingTransaction, last_processed_room: str + ) -> Tuple[int, bool]: processed = 0 while processed < batch_size: txn.execute( @@ -1242,7 +1340,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): return row_count -class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): +class RoomMemberStore( + RoomMemberWorkerStore, + RoomMemberBackgroundUpdateStore, + CacheInvalidationWorkerStore, +): def __init__( self, database: DatabasePool, @@ -1254,7 +1356,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" - def f(txn): + def f(txn: LoggingTransaction) -> None: sql = ( "UPDATE" " room_memberships" @@ -1288,5 +1390,5 @@ class _JoinedHostsCache: # equal to anything else). state_group: Union[object, int] = attr.Factory(object) - def __len__(self): + def __len__(self) -> int: return sum(len(v) for v in self.hosts_to_joined_users.values()) diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 3c49e7ec..78e0773b 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple import attr @@ -27,7 +27,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings ) - async def _background_reindex_search(self, progress, batch_size): + async def _background_reindex_search( + self, progress: JsonDict, batch_size: int + ) -> int: # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): TYPES = ["m.room.name", "m.room.message", "m.room.topic"] - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id, room_id, type, json, " " origin_server_ts FROM events" @@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return result - async def _background_reindex_gin_search(self, progress, batch_size): + async def _background_reindex_gin_search( + self, progress: JsonDict, batch_size: int + ) -> int: """This handles old synapses which used GIST indexes, if any; converting them back to be GIN as per the actual schema. """ - def create_index(conn): + def create_index(conn: LoggingDatabaseConnection) -> None: conn.rollback() # we have to set autocommit, because postgres refuses to @@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ) return 1 - async def _background_reindex_search_order(self, progress, batch_size): + async def _background_reindex_search_order( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): if not have_added_index: - def create_index(conn): + def create_index(conn: LoggingDatabaseConnection) -> None: conn.rollback() conn.set_session(autocommit=True) c = conn.cursor() @@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): pg, ) - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]: sql = ( "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," " origin_server_ts = e.origin_server_ts" @@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore): else: raise Exception("Unrecognized database engine") - args.append(limit) + # mypy expects to append only a `str`, not an `int` + args.append(limit) # type: ignore[arg-type] results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args @@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore): A set of strings. """ - def f(txn): + def f(txn: LoggingTransaction) -> Set[str]: highlight_words = set() for event in events: # As a hack we simply join values of all possible keys. This is @@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore): return await self.db_pool.runInteraction("_find_highlights", f) -def _to_postgres_options(options_dict): +def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -def _parse_query(database_engine, search_term): +def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. We use this so that we can add prefix matching, which isn't something diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 18ae8aee..bdd00273 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -16,6 +16,8 @@ import collections.abc import logging from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple +import attr + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -26,6 +28,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore @@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -43,6 +47,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: Optional[str] + rejection_reason: Optional[str] + + def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: @@ -133,6 +147,57 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return room_version + async def get_metadata_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, EventMetadata]: + """Get some metadata (room_id, type, state_key) for the given events. + + This method is a faster alternative than fetching the full events from + the DB, and should be used when the full event is not needed. + + Returns metadata for rejected and redacted events. Events that have not + been persisted are omitted from the returned dict. + """ + + def get_metadata_for_events_txn( + txn: LoggingTransaction, + batch_ids: Collection[str], + ) -> Dict[str, EventMetadata]: + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", batch_ids + ) + + sql = f""" + SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason + FROM events AS e + LEFT JOIN state_events se USING (event_id) + LEFT JOIN rejections r USING (event_id) + WHERE {clause} + """ + + txn.execute(sql, args) + return { + event_id: EventMetadata( + room_id=room_id, + event_type=event_type, + state_key=state_key, + rejection_reason=rejection_reason, + ) + for event_id, room_id, event_type, state_key, rejection_reason in txn + } + + result_map: Dict[str, EventMetadata] = {} + for batch_ids in batch_iter(event_ids, 1000): + result_map.update( + await self.db_pool.runInteraction( + "get_metadata_for_events", + get_metadata_for_events_txn, + batch_ids=batch_ids, + ) + ) + + return result_map + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. @@ -177,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Raises: NotFoundError if the room is unknown """ - state_ids = await self.get_current_state_ids(room_id) + state_ids = await self.get_partial_current_state_ids(room_id) if not state_ids: raise NotFoundError(f"Current state for room {room_id} is empty") @@ -193,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=100000, iterable=True) - async def get_current_state_ids(self, room_id: str) -> StateMap[str]: + async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. + This may be the partial state if we're lazy joining the room. + Args: room_id: The room to get the state IDs of. @@ -215,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return await self.db_pool.runInteraction( - "get_current_state_ids", _get_current_state_ids_txn + "get_partial_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - async def get_filtered_current_state_ids( + async def get_partial_filtered_current_state_ids( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + This may be the partial state if we're lazy joining the room. + Args: room_id state_filter: The state filter used to fetch state @@ -241,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not where_clause: # We delegate to the cached version - return await self.get_current_state_ids(room_id) + return await self.get_partial_current_state_ids(room_id) def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, @@ -269,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: - """Get canonical alias for room, if any - - Args: - room_id: The room ID - - Returns: - The canonical alias, if any - """ - - state = await self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return None - - event = await self.get_event(event_id, allow_none=True) - if not event: - return None - - return event.content.get("canonical_alias") - @cached(max_entries=50000) async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 188afec3..445213e1 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore): # attribute. TODO: can we get static analysis to enforce this? _curr_state_delta_stream_cache: StreamChangeCache - async def get_current_state_deltas( + async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id @@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore): - prev_event_id (str|None): previous event_id for this state key. None if it's new state. + This may be the partial state if we're lazy joining the room. + Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 0373af86..8e88784d 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -765,15 +765,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self, room_id: str, end_token: RoomStreamToken, - ) -> Optional[EventBase]: - """Returns the last event in a room at or before a stream ordering + ) -> Optional[str]: + """Returns the ID of the last event in a room at or before a stream ordering Args: room_id end_token: The token used to stream from Returns: - The most recent event. + The ID of the most recent event, or None if there are no events in the room + before this stream ordering. """ last_row = await self.get_room_event_before_stream_ordering( @@ -781,37 +782,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): stream_ordering=end_token.stream, ) if last_row: - _, _, event_id = last_row - event = await self.get_event(event_id, get_prev_content=True) - return event - + return last_row[2] return None async def get_current_room_stream_token_for_room_id( - self, room_id: Optional[str] = None + self, room_id: str ) -> RoomStreamToken: - """Returns the current position of the rooms stream. - - By default, it returns a live token with the current global stream - token. Specifying a `room_id` causes it to return a historic token with - the room specific topological token. - """ + """Returns the current position of the rooms stream (historic token).""" stream_ordering = self.get_room_max_stream_ordering() - if room_id is None: - return RoomStreamToken(None, stream_ordering) - else: - topo = await self.db_pool.runInteraction( - "_get_max_topological_txn", self._get_max_topological_txn, room_id - ) - return RoomStreamToken(topo, stream_ordering) + topo = await self.db_pool.runInteraction( + "_get_max_topological_txn", self._get_max_topological_txn, room_id + ) + return RoomStreamToken(topo, stream_ordering) def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, - allow_none=False, - ) -> int: - return self.db_pool.simple_select_one_onecol_txn( + allow_none: bool = False, + ) -> Optional[int]: + # Type ignore: we pass keyvalues a Dict[str, str]; the function wants + # Dict[str, Any]. I think mypy is unhappy because Dict is invariant? + return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload] txn=txn, table="events", keyvalues={"event_id": event_id}, @@ -873,7 +865,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) rows = txn.fetchall() - return rows[0][0] if rows else 0 + # An aggregate function like MAX() will always return one row per group + # so we can safely rely on the lookup here. For example, when a we + # lookup a `room_id` which does not exist, `rows` will look like + # `[(None,)]` + return rows[0][0] if rows[0][0] is not None else 0 @staticmethod def _set_before_and_after( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 028db69a..ddb25b5c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] + # Getting the partial state is fine, as we're not looking at membership + # events. + current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) @@ -729,49 +731,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_mutual_rooms_for_users( - self, user_id: str, other_user_id: str - ) -> Set[str]: - """ - Returns the rooms that a local user shares with another local or remote user. - - Args: - user_id: The MXID of a local user - other_user_id: The MXID of the other user - - Returns: - A set of room ID's that the users share. - """ - - def _get_mutual_rooms_for_users_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, str]]: - txn.execute( - """ - SELECT p1.room_id - FROM users_in_public_rooms as p1 - INNER JOIN users_in_public_rooms as p2 - ON p1.room_id = p2.room_id - AND p1.user_id = ? - AND p2.user_id = ? - UNION - SELECT room_id - FROM users_who_share_private_rooms - WHERE - user_id = ? - AND other_user_id = ? - """, - (user_id, other_user_id, user_id, other_user_id), - ) - rows = self.db_pool.cursor_to_dict(txn) - return rows - - rows = await self.db_pool.runInteraction( - "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn - ) - - return {row["room_id"] for row in rows} - async def get_user_directory_stream_pos(self) -> Optional[int]: """ Get the stream ID of the user directory stream. diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 5de70f31..fa9eadac 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -195,6 +195,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME = "state_group_edges_unique_idx" def __init__( self, @@ -217,6 +218,21 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["room_id"], ) + # `state_group_edges` can cause severe performance issues if duplicate + # rows are introduced, which can accidentally be done by well-meaning + # server admins when trying to restore a database dump, etc. + # See https://github.com/matrix-org/synapse/issues/11779. + # Introduce a unique index to guard against that. + self.db_pool.updates.register_background_index_update( + self.STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME, + index_name="state_group_edges_unique_idx", + table="state_group_edges", + columns=["state_group", "prev_state_group"], + unique=True, + # The old index was on (state_group) and was not unique. + replaces_index="state_group_edges_idx", + ) + async def _background_deduplicate_state( self, progress: dict, batch_size: int ) -> int: diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7614d76a..609a2b88 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): group: int, state_filter: StateFilter, ) -> Tuple[MutableStateMap[str], bool]: - """Checks if group is in cache. See `_get_state_for_groups` + """Checks if group is in cache. See `get_state_for_groups` Args: cache: the state group cache to use |