summaryrefslogtreecommitdiff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-06-19 15:20:00 +0200
committerAndrej Shadura <andrewsh@debian.org>2022-06-19 15:21:39 +0200
commit734a8e556ce00029d9d7ab0fed73336d24fa91f3 (patch)
treeb277733532b1b141d534133a4715a2fe765ab533 /synapse/storage/databases
parent7a966d08c8403bcff00ac636d977097602501a69 (diff)
parent6dc64c92c6991f09910f3e6db368e6eeb4b1981e (diff)
Update upstream source from tag 'upstream/1.61.0'
Update to upstream version '1.61.0' with Debian dir 5b9bb60cc861cbccd0027b7db7acf826071dc6a0
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py26
-rw-r--r--synapse/storage/databases/main/appservice.py47
-rw-r--r--synapse/storage/databases/main/cache.py8
-rw-r--r--synapse/storage/databases/main/devices.py44
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py196
-rw-r--r--synapse/storage/databases/main/event_push_actions.py2
-rw-r--r--synapse/storage/databases/main/events.py245
-rw-r--r--synapse/storage/databases/main/events_worker.py110
-rw-r--r--synapse/storage/databases/main/group_server.py1407
-rw-r--r--synapse/storage/databases/main/lock.py19
-rw-r--r--synapse/storage/databases/main/media_repository.py72
-rw-r--r--synapse/storage/databases/main/metrics.py74
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py45
-rw-r--r--synapse/storage/databases/main/presence.py75
-rw-r--r--synapse/storage/databases/main/profile.py107
-rw-r--r--synapse/storage/databases/main/purge_events.py23
-rw-r--r--synapse/storage/databases/main/push_rule.py305
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py102
-rw-r--r--synapse/storage/databases/main/relations.py59
-rw-r--r--synapse/storage/databases/main/room.py125
-rw-r--r--synapse/storage/databases/main/roommember.py190
-rw-r--r--synapse/storage/databases/main/search.py33
-rw-r--r--synapse/storage/databases/main/state.py103
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stream.py46
-rw-r--r--synapse/storage/databases/main/user_directory.py47
-rw-r--r--synapse/storage/databases/state/bg_updates.py16
-rw-r--r--synapse/storage/databases/state/store.py2
30 files changed, 1296 insertions, 2246 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5895b892..11d9d16c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -26,11 +26,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
- IdGenerator,
- MultiWriterIdGenerator,
- StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -155,12 +151,6 @@ class DataStore(
],
)
- self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
- self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._group_updates_id_gen = StreamIdGenerator(
- db_conn, "local_group_updates", "stream_id"
- )
-
self._cache_id_gen: Optional[MultiWriterIdGenerator]
if isinstance(self.database_engine, PostgresEngine):
# We set the `writers` to an empty list here as we don't care about
@@ -203,20 +193,6 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill,
)
- _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
- db_conn,
- "local_group_updates",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._group_updates_id_gen.get_current_token(),
- limit=1000,
- )
- self._group_updates_stream_cache = StreamChangeCache(
- "_group_updates_stream_cache",
- min_group_updates_id,
- prefilled_cache=_group_updates_prefill,
- )
-
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 945707b0..e284454b 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -203,19 +203,29 @@ class ApplicationServiceTransactionWorkerStore(
"""Get the application service state.
Args:
- service: The service whose state to set.
+ service: The service whose state to get.
Returns:
- An ApplicationServiceState or none.
+ An ApplicationServiceState, or None if we have yet to attempt any
+ transactions to the AS.
"""
- result = await self.db_pool.simple_select_one(
+ # if we have created transactions for this AS but not yet attempted to send
+ # them, we will have a row in the table with state=NULL (recording the stream
+ # positions we have processed up to).
+ #
+ # On the other hand, if we have yet to create any transactions for this AS at
+ # all, then there will be no row for the AS.
+ #
+ # In either case, we return None to indicate "we don't yet know the state of
+ # this AS".
+ result = await self.db_pool.simple_select_one_onecol(
"application_services_state",
{"as_id": service.id},
- ["state"],
+ retcol="state",
allow_none=True,
desc="get_appservice_state",
)
if result:
- return ApplicationServiceState(result.get("state"))
+ return ApplicationServiceState(result)
return None
async def set_appservice_state(
@@ -296,14 +306,6 @@ class ApplicationServiceTransactionWorkerStore(
"""
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
- # Set current txn_id for AS to 'txn_id'
- self.db_pool.simple_upsert_txn(
- txn,
- "application_services_state",
- {"as_id": service.id},
- {"last_txn": txn_id},
- )
-
# Delete txn
self.db_pool.simple_delete_txn(
txn,
@@ -452,16 +454,15 @@ class ApplicationServiceTransactionWorkerStore(
% (stream_type,)
)
- def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
- stream_id_type = "%s_stream_id" % stream_type
- txn.execute(
- "UPDATE application_services_state SET %s = ? WHERE as_id=?"
- % stream_id_type,
- (pos, service.id),
- )
-
- await self.db_pool.runInteraction(
- "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
+ # this may be the first time that we're recording any state for this AS, so
+ # we don't yet know if a row for it exists; hence we have to upsert here.
+ await self.db_pool.simple_upsert(
+ table="application_services_state",
+ keyvalues={"as_id": service.id},
+ values={f"{stream_type}_stream_id": pos},
+ # no need to lock when emulating upsert: as_id is a unique key
+ lock=False,
+ desc="set_appservice_stream_type_pos",
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index dd4e83a2..1653a6a9 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._instance_name = hs.get_instance_name()
+ self.db_pool.updates.register_background_index_update(
+ update_name="cache_invalidation_index_by_instance",
+ index_name="cache_invalidation_stream_by_instance_instance_index",
+ table="cache_invalidation_stream_by_instance",
+ columns=("instance_name", "stream_id"),
+ psql_only=True, # The table is only on postgres DBs.
+ )
+
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2df4dd4e..d900064c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
cast,
)
+from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
get_active_span_text_map,
@@ -419,7 +420,7 @@ class DeviceWorkerStore(SQLBaseStore):
# Add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
- results.append(("m.signing_key_update", result))
+ results.append((EduTypes.SIGNING_KEY_UPDATE, result))
# also send the unstable version
# FIXME: remove this when enough servers have upgraded
# and remove the length budgeting above.
@@ -545,7 +546,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
- results.append(("m.device_list_update", result))
+ results.append((EduTypes.DEVICE_LIST_UPDATE, result))
return results
@@ -1153,6 +1154,45 @@ class DeviceWorkerStore(SQLBaseStore):
_prune_txn,
)
+ async def get_local_devices_not_accessed_since(
+ self, since_ms: int
+ ) -> Dict[str, List[str]]:
+ """Retrieves local devices that haven't been accessed since a given date.
+
+ Args:
+ since_ms: the timestamp to select on, every device with a last access date
+ from before that time is returned.
+
+ Returns:
+ A dictionary with an entry for each user with at least one device matching
+ the request, which value is a list of the device ID(s) for the corresponding
+ device(s).
+ """
+
+ def get_devices_not_accessed_since_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, str]]:
+ sql = """
+ SELECT user_id, device_id
+ FROM devices WHERE last_seen < ? AND hidden = FALSE
+ """
+ txn.execute(sql, (since_ms,))
+ return self.db_pool.cursor_to_dict(txn)
+
+ rows = await self.db_pool.runInteraction(
+ "get_devices_not_accessed_since",
+ get_devices_not_accessed_since_txn,
+ )
+
+ devices: Dict[str, List[str]] = {}
+ for row in rows:
+ # Remote devices are never stale from our point of view.
+ if self.hs.is_mine_id(row["user_id"]):
+ user_devices = devices.setdefault(row["user_id"], [])
+ user_devices.append(row["device_id"])
+
+ return devices
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b789a588..af59be6b 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -21,7 +21,7 @@ from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
-from synapse.types import JsonDict, JsonSerializable
+from synapse.types import JsonDict, JsonSerializable, StreamKeyType
from synapse.util import json_encoder
@@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
- "room_key": room_key,
+ StreamKeyType.ROOM: room_key,
}
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 47102247..eec55b64 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,17 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
from prometheus_client import Counter, Gauge
@@ -33,7 +43,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
- room = await self.get_room(room_id)
+ room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_ids_using_cover_index_txn(
- self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ event_ids: Collection[str],
+ include_given: bool,
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""
@@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
- for batch in batch_iter(event_chains, 1000):
+ for batch2 in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch
+ txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
@@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = set(event_ids)
while front:
- new_front = set()
+ new_front: Set[str] = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
to_fetch: List[str] = [] # Event IDs to fetch from DB
@@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Note we need to batch up the results by event ID before
# adding to the cache.
- to_cache = {}
+ to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
- room = await self.get_room(room_id)
+ room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_difference_using_cover_index_txn(
- self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+ self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using the chain index.
@@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
- for batch in batch_iter(set(seen_chains), 1000):
+ for batch2 in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch
+ txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
@@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
def _get_auth_chain_difference_txn(
- self, txn, state_sets: List[Set[str]]
+ self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search.
@@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# I think building a temporary list with fetchall is more efficient than
# just `search.extend(txn)`, but this is unconfirmed
- search.extend(txn.fetchall())
+ search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
# sort by depth
search.sort()
@@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We parse the results and add the to the `found` set and the
# cache (note we need to batch up the results by event ID before
# adding to the cache).
- to_cache = {}
+ to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return {eid for eid, n in event_to_missing_sets.items() if n}
async def get_oldest_event_ids_with_depth_in_room(
- self, room_id
+ self, room_id: str
) -> List[Tuple[str, int]]:
"""Gets the oldest events(backwards extremities) in the room along with the
aproximate depth.
@@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
- def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
+ def get_oldest_event_ids_with_depth_in_room_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> List[Tuple[str, int]]:
# Assemble a dictionary with event_id -> depth for the oldest events
# we know of in the room. Backwards extremeties are the oldest
# events we know of in the room but we only know of them because
@@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, False))
- return txn.fetchall()
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_oldest_event_ids_with_depth_in_room",
@@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
async def get_insertion_event_backward_extremities_in_room(
- self, room_id
+ self, room_id: str
) -> List[Tuple[str, int]]:
"""Get the insertion events we know about that we haven't backfilled yet.
@@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
- def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
+ def get_insertion_event_backward_extremities_in_room_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> List[Tuple[str, int]]:
sql = """
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
/* We only want insertion events that are also marked as backwards extremities */
@@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
- return txn.fetchall()
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_insertion_event_backward_extremities_in_room",
@@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
- async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+ async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return max_depth_event_id, current_max_depth
- async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+ async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
Args:
@@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
- def _get_prev_events_for_room_txn(self, txn, room_id: str):
+ def _get_prev_events_for_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> List[str]:
# we just use the 10 newest events. Older events will become
# prev_events of future events.
@@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
sorted by extremity count.
"""
- def _get_rooms_with_many_extremities_txn(txn):
+ def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
where_clause = "1=1"
if room_id_filter:
where_clause = "room_id NOT IN (%s)" % (
@@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_min_depth", self._get_min_depth_interaction, room_id
)
- def _get_min_depth_interaction(self, txn, room_id):
+ def _get_min_depth_interaction(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> Optional[int]:
min_depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_depth",
@@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
- last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+ last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
# stream_ordering from before a restart
- last_change = max(self._stream_order_on_start, last_change)
+ last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
# provided the last_change is recent enough, we now clamp the requested
# stream_ordering to it.
- if last_change > self.stream_ordering_month_ago:
+ if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
stream_ordering = min(last_change, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
- async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+ async def _get_forward_extremeties_for_room(
+ self, room_id: str, stream_ordering: int
+ ) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
stream_orderings from that point.
"""
- if stream_ordering <= self.stream_ordering_month_ago:
+ if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
@@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
WHERE room_id = ?
"""
- def get_forward_extremeties_for_room_txn(txn):
+ def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
@@ -1033,7 +1057,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
INNER JOIN batch_events AS c
ON i.next_batch_id = c.batch_id
/* Get the depth of the batch start event from the events table */
- INNER JOIN events AS e USING (event_id)
+ INNER JOIN events AS e ON c.event_id = e.event_id
/* Find an insertion event which matches the given event_id */
WHERE i.event_id = ?
LIMIT ?
@@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
]
async def get_backfill_events(
- self, room_id: str, seed_event_id_list: list, limit: int
- ):
+ self, room_id: str, seed_event_id_list: List[str], limit: int
+ ) -> List[EventBase]:
"""Get a list of Events for a given topic that occurred before (and
including) the events in seed_event_id_list. Return a list of max size `limit`
@@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
events = await self.get_events_as_list(event_ids)
return sorted(
- events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+ # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
+ # But it's never None, because these events were previously persisted to the DB.
+ events,
+ key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
)
- def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
+ def _get_backfill_events(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ seed_event_id_list: List[str],
+ limit: int,
+ ) -> Set[str]:
"""
We want to make sure that we do a breadth-first, "depth" ordered search.
We also handle navigating historical branches of history connected by
@@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
limit,
)
- event_id_results = set()
+ event_id_results: Set[str] = set()
# In a PriorityQueue, the lowest valued entries are retrieved first.
# We're using depth as the priority in the queue and tie-break based on
@@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# highest and newest-in-time message. We add events to the queue with a
# negative depth so that we process the newest-in-time messages first
# going backwards in time. stream_ordering follows the same pattern.
- queue = PriorityQueue()
+ queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
for seed_event_id in seed_event_id_list:
event_lookup_result = self.db_pool.simple_select_one_txn(
@@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return event_id_results
- async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+ async def get_missing_events(
+ self,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[EventBase]:
ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
@@ -1264,25 +1303,29 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(ids)
- def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
+ def _get_missing_events(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[str]:
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
- event_results = []
+ event_results: List[str] = []
query = (
"SELECT prev_event_id FROM event_edges "
- "WHERE room_id = ? AND event_id = ? AND is_state = ? "
+ "WHERE event_id = ? AND NOT is_state "
"LIMIT ?"
)
while front and len(event_results) < limit:
new_front = set()
for event_id in front:
- txn.execute(
- query, (room_id, event_id, False, limit - len(event_results))
- )
-
+ txn.execute(query, (event_id, limit - len(event_results)))
new_results = {t[0] for t in txn} - seen_events
new_front |= new_results
@@ -1311,7 +1354,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@wrap_as_background_process("delete_old_forward_extrem_cache")
async def _delete_old_forward_extrem_cache(self) -> None:
- def _delete_old_forward_extrem_cache_txn(txn):
+ def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = """
@@ -1324,7 +1367,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) AND stream_ordering < ?
"""
txn.execute(
- sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+ sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
)
await self.db_pool.runInteraction(
@@ -1382,7 +1425,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
if self.db_pool.engine.supports_returning:
- def _remove_received_event_from_staging_txn(txn):
+ def _remove_received_event_from_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
sql = """
DELETE FROM federation_inbound_events_staging
WHERE origin = ? AND event_id = ?
@@ -1390,21 +1435,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (origin, event_id))
- return txn.fetchone()
+ row = cast(Optional[Tuple[int]], txn.fetchone())
+
+ if row is None:
+ return None
- row = await self.db_pool.runInteraction(
+ return row[0]
+
+ return await self.db_pool.runInteraction(
"remove_received_event_from_staging",
_remove_received_event_from_staging_txn,
db_autocommit=True,
)
- if row is None:
- return None
-
- return row[0]
else:
- def _remove_received_event_from_staging_txn(txn):
+ def _remove_received_event_from_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
received_ts = self.db_pool.simple_select_one_onecol_txn(
txn,
table="federation_inbound_events_staging",
@@ -1437,7 +1485,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, str]]:
"""Get the next event ID in the staging area for the given room."""
- def _get_next_staged_event_id_for_room_txn(txn):
+ def _get_next_staged_event_id_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, str]]:
sql = """
SELECT origin, event_id
FROM federation_inbound_events_staging
@@ -1448,7 +1498,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id,))
- return txn.fetchone()
+ return cast(Optional[Tuple[str, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@@ -1461,7 +1511,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, EventBase]]:
"""Get the next event in the staging area for the given room."""
- def _get_next_staged_event_for_room_txn(txn):
+ def _get_next_staged_event_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, str, str]]:
sql = """
SELECT event_json, internal_metadata, origin
FROM federation_inbound_events_staging
@@ -1471,7 +1523,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
- return txn.fetchone()
+ return cast(Optional[Tuple[str, str, str]], txn.fetchone())
row = await self.db_pool.runInteraction(
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@@ -1599,18 +1651,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@wrap_as_background_process("_get_stats_for_federation_staging")
- async def _get_stats_for_federation_staging(self):
+ async def _get_stats_for_federation_staging(self) -> None:
"""Update the prometheus metrics for the inbound federation staging area."""
- def _get_stats_for_federation_staging_txn(txn):
+ def _get_stats_for_federation_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, int]:
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
txn.execute(
"SELECT min(received_ts) FROM federation_inbound_events_staging"
)
- (received_ts,) = txn.fetchone()
+ (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
# If there is nothing in the staging area default it to 0.
age = 0
@@ -1651,19 +1705,21 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
- async def clean_room_for_join(self, room_id):
- return await self.db_pool.runInteraction(
+ async def clean_room_for_join(self, room_id: str) -> None:
+ await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
- def _clean_room_for_join_txn(self, txn, room_id):
+ def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- async def _background_delete_non_state_event_auth(self, progress, batch_size):
- def delete_event_auth(txn):
+ async def _background_delete_non_state_event_auth(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def delete_event_auth(txn: LoggingTransaction) -> bool:
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b7c4c622..b0199793 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -938,7 +938,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
users can still get a list of recent highlights.
Args:
- txn: The transcation
+ txn: The transaction
room_id: Room ID to delete from
user_id: user ID to delete for
stream_ordering: The lowest stream ordering which will
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2c86a870..17e35cf6 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,9 +36,8 @@ from prometheus_client import Counter
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
-from synapse.events import EventBase # noqa: F401
-from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines.postgres import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
from synapse.util.stringutils import non_null_str_or_none
@@ -130,7 +129,6 @@ class PersistEventsStore:
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
*,
- current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremities: Dict[str, Set[str]],
use_negative_stream_ordering: bool = False,
@@ -141,8 +139,6 @@ class PersistEventsStore:
Args:
events_and_contexts:
- current_state_for_room: Map from room_id to the current state of
- the room based on forward extremities
state_delta_for_room: Map from room_id to the delta to apply to
room state
new_forward_extremities: Map from room_id to set of event IDs
@@ -217,9 +213,6 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, new_state in current_state_for_room.items():
- self.store.get_current_state_ids.prefill((room_id,), new_state)
-
for room_id, latest_event_ids in new_forward_extremities.items():
self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
@@ -237,7 +230,9 @@ class PersistEventsStore:
"""
results: List[str] = []
- def _get_events_which_are_prevs_txn(txn, batch):
+ def _get_events_which_are_prevs_txn(
+ txn: LoggingTransaction, batch: Collection[str]
+ ) -> None:
sql = """
SELECT prev_event_id, internal_metadata
FROM event_edges
@@ -287,7 +282,9 @@ class PersistEventsStore:
# and their prev events.
existing_prevs = set()
- def _get_prevs_before_rejected_txn(txn, batch):
+ def _get_prevs_before_rejected_txn(
+ txn: LoggingTransaction, batch: Collection[str]
+ ) -> None:
to_recursively_check = batch
while to_recursively_check:
@@ -517,7 +514,7 @@ class PersistEventsStore:
@classmethod
def _add_chain_cover_index(
cls,
- txn,
+ txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
@@ -811,7 +808,7 @@ class PersistEventsStore:
@staticmethod
def _allocate_chain_ids(
- txn,
+ txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
@@ -945,7 +942,7 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- ):
+ ) -> None:
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
to_insert = []
@@ -999,7 +996,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
state_delta_by_room: Dict[str, DeltaState],
stream_id: int,
- ):
+ ) -> None:
for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
@@ -1157,7 +1154,7 @@ class PersistEventsStore:
txn, room_id, members_changed
)
- def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+ def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
"""Update the room version in the database based off current state
events.
@@ -1191,7 +1188,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
new_forward_extremities: Dict[str, Set[str]],
max_stream_order: int,
- ):
+ ) -> None:
for room_id in new_forward_extremities.keys():
self.db_pool.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1256,9 +1253,9 @@ class PersistEventsStore:
def _update_room_depths_txn(
self,
- txn,
+ txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- ):
+ ) -> None:
"""Update min_depth for each room
Args:
@@ -1387,7 +1384,7 @@ class PersistEventsStore:
# nothing to do here
return
- def event_dict(event):
+ def event_dict(event: EventBase) -> JsonDict:
d = event.get_dict()
d.pop("redacted", None)
d.pop("redacted_because", None)
@@ -1478,18 +1475,20 @@ class PersistEventsStore:
),
)
- def _store_rejected_events_txn(self, txn, events_and_contexts):
+ def _store_rejected_events_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Add rows to the 'rejections' table for received events which were
rejected
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
+ txn: db connection
+ events_and_contexts: events we are persisting
Returns:
- list[(EventBase, EventContext)] new list, without the rejected
- events.
+ new list, without the rejected events.
"""
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
@@ -1510,7 +1509,7 @@ class PersistEventsStore:
events_and_contexts: List[Tuple[EventBase, EventContext]],
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool = False,
- ):
+ ) -> None:
"""Update all the miscellaneous tables for new events
Args:
@@ -1601,15 +1600,14 @@ class PersistEventsStore:
inhibit_local_membership_updates=inhibit_local_membership_updates,
)
- # Insert event_reference_hashes table.
- self._store_event_reference_hashes_txn(
- txn, [event for event, _ in events_and_contexts]
- )
-
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
- def _add_to_cache(self, txn, events_and_contexts):
+ def _add_to_cache(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> None:
to_prefill = []
rows = []
@@ -1640,7 +1638,7 @@ class PersistEventsStore:
if not row["rejects"] and not row["redacts"]:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
- def prefill():
+ def prefill() -> None:
for cache_entry in to_prefill:
self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
@@ -1670,19 +1668,24 @@ class PersistEventsStore:
)
def insert_labels_for_event_txn(
- self, txn, event_id, labels, room_id, topological_ordering
- ):
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ labels: List[str],
+ room_id: str,
+ topological_ordering: int,
+ ) -> None:
"""Store the mapping between an event's ID and its labels, with one row per
(event_id, label) tuple.
Args:
- txn (LoggingTransaction): The transaction to execute.
- event_id (str): The event's ID.
- labels (list[str]): A list of text labels.
- room_id (str): The ID of the room the event was sent to.
- topological_ordering (int): The position of the event in the room's topology.
+ txn: The transaction to execute.
+ event_id: The event's ID.
+ labels: A list of text labels.
+ room_id: The ID of the room the event was sent to.
+ topological_ordering: The position of the event in the room's topology.
"""
- return self.db_pool.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1691,44 +1694,32 @@ class PersistEventsStore:
],
)
- def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+ def _insert_event_expiry_txn(
+ self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+ ) -> None:
"""Save the expiry timestamp associated with a given event ID.
Args:
- txn (LoggingTransaction): The database transaction to use.
- event_id (str): The event ID the expiry timestamp is associated with.
- expiry_ts (int): The timestamp at which to expire (delete) the event.
+ txn: The database transaction to use.
+ event_id: The event ID the expiry timestamp is associated with.
+ expiry_ts: The timestamp at which to expire (delete) the event.
"""
- return self.db_pool.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
)
- def _store_event_reference_hashes_txn(self, txn, events):
- """Store a hash for a PDU
- Args:
- txn (cursor):
- events (list): list of Events.
- """
-
- vals = []
- for event in events:
- ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
- vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
-
- self.db_pool.simple_insert_many_txn(
- txn,
- table="event_reference_hashes",
- keys=("event_id", "algorithm", "hash"),
- values=vals,
- )
-
def _store_room_members_txn(
- self, txn, events, *, inhibit_local_membership_updates: bool = False
- ):
+ self,
+ txn: LoggingTransaction,
+ events: List[EventBase],
+ *,
+ inhibit_local_membership_updates: bool = False,
+ ) -> None:
"""
Store a room member in the database.
+
Args:
txn: The transaction to use.
events: List of events to store.
@@ -1765,6 +1756,7 @@ class PersistEventsStore:
)
for event in events:
+ assert event.internal_metadata.stream_ordering is not None
txn.call_after(
self.store._membership_stream_cache.entity_has_changed,
event.state_key,
@@ -1813,55 +1805,54 @@ class PersistEventsStore:
txn: The current database transaction.
event: The event which might have relations.
"""
- relation = event.content.get("m.relates_to")
+ relation = relation_from_event(event)
if not relation:
- # No relations
- return
-
- # Relations must have a type and parent event ID.
- rel_type = relation.get("rel_type")
- if not isinstance(rel_type, str):
+ # No relation, nothing to do.
return
- parent_id = relation.get("event_id")
- if not isinstance(parent_id, str):
- return
-
- # Annotations have a key field.
- aggregation_key = None
- if rel_type == RelationTypes.ANNOTATION:
- aggregation_key = relation.get("key")
-
self.db_pool.simple_insert_txn(
txn,
table="event_relations",
values={
"event_id": event.event_id,
- "relates_to_id": parent_id,
- "relation_type": rel_type,
- "aggregation_key": aggregation_key,
+ "relates_to_id": relation.parent_id,
+ "relation_type": relation.rel_type,
+ "aggregation_key": relation.aggregation_key,
},
)
- txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
txn.call_after(
- self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+ self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+ )
+ txn.call_after(
+ self.store.get_aggregation_groups_for_event.invalidate,
+ (relation.parent_id,),
+ )
+ txn.call_after(
+ self.store.get_mutual_event_relations_for_rel_type.invalidate,
+ (relation.parent_id,),
)
- if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+ if relation.rel_type == RelationTypes.REPLACE:
+ txn.call_after(
+ self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+ )
- if rel_type == RelationTypes.THREAD:
- txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+ if relation.rel_type == RelationTypes.THREAD:
+ txn.call_after(
+ self.store.get_thread_summary.invalidate, (relation.parent_id,)
+ )
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
- (parent_id, event.sender),
+ (relation.parent_id, event.sender),
)
- def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+ def _handle_insertion_event(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
"""Handles keeping track of insertion events and edges/connections.
Part of MSC2716.
@@ -1922,7 +1913,7 @@ class PersistEventsStore:
},
)
- def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Handles inserting the batch edges/connections between the batch event
and an insertion event. Part of MSC2716.
@@ -2017,30 +2008,39 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
+ self.store._invalidate_cache_and_stream(
+ txn,
+ self.store.get_mutual_event_relations_for_rel_type,
+ (redacted_relates_to,),
+ )
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
- def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("topic"), str):
self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
- def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("name"), str):
self.store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
- def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_message_txn(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
if isinstance(event.content.get("body"), str):
self.store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
- def _store_retention_policy_for_room_txn(self, txn, event):
+ def _store_retention_policy_for_room_txn(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
if not event.is_state():
logger.debug("Ignoring non-state m.room.retention event")
return
@@ -2100,8 +2100,11 @@ class PersistEventsStore:
)
def _set_push_actions_for_event_and_users_txn(
- self, txn, events_and_contexts, all_events_and_contexts
- ):
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> None:
"""Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`.
@@ -2109,12 +2112,10 @@ class PersistEventsStore:
from the push action staging area.
Args:
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
- all_events_and_contexts (list[(EventBase, EventContext)]): all
- events that we were going to persist. This includes events
- we've already persisted, etc, that wouldn't appear in
- events_and_context.
+ events_and_contexts: events we are persisting
+ all_events_and_contexts: all events that we were going to persist.
+ This includes events we've already persisted, etc, that wouldn't
+ appear in events_and_context.
"""
# Only non outlier events will have push actions associated with them,
@@ -2183,7 +2184,9 @@ class PersistEventsStore:
),
)
- def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+ def _remove_push_actions_for_event_id_txn(
+ self, txn: LoggingTransaction, room_id: str, event_id: str
+ ) -> None:
# Sad that we have to blow away the cache for the whole room here
txn.call_after(
self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2194,7 +2197,9 @@ class PersistEventsStore:
(room_id, event_id),
)
- def _store_rejections_txn(self, txn, event_id, reason):
+ def _store_rejections_txn(
+ self, txn: LoggingTransaction, event_id: str, reason: str
+ ) -> None:
self.db_pool.simple_insert_txn(
txn,
table="rejections",
@@ -2206,8 +2211,10 @@ class PersistEventsStore:
)
def _store_event_state_mappings_txn(
- self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
- ):
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+ ) -> None:
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
@@ -2264,7 +2271,9 @@ class PersistEventsStore:
state_group_id,
)
- def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ def _update_min_depth_for_room_txn(
+ self, txn: LoggingTransaction, room_id: str, depth: int
+ ) -> None:
min_depth = self.store._get_min_depth_interaction(txn, room_id)
if min_depth is not None and depth >= min_depth:
@@ -2277,7 +2286,9 @@ class PersistEventsStore:
values={"min_depth": depth},
)
- def _handle_mult_prev_events(self, txn, events):
+ def _handle_mult_prev_events(
+ self, txn: LoggingTransaction, events: List[EventBase]
+ ) -> None:
"""
For the given event, update the event edges table and forward and
backward extremities tables.
@@ -2295,7 +2306,9 @@ class PersistEventsStore:
self._update_backward_extremeties(txn, events)
- def _update_backward_extremeties(self, txn, events):
+ def _update_backward_extremeties(
+ self, txn: LoggingTransaction, events: List[EventBase]
+ ) -> None:
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a4a604a4..b99b1077 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,6 +14,7 @@
import logging
import threading
+import weakref
from enum import Enum, auto
from typing import (
TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
Dict,
Iterable,
List,
+ MutableMapping,
Optional,
Set,
Tuple,
@@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore):
str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
+ # We keep track of the events we have currently loaded in memory so that
+ # we can reuse them even if they've been evicted from the cache. We only
+ # track events that don't need redacting in here (as then we don't need
+ # to track redaction status).
+ self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary()
+
self._event_fetch_lock = threading.Condition()
self._event_fetch_list: List[
Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
@@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
+ self._event_ref.pop(event_id, None)
+ self._current_event_fetches.pop(event_id, None)
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
@@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore):
event_map = {}
for event_id in events:
+ # First check if it's in the event cache
ret = self._get_event_cache.get(
(event_id,), None, update_metrics=update_metrics
)
- if not ret:
+ if ret:
+ event_map[event_id] = ret
continue
- event_map[event_id] = ret
+ # Otherwise check if we still have the event in memory.
+ event = self._event_ref.get(event_id)
+ if event:
+ # Reconstruct an event cache entry
+
+ cache_entry = EventCacheEntry(
+ event=event,
+ # We don't cache weakrefs to redacted events, so we know
+ # this is None.
+ redacted_event=None,
+ )
+ event_map[event_id] = cache_entry
+
+ # We add the entry back into the cache as we want to keep
+ # recently queried events in the cache.
+ self._get_event_cache.set((event_id,), cache_entry)
return event_map
@@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry
+ if not redacted_event:
+ # We only cache references to unredacted events.
+ self._event_ref[event_id] = original_ev
+
return result_map
async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
@@ -1325,14 +1356,23 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
The set of events we have already seen.
"""
- res = await self._have_seen_events_dict(
- (room_id, event_id) for event_id in event_ids
- )
- return {eid for ((_rid, eid), have_event) in res.items() if have_event}
+
+ # @cachedList chomps lots of memory if you call it with a big list, so
+ # we break it down. However, each batch requires its own index scan, so we make
+ # the batches as big as possible.
+
+ results: Set[str] = set()
+ for chunk in batch_iter(event_ids, 500):
+ r = await self._have_seen_events_dict(
+ [(room_id, event_id) for event_id in chunk]
+ )
+ results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
+
+ return results
@cachedList(cached_method_name="have_seen_event", list_name="keys")
async def _have_seen_events_dict(
- self, keys: Iterable[Tuple[str, str]]
+ self, keys: Collection[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
"""Helper for have_seen_events
@@ -1344,11 +1384,12 @@ class EventsWorkerStore(SQLBaseStore):
cache_results = {
(rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
}
- results = {x: True for x in cache_results}
+ results = dict.fromkeys(cache_results, True)
+ remaining = [k for k in keys if k not in cache_results]
+ if not remaining:
+ return results
- def have_seen_events_txn(
- txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
- ) -> None:
+ def have_seen_events_txn(txn: LoggingTransaction) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1356,21 +1397,17 @@ class EventsWorkerStore(SQLBaseStore):
sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
- txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk]
+ txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}
- # ... and then we can update the results for each row in the batch
- results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk})
-
- # each batch requires its own index scan, so we make the batches as big as
- # possible.
- for chunk in batch_iter((k for k in keys if k not in cache_results), 500):
- await self.db_pool.runInteraction(
- "have_seen_events", have_seen_events_txn, chunk
+ # ... and then we can update the results for each key
+ results.update(
+ {(rid, eid): (eid in found_events) for (rid, eid) in remaining}
)
+ await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
return results
@cached(max_entries=100000, tree=True)
@@ -1891,6 +1928,18 @@ class EventsWorkerStore(SQLBaseStore):
LIMIT 1
"""
+ # We consider any forward extremity as the latest in the room and
+ # not a forward gap.
+ #
+ # To expand, even though there is technically a gap at the front of
+ # the room where the forward extremities are, we consider those the
+ # latest messages in the room so asking other homeservers for more
+ # is useless. The new latest messages will just be federated as
+ # usual.
+ txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+ if txn.fetchone():
+ return False
+
# Check to see whether the event in question is already referenced
# by another event. If we don't see any edges, we're next to a
# forward gap.
@@ -1899,8 +1948,7 @@ class EventsWorkerStore(SQLBaseStore):
/* Check to make sure the event referencing our event in question is not rejected */
LEFT JOIN rejections ON event_edges.event_id = rejections.event_id
WHERE
- event_edges.room_id = ?
- AND event_edges.prev_event_id = ?
+ event_edges.prev_event_id = ?
/* It's not a valid edge if the event referencing our event in
* question is rejected.
*/
@@ -1908,25 +1956,11 @@ class EventsWorkerStore(SQLBaseStore):
LIMIT 1
"""
- # We consider any forward extremity as the latest in the room and
- # not a forward gap.
- #
- # To expand, even though there is technically a gap at the front of
- # the room where the forward extremities are, we consider those the
- # latest messages in the room so asking other homeservers for more
- # is useless. The new latest messages will just be federated as
- # usual.
- txn.execute(forward_extremity_query, (event.room_id, event.event_id))
- forward_extremities = txn.fetchall()
- if len(forward_extremities):
- return False
-
# If there are no forward edges to the event in question (another
# event hasn't referenced this event in their prev_events), then we
# assume there is a forward gap in the history.
- txn.execute(forward_edge_query, (event.room_id, event.event_id))
- forward_edges = txn.fetchall()
- if not len(forward_edges):
+ txn.execute(forward_edge_query, (event.event_id,))
+ if not txn.fetchone():
return True
return False
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 04efad9e..c15a7136 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,1417 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING
-from typing_extensions import TypedDict
-
-from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
-)
-from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
if TYPE_CHECKING:
from synapse.server import HomeServer
-# The category ID for the "default" category. We don't store as null in the
-# database to avoid the fun of null != null
-_DEFAULT_CATEGORY_ID = ""
-_DEFAULT_ROLE_ID = ""
-
-
-# A room in a group.
-class _RoomInGroup(TypedDict):
- room_id: str
- is_public: bool
-
-class GroupServerWorkerStore(SQLBaseStore):
+class GroupServerStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
- database.updates.register_background_index_update(
- update_name="local_group_updates_index",
- index_name="local_group_updates_stream_id_index",
- table="local_group_updates",
- columns=("stream_id",),
- unique=True,
- )
+ # Register a legacy groups background update as a no-op.
+ database.updates.register_noop_background_update("local_group_updates_index")
super().__init__(database, db_conn, hs)
-
- async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
- return await self.db_pool.simple_select_one(
- table="groups",
- keyvalues={"group_id": group_id},
- retcols=(
- "name",
- "short_description",
- "long_description",
- "avatar_url",
- "is_public",
- "join_policy",
- ),
- allow_none=True,
- desc="get_group",
- )
-
- async def get_users_in_group(
- self, group_id: str, include_private: bool = False
- ) -> List[Dict[str, Any]]:
- # TODO: Pagination
-
- keyvalues: JsonDict = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
-
- return await self.db_pool.simple_select_list(
- table="group_users",
- keyvalues=keyvalues,
- retcols=("user_id", "is_public", "is_admin"),
- desc="get_users_in_group",
- )
-
- async def get_invited_users_in_group(self, group_id: str) -> List[str]:
- # TODO: Pagination
-
- return await self.db_pool.simple_select_onecol(
- table="group_invites",
- keyvalues={"group_id": group_id},
- retcol="user_id",
- desc="get_invited_users_in_group",
- )
-
- async def get_rooms_in_group(
- self, group_id: str, include_private: bool = False
- ) -> List[_RoomInGroup]:
- """Retrieve the rooms that belong to a given group. Does not return rooms that
- lack members.
-
- Args:
- group_id: The ID of the group to query for rooms
- include_private: Whether to return private rooms in results
-
- Returns:
- A list of dictionaries, each in the form of:
-
- {
- "room_id": "!a_room_id:example.com", # The ID of the room
- "is_public": False # Whether this is a public room or not
- }
- """
-
- # TODO: Pagination
-
- def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
- sql = """
- SELECT room_id, is_public FROM group_rooms
- WHERE group_id = ?
- AND room_id IN (
- SELECT group_rooms.room_id FROM group_rooms
- LEFT JOIN room_stats_current ON
- group_rooms.room_id = room_stats_current.room_id
- AND joined_members > 0
- AND local_users_in_room > 0
- LEFT JOIN rooms ON
- group_rooms.room_id = rooms.room_id
- AND (room_version <> '') = ?
- )
- """
- args = [group_id, False]
-
- if not include_private:
- sql += " AND is_public = ?"
- args += [True]
-
- txn.execute(sql, args)
-
- return [
- {"room_id": room_id, "is_public": is_public}
- for room_id, is_public in txn
- ]
-
- return await self.db_pool.runInteraction(
- "get_rooms_in_group", _get_rooms_in_group_txn
- )
-
- async def get_rooms_for_summary_by_category(
- self,
- group_id: str,
- include_private: bool = False,
- ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
- """Get the rooms and categories that should be included in a summary request
-
- Args:
- group_id: The ID of the group to query the summary for
- include_private: Whether to return private rooms in results
-
- Returns:
- A tuple containing:
-
- * A list of dictionaries with the keys:
- * "room_id": str, the room ID
- * "is_public": bool, whether the room is public
- * "category_id": str|None, the category ID if set, else None
- * "order": int, the sort order of rooms
-
- * A dictionary with the key:
- * category_id (str): a dictionary with the keys:
- * "is_public": bool, whether the category is public
- * "profile": str, the category profile
- * "order": int, the sort order of rooms in this category
- """
-
- def _get_rooms_for_summary_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
- keyvalues: JsonDict = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
-
- sql = """
- SELECT room_id, is_public, category_id, room_order
- FROM group_summary_rooms
- WHERE group_id = ?
- AND room_id IN (
- SELECT group_rooms.room_id FROM group_rooms
- LEFT JOIN room_stats_current ON
- group_rooms.room_id = room_stats_current.room_id
- AND joined_members > 0
- AND local_users_in_room > 0
- LEFT JOIN rooms ON
- group_rooms.room_id = rooms.room_id
- AND (room_version <> '') = ?
- )
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, False, True))
- else:
- txn.execute(sql, (group_id, False))
-
- rooms = [
- {
- "room_id": row[0],
- "is_public": row[1],
- "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None,
- "order": row[3],
- }
- for row in txn
- ]
-
- sql = """
- SELECT category_id, is_public, profile, cat_order
- FROM group_summary_room_categories
- INNER JOIN group_room_categories USING (group_id, category_id)
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- categories = {
- row[0]: {
- "is_public": row[1],
- "profile": db_to_json(row[2]),
- "order": row[3],
- }
- for row in txn
- }
-
- return rooms, categories
-
- return await self.db_pool.runInteraction(
- "get_rooms_for_summary", _get_rooms_for_summary_txn
- )
-
- async def get_group_categories(self, group_id: str) -> JsonDict:
- rows = await self.db_pool.simple_select_list(
- table="group_room_categories",
- keyvalues={"group_id": group_id},
- retcols=("category_id", "is_public", "profile"),
- desc="get_group_categories",
- )
-
- return {
- row["category_id"]: {
- "is_public": row["is_public"],
- "profile": db_to_json(row["profile"]),
- }
- for row in rows
- }
-
- async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
- category = await self.db_pool.simple_select_one(
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- retcols=("is_public", "profile"),
- desc="get_group_category",
- )
-
- category["profile"] = db_to_json(category["profile"])
-
- return category
-
- async def get_group_roles(self, group_id: str) -> JsonDict:
- rows = await self.db_pool.simple_select_list(
- table="group_roles",
- keyvalues={"group_id": group_id},
- retcols=("role_id", "is_public", "profile"),
- desc="get_group_roles",
- )
-
- return {
- row["role_id"]: {
- "is_public": row["is_public"],
- "profile": db_to_json(row["profile"]),
- }
- for row in rows
- }
-
- async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
- role = await self.db_pool.simple_select_one(
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- retcols=("is_public", "profile"),
- desc="get_group_role",
- )
-
- role["profile"] = db_to_json(role["profile"])
-
- return role
-
- async def get_local_groups_for_room(self, room_id: str) -> List[str]:
- """Get all of the local group that contain a given room
- Args:
- room_id: The ID of a room
- Returns:
- A list of group ids containing this room
- """
- return await self.db_pool.simple_select_onecol(
- table="group_rooms",
- keyvalues={"room_id": room_id},
- retcol="group_id",
- desc="get_local_groups_for_room",
- )
-
- async def get_users_for_summary_by_role(
- self, group_id: str, include_private: bool = False
- ) -> Tuple[List[JsonDict], JsonDict]:
- """Get the users and roles that should be included in a summary request
-
- Returns:
- ([users], [roles])
- """
-
- def _get_users_for_summary_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], JsonDict]:
- keyvalues: JsonDict = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
-
- sql = """
- SELECT user_id, is_public, role_id, user_order
- FROM group_summary_users
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- users = [
- {
- "user_id": row[0],
- "is_public": row[1],
- "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
- "order": row[3],
- }
- for row in txn
- ]
-
- sql = """
- SELECT role_id, is_public, profile, role_order
- FROM group_summary_roles
- INNER JOIN group_roles USING (group_id, role_id)
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- roles = {
- row[0]: {
- "is_public": row[1],
- "profile": db_to_json(row[2]),
- "order": row[3],
- }
- for row in txn
- }
-
- return users, roles
-
- return await self.db_pool.runInteraction(
- "get_users_for_summary_by_role", _get_users_for_summary_txn
- )
-
- async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
- result = await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="is_user_in_group",
- )
- return bool(result)
-
- async def is_user_admin_in_group(
- self, group_id: str, user_id: str
- ) -> Optional[bool]:
- return await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="is_admin",
- allow_none=True,
- desc="is_user_admin_in_group",
- )
-
- async def is_user_invited_to_local_group(
- self, group_id: str, user_id: str
- ) -> Optional[bool]:
- """Has the group server invited a user?"""
- return await self.db_pool.simple_select_one_onecol(
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- desc="is_user_invited_to_local_group",
- allow_none=True,
- )
-
- async def get_users_membership_info_in_group(
- self, group_id: str, user_id: str
- ) -> JsonDict:
- """Get a dict describing the membership of a user in a group.
-
- Example if joined:
-
- {
- "membership": "join",
- "is_public": True,
- "is_privileged": False,
- }
-
- Returns:
- An empty dict if the user is not join/invite/etc
- """
-
- def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
- row = self.db_pool.simple_select_one_txn(
- txn,
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcols=("is_admin", "is_public"),
- allow_none=True,
- )
-
- if row:
- return {
- "membership": "join",
- "is_public": row["is_public"],
- "is_privileged": row["is_admin"],
- }
-
- row = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- )
-
- if row:
- return {"membership": "invite"}
-
- return {}
-
- return await self.db_pool.runInteraction(
- "get_users_membership_info_in_group", _get_users_membership_in_group_txn
- )
-
- async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
- """Get all groups a user is publicising"""
- return await self.db_pool.simple_select_onecol(
- table="local_group_membership",
- keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
- retcol="group_id",
- desc="get_publicised_groups_for_user",
- )
-
- async def get_attestations_need_renewals(
- self, valid_until_ms: int
- ) -> List[Dict[str, Any]]:
- """Get all attestations that need to be renewed until givent time"""
-
- def _get_attestations_need_renewals_txn(
- txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
- sql = """
- SELECT group_id, user_id FROM group_attestations_renewals
- WHERE valid_until_ms <= ?
- """
- txn.execute(sql, (valid_until_ms,))
- return self.db_pool.cursor_to_dict(txn)
-
- return await self.db_pool.runInteraction(
- "get_attestations_need_renewals", _get_attestations_need_renewals_txn
- )
-
- async def get_remote_attestation(
- self, group_id: str, user_id: str
- ) -> Optional[JsonDict]:
- """Get the attestation that proves the remote agrees that the user is
- in the group.
- """
- row = await self.db_pool.simple_select_one(
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcols=("valid_until_ms", "attestation_json"),
- desc="get_remote_attestation",
- allow_none=True,
- )
-
- now = int(self._clock.time_msec())
- if row and now < row["valid_until_ms"]:
- return db_to_json(row["attestation_json"])
-
- return None
-
- async def get_joined_groups(self, user_id: str) -> List[str]:
- return await self.db_pool.simple_select_onecol(
- table="local_group_membership",
- keyvalues={"user_id": user_id, "membership": "join"},
- retcol="group_id",
- desc="get_joined_groups",
- )
-
- async def get_all_groups_for_user(
- self, user_id: str, now_token: int
- ) -> List[JsonDict]:
- def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
- sql = """
- SELECT group_id, type, membership, u.content
- FROM local_group_updates AS u
- INNER JOIN local_group_membership USING (group_id, user_id)
- WHERE user_id = ? AND membership != 'leave'
- AND stream_id <= ?
- """
- txn.execute(sql, (user_id, now_token))
- return [
- {
- "group_id": row[0],
- "type": row[1],
- "membership": row[2],
- "content": db_to_json(row[3]),
- }
- for row in txn
- ]
-
- return await self.db_pool.runInteraction(
- "get_all_groups_for_user", _get_all_groups_for_user_txn
- )
-
- async def get_groups_changes_for_user(
- self, user_id: str, from_token: int, to_token: int
- ) -> List[JsonDict]:
- has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
- user_id, from_token
- )
- if not has_changed:
- return []
-
- def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
- sql = """
- SELECT group_id, membership, type, u.content
- FROM local_group_updates AS u
- INNER JOIN local_group_membership USING (group_id, user_id)
- WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
- """
- txn.execute(sql, (user_id, from_token, to_token))
- return [
- {
- "group_id": group_id,
- "membership": membership,
- "type": gtype,
- "content": db_to_json(content_json),
- }
- for group_id, membership, gtype, content_json in txn
- ]
-
- return await self.db_pool.runInteraction(
- "get_groups_changes_for_user", _get_groups_changes_for_user_txn
- )
-
- async def get_all_groups_changes(
- self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
- """Get updates for groups replication stream.
-
- Args:
- instance_name: The writer we want to fetch updates from. Unused
- here since there is only ever one writer.
- last_id: The token to fetch updates from. Exclusive.
- current_id: The token to fetch updates up to. Inclusive.
- limit: The requested limit for the number of rows to return. The
- function may return more or fewer rows.
-
- Returns:
- A tuple consisting of: the updates, a token to use to fetch
- subsequent updates, and whether we returned fewer rows than exists
- between the requested tokens due to the limit.
-
- The token returned can be used in a subsequent call to this
- function to get further updatees.
-
- The updates are a list of 2-tuples of stream ID and the row data
- """
-
- last_id = int(last_id)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
-
- if not has_changed:
- return [], current_id, False
-
- def _get_all_groups_changes_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
- sql = """
- SELECT stream_id, group_id, user_id, type, content
- FROM local_group_updates
- WHERE ? < stream_id AND stream_id <= ?
- LIMIT ?
- """
- txn.execute(sql, (last_id, current_id, limit))
- updates = cast(
- List[Tuple[int, tuple]],
- [
- (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
- for stream_id, group_id, user_id, gtype, content_json in txn
- ],
- )
-
- limited = False
- upto_token = current_id
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
-
- return await self.db_pool.runInteraction(
- "get_all_groups_changes", _get_all_groups_changes_txn
- )
-
-
-class GroupServerStore(GroupServerWorkerStore):
- async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
- """Set the join policy of a group.
-
- join_policy can be one of:
- * "invite"
- * "open"
- """
- await self.db_pool.simple_update_one(
- table="groups",
- keyvalues={"group_id": group_id},
- updatevalues={"join_policy": join_policy},
- desc="set_group_join_policy",
- )
-
- async def add_room_to_summary(
- self,
- group_id: str,
- room_id: str,
- category_id: Optional[str],
- order: Optional[int],
- is_public: Optional[bool],
- ) -> None:
- """Add (or update) room's entry in summary.
-
- Args:
- group_id
- room_id
- category_id: If not None then adds the category to the end of
- the summary if its not already there.
- order: If not None inserts the room at that position, e.g. an order
- of 1 will put the room first. Otherwise, the room gets added to
- the end.
- is_public
- """
- await self.db_pool.runInteraction(
- "add_room_to_summary",
- self._add_room_to_summary_txn,
- group_id,
- room_id,
- category_id,
- order,
- is_public,
- )
-
- def _add_room_to_summary_txn(
- self,
- txn: LoggingTransaction,
- group_id: str,
- room_id: str,
- category_id: Optional[str],
- order: Optional[int],
- is_public: Optional[bool],
- ) -> None:
- """Add (or update) room's entry in summary.
-
- Args:
- txn
- group_id
- room_id
- category_id: If not None then adds the category to the end of
- the summary if its not already there.
- order: If not None inserts the room at that position, e.g. an order
- of 1 will put the room first. Otherwise, the room gets added to
- the end.
- is_public
- """
- room_in_group = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_rooms",
- keyvalues={"group_id": group_id, "room_id": room_id},
- retcol="room_id",
- allow_none=True,
- )
- if not room_in_group:
- raise SynapseError(400, "room not in group")
-
- if category_id is None:
- category_id = _DEFAULT_CATEGORY_ID
- else:
- cat_exists = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- retcol="group_id",
- allow_none=True,
- )
- if not cat_exists:
- raise SynapseError(400, "Category doesn't exist")
-
- # TODO: Check category is part of summary already
- cat_exists = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_summary_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- retcol="group_id",
- allow_none=True,
- )
- if not cat_exists:
- # If not, add it with an order larger than all others
- txn.execute(
- """
- INSERT INTO group_summary_room_categories
- (group_id, category_id, cat_order)
- SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
- FROM group_summary_room_categories
- WHERE group_id = ? AND category_id = ?
- """,
- (group_id, category_id, group_id, category_id),
- )
-
- existing = self.db_pool.simple_select_one_txn(
- txn,
- table="group_summary_rooms",
- keyvalues={
- "group_id": group_id,
- "room_id": room_id,
- "category_id": category_id,
- },
- retcols=("room_order", "is_public"),
- allow_none=True,
- )
-
- if order is not None:
- # Shuffle other room orders that come after the given order
- sql = """
- UPDATE group_summary_rooms SET room_order = room_order + 1
- WHERE group_id = ? AND category_id = ? AND room_order >= ?
- """
- txn.execute(sql, (group_id, category_id, order))
- elif not existing:
- sql = """
- SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
- WHERE group_id = ? AND category_id = ?
- """
- txn.execute(sql, (group_id, category_id))
- (order,) = cast(Tuple[int], txn.fetchone())
-
- if existing:
- to_update = {}
- if order is not None:
- to_update["room_order"] = order
- if is_public is not None:
- to_update["is_public"] = is_public
- self.db_pool.simple_update_txn(
- txn,
- table="group_summary_rooms",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- "room_id": room_id,
- },
- updatevalues=to_update,
- )
- else:
- if is_public is None:
- is_public = True
-
- self.db_pool.simple_insert_txn(
- txn,
- table="group_summary_rooms",
- values={
- "group_id": group_id,
- "category_id": category_id,
- "room_id": room_id,
- "room_order": order,
- "is_public": is_public,
- },
- )
-
- async def remove_room_from_summary(
- self, group_id: str, room_id: str, category_id: Optional[str]
- ) -> int:
- if category_id is None:
- category_id = _DEFAULT_CATEGORY_ID
-
- return await self.db_pool.simple_delete(
- table="group_summary_rooms",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- "room_id": room_id,
- },
- desc="remove_room_from_summary",
- )
-
- async def upsert_group_category(
- self,
- group_id: str,
- category_id: str,
- profile: Optional[JsonDict],
- is_public: Optional[bool],
- ) -> None:
- """Add/update room category for group"""
- insertion_values: JsonDict = {}
- update_values: JsonDict = {"category_id": category_id} # This cannot be empty
-
- if profile is None:
- insertion_values["profile"] = "{}"
- else:
- update_values["profile"] = json_encoder.encode(profile)
-
- if is_public is None:
- insertion_values["is_public"] = True
- else:
- update_values["is_public"] = is_public
-
- await self.db_pool.simple_upsert(
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- values=update_values,
- insertion_values=insertion_values,
- desc="upsert_group_category",
- )
-
- async def remove_group_category(self, group_id: str, category_id: str) -> int:
- return await self.db_pool.simple_delete(
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- desc="remove_group_category",
- )
-
- async def upsert_group_role(
- self,
- group_id: str,
- role_id: str,
- profile: Optional[JsonDict],
- is_public: Optional[bool],
- ) -> None:
- """Add/remove user role"""
- insertion_values: JsonDict = {}
- update_values: JsonDict = {"role_id": role_id} # This cannot be empty
-
- if profile is None:
- insertion_values["profile"] = "{}"
- else:
- update_values["profile"] = json_encoder.encode(profile)
-
- if is_public is None:
- insertion_values["is_public"] = True
- else:
- update_values["is_public"] = is_public
-
- await self.db_pool.simple_upsert(
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- values=update_values,
- insertion_values=insertion_values,
- desc="upsert_group_role",
- )
-
- async def remove_group_role(self, group_id: str, role_id: str) -> int:
- return await self.db_pool.simple_delete(
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- desc="remove_group_role",
- )
-
- async def add_user_to_summary(
- self,
- group_id: str,
- user_id: str,
- role_id: Optional[str],
- order: Optional[int],
- is_public: Optional[bool],
- ) -> None:
- """Add (or update) user's entry in summary.
-
- Args:
- group_id
- user_id
- role_id: If not None then adds the role to the end of the summary if
- its not already there.
- order: If not None inserts the user at that position, e.g. an order
- of 1 will put the user first. Otherwise, the user gets added to
- the end.
- is_public
- """
- await self.db_pool.runInteraction(
- "add_user_to_summary",
- self._add_user_to_summary_txn,
- group_id,
- user_id,
- role_id,
- order,
- is_public,
- )
-
- def _add_user_to_summary_txn(
- self,
- txn: LoggingTransaction,
- group_id: str,
- user_id: str,
- role_id: Optional[str],
- order: Optional[int],
- is_public: Optional[bool],
- ) -> None:
- """Add (or update) user's entry in summary.
-
- Args:
- txn
- group_id
- user_id
- role_id: If not None then adds the role to the end of the summary if
- its not already there.
- order: If not None inserts the user at that position, e.g. an order
- of 1 will put the user first. Otherwise, the user gets added to
- the end.
- is_public
- """
- user_in_group = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- )
- if not user_in_group:
- raise SynapseError(400, "user not in group")
-
- if role_id is None:
- role_id = _DEFAULT_ROLE_ID
- else:
- role_exists = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- retcol="group_id",
- allow_none=True,
- )
- if not role_exists:
- raise SynapseError(400, "Role doesn't exist")
-
- # TODO: Check role is part of the summary already
- role_exists = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_summary_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- retcol="group_id",
- allow_none=True,
- )
- if not role_exists:
- # If not, add it with an order larger than all others
- txn.execute(
- """
- INSERT INTO group_summary_roles
- (group_id, role_id, role_order)
- SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
- FROM group_summary_roles
- WHERE group_id = ? AND role_id = ?
- """,
- (group_id, role_id, group_id, role_id),
- )
-
- existing = self.db_pool.simple_select_one_txn(
- txn,
- table="group_summary_users",
- keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
- retcols=("user_order", "is_public"),
- allow_none=True,
- )
-
- if order is not None:
- # Shuffle other users orders that come after the given order
- sql = """
- UPDATE group_summary_users SET user_order = user_order + 1
- WHERE group_id = ? AND role_id = ? AND user_order >= ?
- """
- txn.execute(sql, (group_id, role_id, order))
- elif not existing:
- sql = """
- SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
- WHERE group_id = ? AND role_id = ?
- """
- txn.execute(sql, (group_id, role_id))
- (order,) = cast(Tuple[int], txn.fetchone())
-
- if existing:
- to_update = {}
- if order is not None:
- to_update["user_order"] = order
- if is_public is not None:
- to_update["is_public"] = is_public
- self.db_pool.simple_update_txn(
- txn,
- table="group_summary_users",
- keyvalues={
- "group_id": group_id,
- "role_id": role_id,
- "user_id": user_id,
- },
- updatevalues=to_update,
- )
- else:
- if is_public is None:
- is_public = True
-
- self.db_pool.simple_insert_txn(
- txn,
- table="group_summary_users",
- values={
- "group_id": group_id,
- "role_id": role_id,
- "user_id": user_id,
- "user_order": order,
- "is_public": is_public,
- },
- )
-
- async def remove_user_from_summary(
- self, group_id: str, user_id: str, role_id: Optional[str]
- ) -> int:
- if role_id is None:
- role_id = _DEFAULT_ROLE_ID
-
- return await self.db_pool.simple_delete(
- table="group_summary_users",
- keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
- desc="remove_user_from_summary",
- )
-
- async def add_group_invite(self, group_id: str, user_id: str) -> None:
- """Record that the group server has invited a user"""
- await self.db_pool.simple_insert(
- table="group_invites",
- values={"group_id": group_id, "user_id": user_id},
- desc="add_group_invite",
- )
-
- async def add_user_to_group(
- self,
- group_id: str,
- user_id: str,
- is_admin: bool = False,
- is_public: bool = True,
- local_attestation: Optional[dict] = None,
- remote_attestation: Optional[dict] = None,
- ) -> None:
- """Add a user to the group server.
-
- Args:
- group_id
- user_id
- is_admin
- is_public
- local_attestation: The attestation the GS created to give to the remote
- server. Optional if the user and group are on the same server
- remote_attestation: The attestation given to GS by remote server.
- Optional if the user and group are on the same server
- """
-
- def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_users",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "is_admin": is_admin,
- "is_public": is_public,
- },
- )
-
- self.db_pool.simple_delete_txn(
- txn,
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
-
- if local_attestation:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_attestations_renewals",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": local_attestation["valid_until_ms"],
- },
- )
- if remote_attestation:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_attestations_remote",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json_encoder.encode(remote_attestation),
- },
- )
-
- await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
-
- async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
- def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_attestations_renewals",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_summary_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
-
- await self.db_pool.runInteraction(
- "remove_user_from_group", _remove_user_from_group_txn
- )
-
- async def add_room_to_group(
- self, group_id: str, room_id: str, is_public: bool
- ) -> None:
- await self.db_pool.simple_insert(
- table="group_rooms",
- values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
- desc="add_room_to_group",
- )
-
- async def update_room_in_group_visibility(
- self, group_id: str, room_id: str, is_public: bool
- ) -> int:
- return await self.db_pool.simple_update(
- table="group_rooms",
- keyvalues={"group_id": group_id, "room_id": room_id},
- updatevalues={"is_public": is_public},
- desc="update_room_in_group_visibility",
- )
-
- async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
- def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="group_rooms",
- keyvalues={"group_id": group_id, "room_id": room_id},
- )
-
- self.db_pool.simple_delete_txn(
- txn,
- table="group_summary_rooms",
- keyvalues={"group_id": group_id, "room_id": room_id},
- )
-
- await self.db_pool.runInteraction(
- "remove_room_from_group", _remove_room_from_group_txn
- )
-
- async def update_group_publicity(
- self, group_id: str, user_id: str, publicise: bool
- ) -> None:
- """Update whether the user is publicising their membership of the group"""
- await self.db_pool.simple_update_one(
- table="local_group_membership",
- keyvalues={"group_id": group_id, "user_id": user_id},
- updatevalues={"is_publicised": publicise},
- desc="update_group_publicity",
- )
-
- async def register_user_group_membership(
- self,
- group_id: str,
- user_id: str,
- membership: str,
- is_admin: bool = False,
- content: Optional[JsonDict] = None,
- local_attestation: Optional[dict] = None,
- remote_attestation: Optional[dict] = None,
- is_publicised: bool = False,
- ) -> int:
- """Registers that a local user is a member of a (local or remote) group.
-
- Args:
- group_id: The group the member is being added to.
- user_id: THe user ID to add to the group.
- membership: The type of group membership.
- is_admin: Whether the user should be added as a group admin.
- content: Content of the membership, e.g. includes the inviter
- if the user has been invited.
- local_attestation: If remote group then store the fact that we
- have given out an attestation, else None.
- remote_attestation: If remote group then store the remote
- attestation from the group, else None.
- is_publicised: Whether this should be publicised.
- """
-
- content = content or {}
-
- def _register_user_group_membership_txn(
- txn: LoggingTransaction, next_id: int
- ) -> int:
- # TODO: Upsert?
- self.db_pool.simple_delete_txn(
- txn,
- table="local_group_membership",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_insert_txn(
- txn,
- table="local_group_membership",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "is_admin": is_admin,
- "membership": membership,
- "is_publicised": is_publicised,
- "content": json_encoder.encode(content),
- },
- )
-
- self.db_pool.simple_insert_txn(
- txn,
- table="local_group_updates",
- values={
- "stream_id": next_id,
- "group_id": group_id,
- "user_id": user_id,
- "type": "membership",
- "content": json_encoder.encode(
- {"membership": membership, "content": content}
- ),
- },
- )
- self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
-
- # TODO: Insert profile to ensure it comes down stream if its a join.
-
- if membership == "join":
- if local_attestation:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_attestations_renewals",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": local_attestation["valid_until_ms"],
- },
- )
- if remote_attestation:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_attestations_remote",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json_encoder.encode(remote_attestation),
- },
- )
- else:
- self.db_pool.simple_delete_txn(
- txn,
- table="group_attestations_renewals",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
-
- return next_id
-
- async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
- res = await self.db_pool.runInteraction(
- "register_user_group_membership",
- _register_user_group_membership_txn,
- next_id,
- )
- return res
-
- async def create_group(
- self,
- group_id: str,
- user_id: str,
- name: str,
- avatar_url: str,
- short_description: str,
- long_description: str,
- ) -> None:
- await self.db_pool.simple_insert(
- table="groups",
- values={
- "group_id": group_id,
- "name": name,
- "avatar_url": avatar_url,
- "short_description": short_description,
- "long_description": long_description,
- "is_public": True,
- },
- desc="create_group",
- )
-
- async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
- await self.db_pool.simple_update_one(
- table="groups",
- keyvalues={"group_id": group_id},
- updatevalues=profile,
- desc="update_group_profile",
- )
-
- async def update_attestation_renewal(
- self, group_id: str, user_id: str, attestation: dict
- ) -> None:
- """Update an attestation that we have renewed"""
- await self.db_pool.simple_update_one(
- table="group_attestations_renewals",
- keyvalues={"group_id": group_id, "user_id": user_id},
- updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
- desc="update_attestation_renewal",
- )
-
- async def update_remote_attestion(
- self, group_id: str, user_id: str, attestation: dict
- ) -> None:
- """Update an attestation that a remote has renewed"""
- await self.db_pool.simple_update_one(
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- updatevalues={
- "valid_until_ms": attestation["valid_until_ms"],
- "attestation_json": json_encoder.encode(attestation),
- },
- desc="update_remote_attestion",
- )
-
- async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int:
- """Remove an attestation that we thought we should renew, but actually
- shouldn't. Ideally this would never get called as we would never
- incorrectly try and do attestations for local users on local groups.
-
- Args:
- group_id
- user_id
- """
- return await self.db_pool.simple_delete(
- table="group_attestations_renewals",
- keyvalues={"group_id": group_id, "user_id": user_id},
- desc="remove_attestation_renewal",
- )
-
- def get_group_stream_token(self) -> int:
- return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
-
- async def delete_group(self, group_id: str) -> None:
- """Deletes a group fully from the database.
-
- Args:
- group_id: The group ID to delete.
- """
-
- def _delete_group_txn(txn: LoggingTransaction) -> None:
- tables = [
- "groups",
- "group_users",
- "group_invites",
- "group_rooms",
- "group_summary_rooms",
- "group_summary_room_categories",
- "group_room_categories",
- "group_summary_users",
- "group_summary_roles",
- "group_roles",
- "group_attestations_renewals",
- "group_attestations_remote",
- ]
-
- for table in tables:
- self.db_pool.simple_delete_txn(
- txn, table=table, keyvalues={"group_id": group_id}
- )
-
- await self.db_pool.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index bedacaf0..2d7633fb 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from types import TracebackType
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@@ -84,6 +84,8 @@ class LockStore(SQLBaseStore):
self._on_shutdown,
)
+ self._acquiring_locks: Set[Tuple[str, str]] = set()
+
@wrap_as_background_process("LockStore._on_shutdown")
async def _on_shutdown(self) -> None:
"""Called when the server is shutting down"""
@@ -103,6 +105,21 @@ class LockStore(SQLBaseStore):
context manager if the lock is successfully acquired, which *must* be
used (otherwise the lock will leak).
"""
+ if (lock_name, lock_key) in self._acquiring_locks:
+ return None
+ try:
+ self._acquiring_locks.add((lock_name, lock_key))
+ return await self._try_acquire_lock(lock_name, lock_key)
+ finally:
+ self._acquiring_locks.discard((lock_name, lock_key))
+
+ async def _try_acquire_lock(
+ self, lock_name: str, lock_key: str
+ ) -> Optional["Lock"]:
+ """Try to acquire a lock for the given name/key. Will return an async
+ context manager if the lock is successfully acquired, which *must* be
+ used (otherwise the lock will leak).
+ """
# Check if this process has taken out a lock and if it's still valid.
lock = self._live_tokens.get((lock_name, lock_key))
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 40ac377c..d028be16 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -251,12 +251,36 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
)
- async def get_local_media_before(
+ async def get_local_media_ids(
self,
before_ts: int,
size_gt: int,
keep_profiles: bool,
+ include_quarantined_media: bool,
+ include_protected_media: bool,
) -> List[str]:
+ """
+ Retrieve a list of media IDs from the local media store.
+
+ Args:
+ before_ts: Only retrieve IDs from media that was either last accessed
+ (or if never accessed, created) before the given UNIX timestamp in ms.
+ size_gt: Only retrieve IDs from media that has a size (in bytes) greater than
+ the given integer.
+ keep_profiles: If True, exclude media IDs from the results that are used in the
+ following situations:
+ * global profile user avatar
+ * per-room profile user avatar
+ * room avatar
+ * a user's avatar in the user directory
+ include_quarantined_media: If False, exclude media IDs from the results that have
+ been marked as quarantined.
+ include_protected_media: If False, exclude media IDs from the results that have
+ been marked as protected from quarantine.
+
+ Returns:
+ A list of local media IDs.
+ """
# to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts`
@@ -278,10 +302,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
AND NOT EXISTS
(SELECT 1
- FROM groups
- WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id)
- AND NOT EXISTS
- (SELECT 1
FROM room_memberships
WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id)
AND NOT EXISTS
@@ -298,12 +318,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
sql += sql_keep
- def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
+ if include_quarantined_media is False:
+ # Do not include media that has been quarantined
+ sql += """
+ AND quarantined_by IS NULL
+ """
+
+ if include_protected_media is False:
+ # Do not include media that has been protected from quarantine
+ sql += """
+ AND NOT safe_from_quarantine
+ """
+
+ def _get_local_media_ids_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts, before_ts, size_gt))
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
- "get_local_media_before", _get_local_media_before_txn
+ "get_local_media_ids", _get_local_media_ids_txn
)
async def store_local_media(
@@ -603,15 +635,37 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
- async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
+ async def get_remote_media_ids(
+ self, before_ts: int, include_quarantined_media: bool
+ ) -> List[Dict[str, str]]:
+ """
+ Retrieve a list of server name, media ID tuples from the remote media cache.
+
+ Args:
+ before_ts: Only retrieve IDs from media that was either last accessed
+ (or if never accessed, created) before the given UNIX timestamp in ms.
+ include_quarantined_media: If False, exclude media IDs from the results that have
+ been marked as quarantined.
+
+ Returns:
+ A list of tuples containing:
+ * The server name of homeserver where the media originates from,
+ * The ID of the media.
+ """
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
+ if include_quarantined_media is False:
+ # Only include media that has not been quarantined
+ sql += """
+ AND quarantined_by IS NULL
+ """
+
return await self.db_pool.execute(
- "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
+ "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f0..14294a0b 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,12 +14,16 @@
import calendar
import logging
import time
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, List, Tuple, cast
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
@@ -71,8 +75,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
self._last_user_visit_update = self._get_start_of_day()
@wrap_as_background_process("read_forward_extremities")
- async def _read_forward_extremities(self):
- def fetch(txn):
+ async def _read_forward_extremities(self) -> None:
+ def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
txn.execute(
"""
SELECT t1.c, t2.c
@@ -85,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) t2 ON t1.room_id = t2.room_id
"""
)
- return txn.fetchall()
+ return cast(List[Tuple[int, int]], txn.fetchall())
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
@@ -95,7 +99,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
- async def count_daily_e2ee_messages(self):
+ async def count_daily_e2ee_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@@ -103,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
- def _count_messages(txn):
+ def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
- async def count_daily_sent_e2ee_messages(self):
- def _count_messages(txn):
+ async def count_daily_sent_e2ee_messages(self) -> int:
+ def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@@ -129,29 +133,29 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
"count_daily_sent_e2ee_messages", _count_messages
)
- async def count_daily_active_e2ee_rooms(self):
- def _count(txn):
+ async def count_daily_active_e2ee_rooms(self) -> int:
+ def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
"count_daily_active_e2ee_rooms", _count
)
- async def count_daily_messages(self):
+ async def count_daily_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@@ -159,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
- def _count_messages(txn):
+ def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_messages", _count_messages)
- async def count_daily_sent_messages(self):
- def _count_messages(txn):
+ async def count_daily_sent_messages(self) -> int:
+ def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@@ -185,22 +189,22 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages
)
- async def count_daily_active_rooms(self):
- def _count(txn):
+ async def count_daily_active_rooms(self) -> int:
+ def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -226,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)
- def _count_users(self, txn, time_from):
+ def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@@ -238,7 +242,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) u
"""
txn.execute(sql, (time_from,))
- (count,) = txn.fetchone()
+ # Mypy knows that fetchone() might return None if there are no rows.
+ # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+ # returns exactly one row.
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
async def count_r30_users(self) -> Dict[str, int]:
@@ -252,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
A mapping of counts globally as well as broken out by platform.
"""
- def _count_r30_users(txn):
+ def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -317,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count
return results
@@ -344,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""
- def _count_r30v2_users(txn):
+ def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -441,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
thirty_days_in_secs * 1000,
),
)
- row = txn.fetchone()
- if row is None:
- results["all"] = 0
- else:
- results["all"] = row[0]
+ (count,) = cast(Tuple[int], txn.fetchone())
+ results["all"] = count
return results
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_r30v2_users", _count_r30v2_users
)
- def _get_start_of_day(self):
+ def _get_start_of_day(self) -> int:
"""
Returns millisecond unixtime for start of UTC day.
"""
@@ -467,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
Generates daily visit data for use in cohort/ retention analysis
"""
- def _generate_user_daily_visits(txn):
+ def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 5beb8f1d..9a63f953 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -122,6 +122,51 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
"count_users_by_service", _count_users_by_service
)
+ async def get_monthly_active_users_by_service(
+ self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None
+ ) -> List[Tuple[str, str]]:
+ """Generates list of monthly active users and their services.
+ Please see "get_monthly_active_count_by_service" docstring for more details
+ about services.
+
+ Arguments:
+ start_timestamp: If specified, only include users that were first active
+ at or after this point
+ end_timestamp: If specified, only include users that were first active
+ at or before this point
+
+ Returns:
+ A list of tuples (appservice_id, user_id). "native" is emitted as the
+ appservice for users that don't come from appservices (i.e. native Matrix
+ users).
+
+ """
+ if start_timestamp is not None and end_timestamp is not None:
+ where_clause = 'WHERE "timestamp" >= ? and "timestamp" <= ?'
+ query_params = [start_timestamp, end_timestamp]
+ elif start_timestamp is not None:
+ where_clause = 'WHERE "timestamp" >= ?'
+ query_params = [start_timestamp]
+ elif end_timestamp is not None:
+ where_clause = 'WHERE "timestamp" <= ?'
+ query_params = [end_timestamp]
+ else:
+ where_clause = ""
+ query_params = []
+
+ def _list_users(txn: LoggingTransaction) -> List[Tuple[str, str]]:
+ sql = f"""
+ SELECT COALESCE(appservice_id, 'native'), user_id
+ FROM monthly_active_users
+ LEFT JOIN users ON monthly_active_users.user_id=users.name
+ {where_clause};
+ """
+
+ txn.execute(sql, query_params)
+ return cast(List[Tuple[str, str]], txn.fetchall())
+
+ return await self.db_pool.runInteraction("list_users", _list_users)
+
async def get_registered_reserved_users(self) -> List[str]:
"""Of the reserved threepids defined in config, retrieve those that are associated
with registered users
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index b47c5114..9769a18a 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -22,6 +22,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import (
@@ -56,7 +57,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
)
-class PresenceStore(PresenceBackgroundUpdateStore):
+class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -281,20 +282,30 @@ class PresenceStore(PresenceBackgroundUpdateStore):
True if the user should have full presence sent to them, False otherwise.
"""
- def _should_user_receive_full_presence_with_token_txn(
- txn: LoggingTransaction,
- ) -> bool:
- sql = """
- SELECT 1 FROM users_to_send_full_presence_to
- WHERE user_id = ?
- AND presence_stream_id >= ?
- """
- txn.execute(sql, (user_id, from_token))
- return bool(txn.fetchone())
+ token = await self._get_full_presence_stream_token_for_user(user_id)
+ if token is None:
+ return False
- return await self.db_pool.runInteraction(
- "should_user_receive_full_presence_with_token",
- _should_user_receive_full_presence_with_token_txn,
+ return from_token <= token
+
+ @cached()
+ async def _get_full_presence_stream_token_for_user(
+ self, user_id: str
+ ) -> Optional[int]:
+ """Get the presence token corresponding to the last full presence update
+ for this user.
+
+ If the user presents a sync token with a presence stream token at least
+ as old as the result, then we need to send them a full presence update.
+
+ If this user has never needed a full presence update, returns `None`.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="users_to_send_full_presence_to",
+ keyvalues={"user_id": user_id},
+ retcol="presence_stream_id",
+ allow_none=True,
+ desc="_get_full_presence_stream_token_for_user",
)
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
@@ -307,18 +318,28 @@ class PresenceStore(PresenceBackgroundUpdateStore):
# Add user entries to the table, updating the presence_stream_id column if the user already
# exists in the table.
presence_stream_id = self._presence_id_gen.get_current_token()
- await self.db_pool.simple_upsert_many(
- table="users_to_send_full_presence_to",
- key_names=("user_id",),
- key_values=[(user_id,) for user_id in user_ids],
- value_names=("presence_stream_id",),
- # We save the current presence stream ID token along with the user ID entry so
- # that when a user /sync's, even if they syncing multiple times across separate
- # devices at different times, each device will receive full presence once - when
- # the presence stream ID in their sync token is less than the one in the table
- # for their user ID.
- value_values=[(presence_stream_id,) for _ in user_ids],
- desc="add_users_to_send_full_presence_to",
+
+ def _add_users_to_send_full_presence_to(txn: LoggingTransaction) -> None:
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="users_to_send_full_presence_to",
+ key_names=("user_id",),
+ key_values=[(user_id,) for user_id in user_ids],
+ value_names=("presence_stream_id",),
+ # We save the current presence stream ID token along with the user ID entry so
+ # that when a user /sync's, even if they syncing multiple times across separate
+ # devices at different times, each device will receive full presence once - when
+ # the presence stream ID in their sync token is less than the one in the table
+ # for their user ID.
+ value_values=[(presence_stream_id,) for _ in user_ids],
+ )
+ for user_id in user_ids:
+ self._invalidate_cache_and_stream(
+ txn, self._get_full_presence_stream_token_for_user, (user_id,)
+ )
+
+ return await self.db_pool.runInteraction(
+ "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to
)
async def get_presence_for_all_users(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index e197b720..a1747f04 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
@@ -55,17 +54,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
- async def get_from_remote_profile_cache(
- self, user_id: str
- ) -> Optional[Dict[str, Any]]:
- return await self.db_pool.simple_select_one(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- retcols=("displayname", "avatar_url"),
- allow_none=True,
- desc="get_from_remote_profile_cache",
- )
-
async def create_profile(self, user_localpart: str) -> None:
await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
@@ -91,97 +79,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="set_profile_avatar_url",
)
- async def update_remote_profile_cache(
- self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
- ) -> int:
- return await self.db_pool.simple_update(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- updatevalues={
- "displayname": displayname,
- "avatar_url": avatar_url,
- "last_check": self._clock.time_msec(),
- },
- desc="update_remote_profile_cache",
- )
-
- async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
- """Check if we still care about the remote user's profile, and if we
- don't then remove their profile from the cache
- """
- subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
- if not subscribed:
- await self.db_pool.simple_delete(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- desc="delete_remote_profile_cache",
- )
-
- async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
- """Check whether we are interested in a remote user's profile."""
- res: Optional[str] = await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
- )
-
- if res:
- return True
-
- res = await self.db_pool.simple_select_one_onecol(
- table="group_invites",
- keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
- )
-
- if res:
- return True
- return False
-
- async def get_remote_profile_cache_entries_that_expire(
- self, last_checked: int
- ) -> List[Dict[str, str]]:
- """Get all users who haven't been checked since `last_checked`"""
-
- def _get_remote_profile_cache_entries_that_expire_txn(
- txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
- sql = """
- SELECT user_id, displayname, avatar_url
- FROM remote_profile_cache
- WHERE last_check < ?
- """
-
- txn.execute(sql, (last_checked,))
-
- return self.db_pool.cursor_to_dict(txn)
-
- return await self.db_pool.runInteraction(
- "get_remote_profile_cache_entries_that_expire",
- _get_remote_profile_cache_entries_that_expire_txn,
- )
-
class ProfileStore(ProfileWorkerStore):
- async def add_remote_profile_cache(
- self, user_id: str, displayname: str, avatar_url: str
- ) -> None:
- """Ensure we are caching the remote user's profiles.
-
- This should only be called when `is_subscribed_remote_profile_for_user`
- would return true for the user.
- """
- await self.db_pool.simple_upsert(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- values={
- "displayname": displayname,
- "avatar_url": avatar_url,
- "last_check": self._clock.time_msec(),
- },
- desc="add_remote_profile_cache",
- )
+ pass
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bfc85b3a..ba385f9f 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# event_forward_extremities
# event_json
# event_push_actions
- # event_reference_hashes
# event_relations
# event_search
# event_to_state_groups
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_auth",
"event_edges",
"event_forward_extremities",
- "event_reference_hashes",
"event_relations",
"event_search",
"rejections",
@@ -324,12 +322,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
)
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
- # We *immediately* delete the room from the rooms table. This ensures
- # that we don't race when persisting events (as that transaction checks
- # that the room exists).
- txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,))
-
- # Next, we fetch all the state groups that should be deleted, before
+ # First, fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
"""
@@ -369,7 +362,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_edges",
"event_json",
"event_push_actions_staging",
- "event_reference_hashes",
"event_relations",
"event_to_state_groups",
"event_auth_chains",
@@ -390,7 +382,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
(room_id,),
)
- # and finally, the tables with an index on room_id (or no useful index)
+ # next, the tables with an index on room_id (or no useful index)
for table in (
"current_state_events",
"destination_rooms",
@@ -398,8 +390,12 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_forward_extremities",
"event_push_actions",
"event_search",
+ "partial_state_events",
"events",
- "group_rooms",
+ "federation_inbound_events_staging",
+ "local_current_membership",
+ "partial_state_rooms_servers",
+ "partial_state_rooms",
"receipts_graph",
"receipts_linearized",
"room_aliases",
@@ -416,10 +412,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
- "group_summary_rooms",
"room_account_data",
"room_tags",
- "local_current_membership",
+ # "rooms" happens last, to keep the foreign keys in the other tables
+ # happy
+ "rooms",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 4ed913e2..d5aefe02 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,14 +14,18 @@
# limitations under the License.
import abc
import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
AbstractStreamIdTracker,
+ IdGenerator,
StreamIdGenerator,
)
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -54,10 +61,19 @@ def _is_experimental_rule_enabled(
and not experimental_config.msc3786_enabled
):
return False
+ if (
+ rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+ and not experimental_config.msc3772_enabled
+ ):
+ return False
return True
-def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
+def _load_rules(
+ rawrules: List[JsonDict],
+ enabled_map: Dict[str, bool],
+ experimental_config: ExperimentalConfig,
+) -> List[JsonDict]:
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@@ -137,7 +153,7 @@ class PushRulesWorkerStore(
)
@abc.abstractmethod
- def get_max_push_rules_stream_id(self):
+ def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream.
Returns:
@@ -146,7 +162,7 @@ class PushRulesWorkerStore(
raise NotImplementedError()
@cached(max_entries=5000)
- async def get_push_rules_for_user(self, user_id):
+ async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -158,7 +174,7 @@ class PushRulesWorkerStore(
"conditions",
"actions",
),
- desc="get_push_rules_enabled_for_user",
+ desc="get_push_rules_for_user",
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
@@ -168,14 +184,14 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, self.hs.config.experimental)
@cached(max_entries=5000)
- async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
+ async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
- retcols=("user_name", "rule_id", "enabled"),
+ retcols=("rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
)
- return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
+ return {r["rule_id"]: bool(r["enabled"]) for r in results}
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
@@ -184,29 +200,27 @@ class PushRulesWorkerStore(
return False
else:
- def have_push_rules_changed_txn(txn):
+ def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return bool(count)
return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
- @cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- )
- async def bulk_get_push_rules(self, user_ids):
+ @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
+ async def bulk_get_push_rules(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, List[JsonDict]]:
if not user_ids:
return {}
- results = {user_id: [] for user_id in user_ids}
+ results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -230,67 +244,16 @@ class PushRulesWorkerStore(
return results
- async def copy_push_rule_from_room_to_room(
- self, new_room_id: str, user_id: str, rule: dict
- ) -> None:
- """Copy a single push rule from one room to another for a specific user.
-
- Args:
- new_room_id: ID of the new room.
- user_id : ID of user the push rule belongs to.
- rule: A push rule.
- """
- # Create new rule id
- rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
- new_rule_id = rule_id_scope + "/" + new_room_id
-
- # Change room id in each condition
- for condition in rule.get("conditions", []):
- if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
-
- # Add the rule for the new room
- await self.add_push_rule(
- user_id=user_id,
- rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
- )
-
- async def copy_push_rules_from_room_to_room_for_user(
- self, old_room_id: str, new_room_id: str, user_id: str
- ) -> None:
- """Copy all of the push rules from one room to another for a specific
- user.
-
- Args:
- old_room_id: ID of the old room.
- new_room_id: ID of the new room.
- user_id: ID of user to copy push rules for.
- """
- # Retrieve push rules for this user
- user_push_rules = await self.get_push_rules_for_user(user_id)
-
- # Get rules relating to the old room and copy them to the new room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
- if any(
- (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
- for c in conditions
- ):
- await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
@cachedList(
- cached_method_name="get_push_rules_enabled_for_user",
- list_name="user_ids",
- num_args=1,
+ cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
)
- async def bulk_get_push_rules_enabled(self, user_ids):
+ async def bulk_get_push_rules_enabled(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, Dict[str, bool]]:
if not user_ids:
return {}
- results = {user_id: {} for user_id in user_ids}
+ results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
@@ -306,7 +269,7 @@ class PushRulesWorkerStore(
async def get_all_push_rule_updates(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
"""Get updates for push_rules replication stream.
Args:
@@ -331,7 +294,9 @@ class PushRulesWorkerStore(
if last_id == current_id:
return [], current_id, False
- def get_all_push_rule_updates_txn(txn):
+ def get_all_push_rule_updates_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
sql = """
SELECT stream_id, user_id
FROM push_rules_stream
@@ -340,7 +305,10 @@ class PushRulesWorkerStore(
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
+ updates = cast(
+ List[Tuple[int, Tuple[str]]],
+ [(stream_id, (user_id,)) for stream_id, user_id in txn],
+ )
limited = False
upper_bound = current_id
@@ -356,15 +324,30 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
+ # Because we have write access, this will be a StreamIdGenerator
+ # (see PushRulesWorkerStore.__init__)
+ _push_rules_stream_id_gen: AbstractStreamIdGenerator
+
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
async def add_push_rule(
self,
- user_id,
- rule_id,
- priority_class,
- conditions,
- actions,
- before=None,
- after=None,
+ user_id: str,
+ rule_id: str,
+ priority_class: int,
+ conditions: List[Dict[str, str]],
+ actions: List[Union[JsonDict, str]],
+ before: Optional[str] = None,
+ after: Optional[str] = None,
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
@@ -400,17 +383,17 @@ class PushRuleStore(PushRulesWorkerStore):
def _add_push_rule_relative_txn(
self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- before,
- after,
- ):
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ priority_class: int,
+ conditions_json: str,
+ actions_json: str,
+ before: str,
+ after: str,
+ ) -> None:
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
@@ -470,15 +453,15 @@ class PushRuleStore(PushRulesWorkerStore):
def _add_push_rule_highest_priority_txn(
self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- ):
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ priority_class: int,
+ conditions_json: str,
+ actions_json: str,
+ ) -> None:
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
@@ -510,17 +493,17 @@ class PushRuleStore(PushRulesWorkerStore):
def _upsert_push_rule_txn(
self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- priority,
- conditions_json,
- actions_json,
- update_stream=True,
- ):
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ priority_class: int,
+ priority: int,
+ conditions_json: str,
+ actions_json: str,
+ update_stream: bool = True,
+ ) -> None:
"""Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked"""
@@ -600,7 +583,11 @@ class PushRuleStore(PushRulesWorkerStore):
rule_id: The rule_id of the rule to be deleted
"""
- def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ def delete_push_rule_txn(
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ ) -> None:
# we don't use simple_delete_one_txn because that would fail if the
# user did not have a push_rule_enable row.
self.db_pool.simple_delete_txn(
@@ -661,14 +648,14 @@ class PushRuleStore(PushRulesWorkerStore):
def _set_push_rule_enabled_txn(
self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- enabled,
- is_default_rule,
- ):
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ enabled: bool,
+ is_default_rule: bool,
+ ) -> None:
new_id = self._push_rules_enable_id_gen.get_next()
if not is_default_rule:
@@ -740,7 +727,11 @@ class PushRuleStore(PushRulesWorkerStore):
"""
actions_json = json_encoder.encode(actions)
- def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+ def set_push_rule_actions_txn(
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ ) -> None:
if is_default_rule:
# Add a dummy rule to the rules table with the user specified
# actions.
@@ -794,8 +785,15 @@ class PushRuleStore(PushRulesWorkerStore):
)
def _insert_push_rules_update_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
- ):
+ self,
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ op: str,
+ data: Optional[JsonDict] = None,
+ ) -> None:
values = {
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
@@ -814,5 +812,56 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_max_push_rules_stream_id(self):
+ def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token()
+
+ async def copy_push_rule_from_room_to_room(
+ self, new_room_id: str, user_id: str, rule: dict
+ ) -> None:
+ """Copy a single push rule from one room to another for a specific user.
+
+ Args:
+ new_room_id: ID of the new room.
+ user_id : ID of user the push rule belongs to.
+ rule: A push rule.
+ """
+ # Create new rule id
+ rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ new_rule_id = rule_id_scope + "/" + new_room_id
+
+ # Change room id in each condition
+ for condition in rule.get("conditions", []):
+ if condition.get("key") == "room_id":
+ condition["pattern"] = new_room_id
+
+ # Add the rule for the new room
+ await self.add_push_rule(
+ user_id=user_id,
+ rule_id=new_rule_id,
+ priority_class=rule["priority_class"],
+ conditions=rule["conditions"],
+ actions=rule["actions"],
+ )
+
+ async def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id: str, new_room_id: str, user_id: str
+ ) -> None:
+ """Copy all of the push rules from one room to another for a specific
+ user.
+
+ Args:
+ old_room_id: ID of the old room.
+ new_room_id: ID of the new room.
+ user_id: ID of user to copy push rules for.
+ """
+ # Retrieve push rules for this user
+ user_push_rules = await self.get_push_rules_for_user(user_id)
+
+ # Get rules relating to the old room and copy them to the new room
+ for rule in user_push_rules:
+ conditions = rule.get("conditions", [])
+ if any(
+ (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+ for c in conditions
+ ):
+ await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 91286c9b..bd0cfa7f 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -91,12 +91,6 @@ class PusherWorkerStore(SQLBaseStore):
yield PusherConfig(**r)
- async def user_has_pusher(self, user_id: str) -> bool:
- ret = await self.db_pool.simple_select_one_onecol(
- "pushers", {"user_name": user_id}, "id", allow_none=True
- )
- return ret is not None
-
async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
) -> Iterator[PusherConfig]:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index d035969a..21e954cc 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -26,7 +26,7 @@ from typing import (
cast,
)
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -363,7 +363,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
row["user_id"]
] = db_to_json(row["data"])
- return [{"type": "m.receipt", "room_id": room_id, "content": content}]
+ return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@cachedList(
cached_method_name="_get_linearized_receipts_for_room",
@@ -411,7 +411,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
- {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+ {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
@@ -476,7 +476,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
- {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+ {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
@@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
- def insert_linearized_receipt_txn(
+ def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
@@ -673,8 +673,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
lock=False,
)
+ # When updating a local users read receipt, remove any push actions
+ # which resulted from the receipt's event and all earlier events.
if (
- receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
+ self.hs.is_mine_id(user_id)
+ and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
and stream_ordering is not None
):
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
@@ -683,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rx_ts
+ def _graph_to_linear(
+ self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
+ ) -> str:
+ """
+ Generate a linearized event from a list of events (i.e. a list of forward
+ extremities in the room).
+
+ This should allow for calculation of the correct read receipt even if
+ servers have different event ordering.
+
+ Args:
+ txn: The transaction
+ room_id: The room ID the events are in.
+ event_ids: The list of event IDs to linearize.
+
+ Returns:
+ The linearized event ID.
+ """
+ # TODO: Make this better.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", event_ids
+ )
+
+ sql = """
+ SELECT event_id WHERE room_id = ? AND stream_ordering IN (
+ SELECT max(stream_ordering) WHERE %s
+ )
+ """ % (
+ clause,
+ )
+
+ txn.execute(sql, [room_id] + list(args))
+ rows = txn.fetchall()
+ if rows:
+ return rows[0][0]
+ else:
+ raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
+
async def insert_receipt(
self,
room_id: str,
@@ -709,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
- # TODO: Make this better.
- def graph_to_linear(txn: LoggingTransaction) -> str:
- clause, args = make_in_list_sql_clause(
- self.database_engine, "event_id", event_ids
- )
-
- sql = """
- SELECT event_id WHERE room_id = ? AND stream_ordering IN (
- SELECT max(stream_ordering) WHERE %s
- )
- """ % (
- clause,
- )
-
- txn.execute(sql, [room_id] + list(args))
- rows = txn.fetchall()
- if rows:
- return rows[0][0]
- else:
- raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
-
linearized_event_id = await self.db_pool.runInteraction(
- "insert_receipt_conv", graph_to_linear
+ "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
)
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
- self.insert_linearized_receipt_txn,
+ self._insert_linearized_receipt_txn,
room_id,
receipt_type,
user_id,
@@ -758,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
now - event_ts,
)
- await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
-
- max_persisted_id = self._receipts_id_gen.get_current_token()
-
- return stream_id, max_persisted_id
-
- async def insert_graph_receipt(
- self,
- room_id: str,
- receipt_type: str,
- user_id: str,
- event_ids: List[str],
- data: JsonDict,
- ) -> None:
- assert self._can_write_to_receipts
-
await self.db_pool.runInteraction(
"insert_graph_receipt",
- self.insert_graph_receipt_txn,
+ self._insert_graph_receipt_txn,
room_id,
receipt_type,
user_id,
@@ -784,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- def insert_graph_receipt_txn(
+ max_persisted_id = self._receipts_id_gen.get_current_token()
+
+ return stream_id, max_persisted_id
+
+ def _insert_graph_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 484976ca..b457bc18 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
logger = logging.getLogger(__name__)
@@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore):
if len(events) > limit and last_topo_id and last_stream_id:
next_key = RoomStreamToken(last_topo_id, last_stream_id)
if from_token:
- next_token = from_token.copy_and_replace("room_key", next_key)
+ next_token = from_token.copy_and_replace(
+ StreamKeyType.ROOM, next_key
+ )
else:
next_token = StreamToken(
room_key=next_key,
@@ -765,6 +767,59 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ @cached(iterable=True)
+ async def get_mutual_event_relations_for_rel_type(
+ self, event_id: str, relation_type: str
+ ) -> Set[Tuple[str, str]]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="get_mutual_event_relations_for_rel_type",
+ list_name="relation_types",
+ )
+ async def get_mutual_event_relations(
+ self, event_id: str, relation_types: Collection[str]
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ """
+ Fetch event metadata for events which related to the same event as the given event.
+
+ If the given event has no relation information, returns an empty dictionary.
+
+ Args:
+ event_id: The event ID which is targeted by relations.
+ relation_types: The relation types to check for mutual relations.
+
+ Returns:
+ A dictionary of relation type to:
+ A set of tuples of:
+ The sender
+ The event type
+ """
+ rel_type_sql, rel_type_args = make_in_list_sql_clause(
+ self.database_engine, "relation_type", relation_types
+ )
+
+ sql = f"""
+ SELECT DISTINCT relation_type, sender, type FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND {rel_type_sql}
+ """
+
+ def _get_event_relations(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ txn.execute(sql, [event_id] + rel_type_args)
+ result: Dict[str, Set[Tuple[str, str]]] = {
+ rel_type: set() for rel_type in relation_types
+ }
+ for rel_type, sender, type in txn.fetchall():
+ result[rel_type].add((sender, type))
+ return result
+
+ return await self.db_pool.runInteraction(
+ "get_event_relations", _get_event_relations
+ )
+
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 87e9482c..68d4fc2e 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -23,6 +23,7 @@ from typing import (
Collection,
Dict,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -45,7 +46,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
-from synapse.types import JsonDict, ThirdPartyInstanceID
+from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import MXC_REGEX
@@ -233,24 +234,23 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
UNION SELECT room_id from appservice_room_list
"""
- sql = """
+ sql = f"""
SELECT
COUNT(*)
FROM (
- %(published_sql)s
+ {published_sql}
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ join_rules = '{JoinRules.PUBLIC}'
+ OR join_rules = '{JoinRules.KNOCK}'
+ OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
OR history_visibility = 'world_readable'
)
AND joined_members > 0
- """ % {
- "published_sql": published_sql,
- "knock_join_rule": JoinRules.KNOCK,
- }
+ """
txn.execute(sql, query_args)
return cast(Tuple[int], txn.fetchone())[0]
@@ -369,29 +369,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
if where_clauses:
where_clause = " AND " + " AND ".join(where_clauses)
- sql = """
+ dir = "DESC" if forwards else "ASC"
+ sql = f"""
SELECT
room_id, name, topic, canonical_alias, joined_members,
avatar, history_visibility, guest_access, join_rules
FROM (
- %(published_sql)s
+ {published_sql}
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ join_rules = '{JoinRules.PUBLIC}'
+ OR join_rules = '{JoinRules.KNOCK}'
+ OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
OR history_visibility = 'world_readable'
)
AND joined_members > 0
- %(where_clause)s
- ORDER BY joined_members %(dir)s, room_id %(dir)s
- """ % {
- "published_sql": published_sql,
- "where_clause": where_clause,
- "dir": "DESC" if forwards else "ASC",
- "knock_join_rule": JoinRules.KNOCK,
- }
+ {where_clause}
+ ORDER BY
+ joined_members {dir},
+ room_id {dir}
+ """
if limit is not None:
query_args.append(limit)
@@ -699,7 +699,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
@cached()
- async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
+ async def get_retention_policy_for_room(self, room_id: str) -> RetentionPolicy:
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@@ -707,12 +707,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
the 'max_lifetime' if no default policy has been defined in the server's
configuration).
+ If support for retention policies is disabled, a policy with a 'min_lifetime' and
+ 'max_lifetime' of None is returned.
+
Args:
room_id: The ID of the room to get the retention policy of.
Returns:
A dict containing "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum. This prevents incorrectly filtering out events when sending to
+ # the client.
+ if not self.config.retention.retention_enabled:
+ return RetentionPolicy()
def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
@@ -736,10 +744,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# If we don't know this room ID, ret will be None, in this case return the default
# policy.
if not ret:
- return {
- "min_lifetime": self.config.retention.retention_default_min_lifetime,
- "max_lifetime": self.config.retention.retention_default_max_lifetime,
- }
+ return RetentionPolicy(
+ min_lifetime=self.config.retention.retention_default_min_lifetime,
+ max_lifetime=self.config.retention.retention_default_max_lifetime,
+ )
min_lifetime = ret[0]["min_lifetime"]
max_lifetime = ret[0]["max_lifetime"]
@@ -754,10 +762,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
if max_lifetime is None:
max_lifetime = self.config.retention.retention_default_max_lifetime
- return {
- "min_lifetime": min_lifetime,
- "max_lifetime": max_lifetime,
- }
+ return RetentionPolicy(
+ min_lifetime=min_lifetime,
+ max_lifetime=max_lifetime,
+ )
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
@@ -994,7 +1002,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
- ) -> Dict[str, Dict[str, Optional[int]]]:
+ ) -> Dict[str, RetentionPolicy]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
@@ -1016,7 +1024,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_rooms_for_retention_period_in_range_txn(
txn: LoggingTransaction,
- ) -> Dict[str, Dict[str, Optional[int]]]:
+ ) -> Dict[str, RetentionPolicy]:
range_conditions = []
args = []
@@ -1047,10 +1055,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
rooms_dict = {}
for row in rows:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": row["min_lifetime"],
- "max_lifetime": row["max_lifetime"],
- }
+ rooms_dict[row["room_id"]] = RetentionPolicy(
+ min_lifetime=row["min_lifetime"],
+ max_lifetime=row["max_lifetime"],
+ )
if include_null:
# If required, do a second query that retrieves all of the rooms we know
@@ -1065,10 +1073,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# policy in its state), add it with a null policy.
for row in rows:
if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": None,
- "max_lifetime": None,
- }
+ rooms_dict[row["room_id"]] = RetentionPolicy()
return rooms_dict
@@ -1077,6 +1082,32 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
+ async def get_partial_state_rooms_and_servers(
+ self,
+ ) -> Mapping[str, Collection[str]]:
+ """Get all rooms containing events with partial state, and the servers known
+ to be in the room.
+
+ Returns:
+ A dictionary of rooms with partial state, with room IDs as keys and
+ lists of servers in rooms as values.
+ """
+ room_servers: Dict[str, List[str]] = {}
+
+ rows = await self.db_pool.simple_select_list(
+ "partial_state_rooms_servers",
+ keyvalues=None,
+ retcols=("room_id", "server_name"),
+ desc="get_partial_state_rooms",
+ )
+
+ for row in rows:
+ room_id = row["room_id"]
+ server_name = row["server_name"]
+ room_servers.setdefault(room_id, []).append(server_name)
+
+ return room_servers
+
async def clear_partial_state_room(self, room_id: str) -> bool:
# this can race with incoming events, so we watch out for FK errors.
# TODO(faster_joins): this still doesn't completely fix the race, since the persist process
@@ -1108,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"room_id": room_id},
)
+ async def is_partial_state_room(self, room_id: str) -> bool:
+ """Checks if this room has partial state.
+
+ Returns true if this is a "partial-state" room, which means that the state
+ at events in the room, and `current_state_events`, may not yet be
+ complete.
+ """
+
+ entry = await self.db_pool.simple_select_one_onecol(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="room_id",
+ allow_none=True,
+ desc="is_partial_state_room",
+ )
+
+ return entry is not None
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 48e83592..31bc8c56 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,6 +15,7 @@
import logging
from typing import (
TYPE_CHECKING,
+ Callable,
Collection,
Dict,
FrozenSet,
@@ -37,7 +38,12 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@@ -46,7 +52,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -115,7 +121,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@wrap_as_background_process("_count_known_servers")
- async def _count_known_servers(self):
+ async def _count_known_servers(self) -> int:
"""
Count the servers that this server knows about.
@@ -123,7 +129,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
`synapse_federation_known_servers` LaterGauge to collect.
"""
- def _transact(txn):
+ def _transact(txn: LoggingTransaction) -> int:
if isinstance(self.database_engine, Sqlite3Engine):
query = """
SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
@@ -150,7 +156,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._known_servers_count = max([count, 1])
return self._known_servers_count
- def _check_safe_current_state_events_membership_updated_txn(self, txn):
+ def _check_safe_current_state_events_membership_updated_txn(
+ self, txn: LoggingTransaction
+ ) -> None:
"""Checks if it is safe to assume the new current_state_events
membership column is up to date
"""
@@ -182,7 +190,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_users_in_room", self.get_users_in_room_txn, room_id
)
- def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
+ def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
@@ -222,7 +230,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
A mapping from user ID to ProfileInfo.
"""
- def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
+ def _get_users_in_room_with_profiles(
+ txn: LoggingTransaction,
+ ) -> Dict[str, ProfileInfo]:
sql = """
SELECT state_key, display_name, avatar_url FROM room_memberships as m
INNER JOIN current_state_events as c
@@ -250,7 +260,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
dict of membership states, pointing to a MemberSummary named tuple.
"""
- def _get_room_summary_txn(txn):
+ def _get_room_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, MemberSummary]:
# first get counts.
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
@@ -279,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (room_id,))
- res = {}
+ res: Dict[str, MemberSummary] = {}
for count, membership in txn:
res.setdefault(membership, MemberSummary([], count))
@@ -400,7 +412,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def _get_rooms_for_local_user_where_membership_is_txn(
self,
- txn,
+ txn: LoggingTransaction,
user_id: str,
membership_list: List[str],
) -> List[RoomsForUser]:
@@ -488,7 +500,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
def _get_rooms_for_user_with_stream_ordering_txn(
- self, txn, user_id: str
+ self, txn: LoggingTransaction, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
@@ -542,7 +554,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
def _get_rooms_for_users_with_stream_ordering_txn(
- self, txn, user_ids: Collection[str]
+ self, txn: LoggingTransaction, user_ids: Collection[str]
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
clause, args = make_in_list_sql_clause(
@@ -575,7 +587,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, [Membership.JOIN] + args)
- result = {user_id: set() for user_id in user_ids}
+ result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
+ user_id: set() for user_id in user_ids
+ }
for user_id, room_id, instance, stream_id in txn:
result[user_id].add(
GetRoomsForUserWithStreamOrdering(
@@ -595,7 +609,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if not user_ids:
return set()
- def _get_users_server_still_shares_room_with_txn(txn):
+ def _get_users_server_still_shares_room_with_txn(
+ txn: LoggingTransaction,
+ ) -> Set[str]:
sql = """
SELECT state_key FROM current_state_events
WHERE
@@ -619,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
async def get_rooms_for_user(
- self, user_id: str, on_invalidate=None
+ self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
@@ -654,10 +670,34 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return user_who_share_room
+ @cached(cache_context=True, iterable=True)
+ async def get_mutual_rooms_between_users(
+ self, user_ids: FrozenSet[str], cache_context: _CacheContext
+ ) -> FrozenSet[str]:
+ """
+ Returns the set of rooms that all users in `user_ids` share.
+
+ Args:
+ user_ids: A frozen set of all users to investigate and return
+ overlapping joined rooms for.
+ cache_context
+ """
+ shared_room_ids: Optional[FrozenSet[str]] = None
+ for user_id in user_ids:
+ room_ids = await self.get_rooms_for_user(
+ user_id, on_invalidate=cache_context.invalidate
+ )
+ if shared_room_ids is not None:
+ shared_room_ids &= room_ids
+ else:
+ shared_room_ids = room_ids
+
+ return shared_room_ids or frozenset()
+
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
) -> Dict[str, ProfileInfo]:
- state_group = context.state_group
+ state_group: Union[object, int] = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -666,14 +706,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
current_state_ids = await context.get_current_state_ids()
+ assert current_state_ids is not None
+ assert state_group is not None
return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
async def get_joined_users_from_state(
- self, room_id, state_entry
+ self, room_id: str, state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]:
- state_group = state_entry.state_group
+ state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -681,6 +723,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
+ assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
@@ -689,12 +732,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
self,
- room_id,
- state_group,
- current_state_ids,
- cache_context,
- event=None,
- context=None,
+ room_id: str,
+ state_group: Union[object, int],
+ current_state_ids: StateMap[str],
+ cache_context: _CacheContext,
+ event: Optional[EventBase] = None,
+ context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
@@ -765,14 +808,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return users_in_room
@cached(max_entries=10000)
- def _get_joined_profile_from_event_id(self, event_id):
+ def _get_joined_profile_from_event_id(
+ self, event_id: str
+ ) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
)
- async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+ async def _get_joined_profiles_from_event_ids(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -780,8 +827,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
- dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
- to `user_id` and ProfileInfo (or None if not join event).
+ Map from event ID to `user_id` and ProfileInfo (or None if not join event).
"""
rows = await self.db_pool.simple_select_many_batch(
@@ -847,8 +893,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
- async def get_joined_hosts(self, room_id: str, state_entry):
- state_group = state_entry.state_group
+ @cached(iterable=True, max_entries=10000)
+ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ """Get current hosts in room based on current state."""
+
+ # First we check if we already have `get_users_in_room` in the cache, as
+ # we can just calculate result from that
+ users = self.get_users_in_room.cache.get_immediate(
+ (room_id,), None, update_metrics=False
+ )
+ if users is not None:
+ return {get_domain_from_id(u) for u in users}
+
+ if isinstance(self.database_engine, Sqlite3Engine):
+ # If we're using SQLite then let's just always use
+ # `get_users_in_room` rather than funky SQL.
+ users = await self.get_users_in_room(room_id)
+ return {get_domain_from_id(u) for u in users}
+
+ # For PostgreSQL we can use a regex to pull out the domains from the
+ # joined users in `current_state_events` via regex.
+
+ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+ sql = """
+ SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
+ FROM current_state_events
+ WHERE
+ type = 'm.room.member'
+ AND membership = 'join'
+ AND room_id = ?
+ """
+ txn.execute(sql, (room_id,))
+ return {d for d, in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_current_hosts_in_room", get_current_hosts_in_room_txn
+ )
+
+ async def get_joined_hosts(
+ self, room_id: str, state_entry: "_StateCacheEntry"
+ ) -> FrozenSet[str]:
+ state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -856,6 +941,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
+ assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry
@@ -863,7 +949,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
- self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
+ self,
+ room_id: str,
+ state_group: Union[object, int],
+ state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
@@ -881,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
- cache = await self._get_joined_hosts_cache(room_id)
+ cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
@@ -897,6 +986,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
elif state_entry.prev_group == cache.state_group:
# The cached work is for the previous state group, so we work out
# the delta.
+ assert state_entry.delta_ids is not None
for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue
@@ -942,7 +1032,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns False if they have since re-joined."""
- def f(txn):
+ def f(txn: LoggingTransaction) -> int:
sql = (
"SELECT"
" COUNT(*)"
@@ -973,7 +1063,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The forgotten rooms.
"""
- def _get_forgotten_rooms_for_user_txn(txn):
+ def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
# This is a slightly convoluted query that first looks up all rooms
# that the user has forgotten in the past, then rechecks that list
# to see if any have subsequently been updated. This is done so that
@@ -1076,7 +1166,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
clause,
)
- def _is_local_host_in_room_ignoring_users_txn(txn):
+ def _is_local_host_in_room_ignoring_users_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
txn.execute(sql, (room_id, Membership.JOIN, *args))
return bool(txn.fetchone())
@@ -1110,15 +1202,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1",
)
- async def _background_add_membership_profile(self, progress, batch_size):
+ async def _background_add_membership_profile(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress.get(
- "target_min_stream_id_inclusive", self._min_stream_order_on_start
+ "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined]
)
max_stream_id = progress.get(
- "max_stream_id_exclusive", self._stream_order_on_start + 1
+ "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined]
)
- def add_membership_profile_txn(txn):
+ def add_membership_profile_txn(txn: LoggingTransaction) -> int:
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
@@ -1182,13 +1276,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return result
- async def _background_current_state_membership(self, progress, batch_size):
+ async def _background_current_state_membership(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order.
"""
- def _background_current_state_membership_txn(txn, last_processed_room):
+ def _background_current_state_membership_txn(
+ txn: LoggingTransaction, last_processed_room: str
+ ) -> Tuple[int, bool]:
processed = 0
while processed < batch_size:
txn.execute(
@@ -1242,7 +1340,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return row_count
-class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+class RoomMemberStore(
+ RoomMemberWorkerStore,
+ RoomMemberBackgroundUpdateStore,
+ CacheInvalidationWorkerStore,
+):
def __init__(
self,
database: DatabasePool,
@@ -1254,7 +1356,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
sql = (
"UPDATE"
" room_memberships"
@@ -1288,5 +1390,5 @@ class _JoinedHostsCache:
# equal to anything else).
state_group: Union[object, int] = attr.Factory(object)
- def __len__(self):
+ def __len__(self) -> int:
return sum(len(v) for v in self.hosts_to_joined_users.values())
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3c49e7ec..78e0773b 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
import attr
@@ -27,7 +27,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
)
- async def _background_reindex_search(self, progress, batch_size):
+ async def _background_reindex_search(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
- def reindex_search_txn(txn):
+ def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return result
- async def _background_reindex_gin_search(self, progress, batch_size):
+ async def _background_reindex_gin_search(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
- def create_index(conn):
+ def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
# we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
return 1
- async def _background_reindex_search_order(self, progress, batch_size):
+ async def _background_reindex_search_order(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
if not have_added_index:
- def create_index(conn):
+ def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
conn.set_session(autocommit=True)
c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
pg,
)
- def reindex_search_txn(txn):
+ def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
else:
raise Exception("Unrecognized database engine")
- args.append(limit)
+ # mypy expects to append only a `str`, not an `int`
+ args.append(limit) # type: ignore[arg-type]
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
A set of strings.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> Set[str]:
highlight_words = set()
for event in events:
# As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
return await self.db_pool.runInteraction("_find_highlights", f)
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 18ae8aee..bdd00273 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -16,6 +16,8 @@ import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+import attr
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -26,6 +28,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -43,6 +47,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventMetadata:
+ """Returned by `get_metadata_for_events`"""
+
+ room_id: str
+ event_type: str
+ state_key: Optional[str]
+ rejection_reason: Optional[str]
+
+
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v:
@@ -133,6 +147,57 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
+ async def get_metadata_for_events(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventMetadata]:
+ """Get some metadata (room_id, type, state_key) for the given events.
+
+ This method is a faster alternative than fetching the full events from
+ the DB, and should be used when the full event is not needed.
+
+ Returns metadata for rejected and redacted events. Events that have not
+ been persisted are omitted from the returned dict.
+ """
+
+ def get_metadata_for_events_txn(
+ txn: LoggingTransaction,
+ batch_ids: Collection[str],
+ ) -> Dict[str, EventMetadata]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "e.event_id", batch_ids
+ )
+
+ sql = f"""
+ SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason
+ FROM events AS e
+ LEFT JOIN state_events se USING (event_id)
+ LEFT JOIN rejections r USING (event_id)
+ WHERE {clause}
+ """
+
+ txn.execute(sql, args)
+ return {
+ event_id: EventMetadata(
+ room_id=room_id,
+ event_type=event_type,
+ state_key=state_key,
+ rejection_reason=rejection_reason,
+ )
+ for event_id, room_id, event_type, state_key, rejection_reason in txn
+ }
+
+ result_map: Dict[str, EventMetadata] = {}
+ for batch_ids in batch_iter(event_ids, 1000):
+ result_map.update(
+ await self.db_pool.runInteraction(
+ "get_metadata_for_events",
+ get_metadata_for_events_txn,
+ batch_ids=batch_ids,
+ )
+ )
+
+ return result_map
+
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@@ -177,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
NotFoundError if the room is unknown
"""
- state_ids = await self.get_current_state_ids(room_id)
+ state_ids = await self.get_partial_current_state_ids(room_id)
if not state_ids:
raise NotFoundError(f"Current state for room {room_id} is empty")
@@ -193,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
- async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
+ async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
+ This may be the partial state if we're lazy joining the room.
+
Args:
room_id: The room to get the state IDs of.
@@ -215,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return await self.db_pool.runInteraction(
- "get_current_state_ids", _get_current_state_ids_txn
+ "get_partial_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
- async def get_filtered_current_state_ids(
+ async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
+ This may be the partial state if we're lazy joining the room.
+
Args:
room_id
state_filter: The state filter used to fetch state
@@ -241,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not where_clause:
# We delegate to the cached version
- return await self.get_current_state_ids(room_id)
+ return await self.get_partial_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
@@ -269,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
- async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
- """Get canonical alias for room, if any
-
- Args:
- room_id: The room ID
-
- Returns:
- The canonical alias, if any
- """
-
- state = await self.get_filtered_current_state_ids(
- room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
- )
-
- event_id = state.get((EventTypes.CanonicalAlias, ""))
- if not event_id:
- return None
-
- event = await self.get_event(event_id, allow_none=True)
- if not event:
- return None
-
- return event.content.get("canonical_alias")
-
@cached(max_entries=50000)
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return await self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 188afec3..445213e1 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore):
# attribute. TODO: can we get static analysis to enforce this?
_curr_state_delta_stream_cache: StreamChangeCache
- async def get_current_state_deltas(
+ async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
@@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore):
- prev_event_id (str|None): previous event_id for this state key. None
if it's new state.
+ This may be the partial state if we're lazy joining the room.
+
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 0373af86..8e88784d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -765,15 +765,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self,
room_id: str,
end_token: RoomStreamToken,
- ) -> Optional[EventBase]:
- """Returns the last event in a room at or before a stream ordering
+ ) -> Optional[str]:
+ """Returns the ID of the last event in a room at or before a stream ordering
Args:
room_id
end_token: The token used to stream from
Returns:
- The most recent event.
+ The ID of the most recent event, or None if there are no events in the room
+ before this stream ordering.
"""
last_row = await self.get_room_event_before_stream_ordering(
@@ -781,37 +782,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
stream_ordering=end_token.stream,
)
if last_row:
- _, _, event_id = last_row
- event = await self.get_event(event_id, get_prev_content=True)
- return event
-
+ return last_row[2]
return None
async def get_current_room_stream_token_for_room_id(
- self, room_id: Optional[str] = None
+ self, room_id: str
) -> RoomStreamToken:
- """Returns the current position of the rooms stream.
-
- By default, it returns a live token with the current global stream
- token. Specifying a `room_id` causes it to return a historic token with
- the room specific topological token.
- """
+ """Returns the current position of the rooms stream (historic token)."""
stream_ordering = self.get_room_max_stream_ordering()
- if room_id is None:
- return RoomStreamToken(None, stream_ordering)
- else:
- topo = await self.db_pool.runInteraction(
- "_get_max_topological_txn", self._get_max_topological_txn, room_id
- )
- return RoomStreamToken(topo, stream_ordering)
+ topo = await self.db_pool.runInteraction(
+ "_get_max_topological_txn", self._get_max_topological_txn, room_id
+ )
+ return RoomStreamToken(topo, stream_ordering)
def get_stream_id_for_event_txn(
self,
txn: LoggingTransaction,
event_id: str,
- allow_none=False,
- ) -> int:
- return self.db_pool.simple_select_one_onecol_txn(
+ allow_none: bool = False,
+ ) -> Optional[int]:
+ # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+ # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+ return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
txn=txn,
table="events",
keyvalues={"event_id": event_id},
@@ -873,7 +865,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
rows = txn.fetchall()
- return rows[0][0] if rows else 0
+ # An aggregate function like MAX() will always return one row per group
+ # so we can safely rely on the lookup here. For example, when a we
+ # lookup a `room_id` which does not exist, `rows` will look like
+ # `[(None,)]`
+ return rows[0][0] if rows[0][0] is not None else 0
@staticmethod
def _set_before_and_after(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 028db69a..ddb25b5c 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
- current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
+ # Getting the partial state is fine, as we're not looking at membership
+ # events.
+ current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)
@@ -729,49 +731,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- async def get_mutual_rooms_for_users(
- self, user_id: str, other_user_id: str
- ) -> Set[str]:
- """
- Returns the rooms that a local user shares with another local or remote user.
-
- Args:
- user_id: The MXID of a local user
- other_user_id: The MXID of the other user
-
- Returns:
- A set of room ID's that the users share.
- """
-
- def _get_mutual_rooms_for_users_txn(
- txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
- txn.execute(
- """
- SELECT p1.room_id
- FROM users_in_public_rooms as p1
- INNER JOIN users_in_public_rooms as p2
- ON p1.room_id = p2.room_id
- AND p1.user_id = ?
- AND p2.user_id = ?
- UNION
- SELECT room_id
- FROM users_who_share_private_rooms
- WHERE
- user_id = ?
- AND other_user_id = ?
- """,
- (user_id, other_user_id, user_id, other_user_id),
- )
- rows = self.db_pool.cursor_to_dict(txn)
- return rows
-
- rows = await self.db_pool.runInteraction(
- "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
- )
-
- return {row["room_id"] for row in rows}
-
async def get_user_directory_stream_pos(self) -> Optional[int]:
"""
Get the stream ID of the user directory stream.
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 5de70f31..fa9eadac 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -195,6 +195,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+ STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME = "state_group_edges_unique_idx"
def __init__(
self,
@@ -217,6 +218,21 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
+ # `state_group_edges` can cause severe performance issues if duplicate
+ # rows are introduced, which can accidentally be done by well-meaning
+ # server admins when trying to restore a database dump, etc.
+ # See https://github.com/matrix-org/synapse/issues/11779.
+ # Introduce a unique index to guard against that.
+ self.db_pool.updates.register_background_index_update(
+ self.STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME,
+ index_name="state_group_edges_unique_idx",
+ table="state_group_edges",
+ columns=["state_group", "prev_state_group"],
+ unique=True,
+ # The old index was on (state_group) and was not unique.
+ replaces_index="state_group_edges_idx",
+ )
+
async def _background_deduplicate_state(
self, progress: dict, batch_size: int
) -> int:
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7614d76a..609a2b88 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
group: int,
state_filter: StateFilter,
) -> Tuple[MutableStateMap[str], bool]:
- """Checks if group is in cache. See `_get_state_for_groups`
+ """Checks if group is in cache. See `get_state_for_groups`
Args:
cache: the state group cache to use