summaryrefslogtreecommitdiff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-01-25 17:19:30 +0100
committerAndrej Shadura <andrewsh@debian.org>2022-01-25 17:19:30 +0100
commit0027c02b907486b437772b1cdecbea14d18597d9 (patch)
tree2d229de4d40a5dcd53d1981b34953a4862203e04 /synapse/storage/databases/main
parentd3d4fbf0a1d394c3e39184fbe0348d6b1d8c7219 (diff)
New upstream version 1.51.0
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/account_data.py6
-rw-r--r--synapse/storage/databases/main/deviceinbox.py30
-rw-r--r--synapse/storage/databases/main/devices.py63
-rw-r--r--synapse/storage/databases/main/directory.py6
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py34
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py48
-rw-r--r--synapse/storage/databases/main/event_push_actions.py21
-rw-r--r--synapse/storage/databases/main/events.py180
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py42
-rw-r--r--synapse/storage/databases/main/presence.py33
-rw-r--r--synapse/storage/databases/main/pusher.py8
-rw-r--r--synapse/storage/databases/main/registration.py18
-rw-r--r--synapse/storage/databases/main/relations.py178
-rw-r--r--synapse/storage/databases/main/room.py18
-rw-r--r--synapse/storage/databases/main/roommember.py6
-rw-r--r--synapse/storage/databases/main/session.py1
-rw-r--r--synapse/storage/databases/main/transactions.py11
-rw-r--r--synapse/storage/databases/main/ui_auth.py12
-rw-r--r--synapse/storage/databases/main/user_directory.py12
19 files changed, 489 insertions, 238 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 32a553fd..ef475e18 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -450,7 +450,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
- """Add some account_data to a room for a user.
+ """Add some global account_data for a user.
Args:
user_id: The user to add a tag for.
@@ -536,9 +536,9 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
self.db_pool.simple_insert_many_txn(
txn,
table="ignored_users",
+ keys=("ignorer_user_id", "ignored_user_id"),
values=[
- {"ignorer_user_id": user_id, "ignored_user_id": u}
- for u in currently_ignored_users - previously_ignored_users
+ (user_id, u) for u in currently_ignored_users - previously_ignored_users
],
)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3682cb6a..4eca9718 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -432,14 +432,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_federation_outbox",
+ keys=(
+ "destination",
+ "stream_id",
+ "queued_ts",
+ "messages_json",
+ "instance_name",
+ ),
values=[
- {
- "destination": destination,
- "stream_id": stream_id,
- "queued_ts": now_ms,
- "messages_json": json_encoder.encode(edu),
- "instance_name": self._instance_name,
- }
+ (
+ destination,
+ stream_id,
+ now_ms,
+ json_encoder.encode(edu),
+ self._instance_name,
+ )
for destination, edu in remote_messages_by_destination.items()
],
)
@@ -571,14 +578,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_inbox",
+ keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"),
values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "stream_id": stream_id,
- "message_json": message_json,
- "instance_name": self._instance_name,
- }
+ (user_id, device_id, stream_id, message_json, self._instance_name)
for user_id, messages_by_device in local_by_user_then_device.items()
for device_id, message_json in messages_by_device.items()
],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index bc7e8760..b2a5cd9a 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -53,6 +53,7 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
"drop_device_list_streams_non_unique_indexes"
@@ -229,6 +230,12 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ data = {(user, device): stream_id for user, device, stream_id, _ in updates}
+ issue_8631_logger.debug(
+ "device updates need to be sent to %s: %s", destination, data
+ )
+
# get the cross-signing keys of the users in the list, so that we can
# determine which of the device changes were cross-signing keys
users = {r[0] for r in updates}
@@ -365,6 +372,17 @@ class DeviceWorkerStore(SQLBaseStore):
# and remove the length budgeting above.
results.append(("org.matrix.signing_key_update", result))
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ for (user_id, edu) in results:
+ issue_8631_logger.debug(
+ "device update to %s for %s from %s to %s: %s",
+ destination,
+ user_id,
+ from_stream_id,
+ last_processed_stream_id,
+ edu,
+ )
+
return last_processed_stream_id, results
def _get_device_updates_by_remote_txn(
@@ -781,7 +799,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cached(max_entries=10000)
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
- ) -> Optional[Any]:
+ ) -> Optional[str]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
@@ -797,7 +815,9 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
- async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
+ async def get_device_list_last_stream_id_for_remotes(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
@@ -1384,6 +1404,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
content: JsonDict,
stream_id: str,
) -> None:
+ """Delete, update or insert a cache entry for this (user, device) pair."""
if content.get("deleted"):
self.db_pool.simple_delete_txn(
txn,
@@ -1443,6 +1464,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
) -> None:
+ """Replace the list of cached devices for this user with the given list."""
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
@@ -1450,12 +1472,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
+ keys=("user_id", "device_id", "content"),
values=[
- {
- "user_id": user_id,
- "device_id": content["device_id"],
- "content": json_encoder.encode(content),
- }
+ (user_id, content["device_id"], json_encoder.encode(content))
for content in devices
],
)
@@ -1543,8 +1562,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
+ keys=("stream_id", "user_id", "device_id"),
values=[
- {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
+ (stream_id, user_id, device_id)
for stream_id, device_id in zip(stream_ids, device_ids)
],
)
@@ -1571,18 +1591,27 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
+ keys=(
+ "destination",
+ "stream_id",
+ "user_id",
+ "device_id",
+ "sent",
+ "ts",
+ "opentracing_context",
+ ),
values=[
- {
- "destination": destination,
- "stream_id": next(next_stream_id),
- "user_id": user_id,
- "device_id": device_id,
- "sent": False,
- "ts": now,
- "opentracing_context": json_encoder.encode(context)
+ (
+ destination,
+ next(next_stream_id),
+ user_id,
+ device_id,
+ False,
+ now,
+ json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
- }
+ )
for destination in hosts
for device_id in device_ids
],
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index f76c6121..5903fdaf 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -112,10 +112,8 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
self.db_pool.simple_insert_many_txn(
txn,
table="room_alias_servers",
- values=[
- {"room_alias": room_alias.to_string(), "server": server}
- for server in servers
- ],
+ keys=("room_alias", "server"),
+ values=[(room_alias.to_string(), server) for server in servers],
)
self._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 0cb48b9d..b789a588 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -110,16 +110,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
- {
- "user_id": user_id,
- "version": version_int,
- "room_id": room_id,
- "session_id": session_id,
- "first_message_index": room_key["first_message_index"],
- "forwarded_count": room_key["forwarded_count"],
- "is_verified": room_key["is_verified"],
- "session_data": json_encoder.encode(room_key["session_data"]),
- }
+ (
+ user_id,
+ version_int,
+ room_id,
+ session_id,
+ room_key["first_message_index"],
+ room_key["forwarded_count"],
+ room_key["is_verified"],
+ json_encoder.encode(room_key["session_data"]),
+ )
)
log_kv(
{
@@ -131,7 +131,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
await self.db_pool.simple_insert_many(
- table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
+ table="e2e_room_keys",
+ keys=(
+ "user_id",
+ "version",
+ "room_id",
+ "session_id",
+ "first_message_index",
+ "forwarded_count",
+ "is_verified",
+ "session_data",
+ ),
+ values=values,
+ desc="add_e2e_room_keys",
)
@trace
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 57b5ffba..1f8447b5 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -50,16 +50,16 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class DeviceKeyLookupResult:
"""The type returned by get_e2e_device_keys_and_signatures"""
- display_name = attr.ib(type=Optional[str])
+ display_name: Optional[str]
# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
- keys = attr.ib(type=Optional[JsonDict])
+ keys: Optional[JsonDict]
class EndToEndKeyBackgroundStore(SQLBaseStore):
@@ -387,15 +387,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
+ keys=(
+ "user_id",
+ "device_id",
+ "algorithm",
+ "key_id",
+ "ts_added_ms",
+ "key_json",
+ ),
values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- "key_id": key_id,
- "ts_added_ms": time_now,
- "key_json": json_bytes,
- }
+ (user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
@@ -1186,15 +1187,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"""
await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
- [
- {
- "user_id": user_id,
- "key_id": item.signing_key_id,
- "target_user_id": item.target_user_id,
- "target_device_id": item.target_device_id,
- "signature": item.signature,
- }
+ keys=(
+ "user_id",
+ "key_id",
+ "target_user_id",
+ "target_device_id",
+ "signature",
+ ),
+ values=[
+ (
+ user_id,
+ item.signing_key_id,
+ item.target_user_id,
+ item.target_device_id,
+ item.signature,
+ )
for item in signatures
],
- "add_e2e_signing_key",
+ desc="add_e2e_signing_key",
)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index a98e6b25..b7c4c622 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -875,14 +875,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="event_push_summary",
+ keys=(
+ "user_id",
+ "room_id",
+ "notif_count",
+ "unread_count",
+ "stream_ordering",
+ ),
values=[
- {
- "user_id": user_id,
- "room_id": room_id,
- "notif_count": summary.notif_count,
- "unread_count": summary.unread_count,
- "stream_ordering": summary.stream_ordering,
- }
+ (
+ user_id,
+ room_id,
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ )
for ((user_id, room_id), summary) in summaries.items()
if summary.old_user_id is None
],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index dd255aef..1ae1ebe1 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -39,7 +39,6 @@ from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
-from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -69,7 +68,7 @@ event_counter = Counter(
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -80,9 +79,9 @@ class DeltaState:
should e.g. be removed from `current_state_events` table.
"""
- to_delete = attr.ib(type=List[Tuple[str, str]])
- to_insert = attr.ib(type=StateMap[str])
- no_longer_in_room = attr.ib(type=bool, default=False)
+ to_delete: List[Tuple[str, str]]
+ to_insert: StateMap[str]
+ no_longer_in_room: bool = False
class PersistEventsStore:
@@ -328,7 +327,6 @@ class PersistEventsStore:
return existing_prevs
- @log_function
def _persist_events_txn(
self,
txn: LoggingTransaction,
@@ -442,12 +440,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
+ keys=("event_id", "room_id", "auth_id"),
values=[
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "auth_id": auth_id,
- }
+ (event.event_id, event.room_id, auth_id)
for event in events
for auth_id in event.auth_event_ids()
if event.is_state()
@@ -675,8 +670,9 @@ class PersistEventsStore:
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
+ keys=("event_id", "chain_id", "sequence_number"),
values=[
- {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+ (event_id, c_id, seq)
for event_id, (c_id, seq) in new_chain_tuples.items()
],
)
@@ -782,13 +778,14 @@ class PersistEventsStore:
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
+ keys=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
values=[
- {
- "origin_chain_id": source_id,
- "origin_sequence_number": source_seq,
- "target_chain_id": target_id,
- "target_sequence_number": target_seq,
- }
+ (source_id, source_seq, target_id, target_seq)
for (
source_id,
source_seq,
@@ -943,20 +940,28 @@ class PersistEventsStore:
txn_id = getattr(event.internal_metadata, "txn_id", None)
if token_id and txn_id:
to_insert.append(
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "user_id": event.sender,
- "token_id": token_id,
- "txn_id": txn_id,
- "inserted_ts": self._clock.time_msec(),
- }
+ (
+ event.event_id,
+ event.room_id,
+ event.sender,
+ token_id,
+ txn_id,
+ self._clock.time_msec(),
+ )
)
if to_insert:
self.db_pool.simple_insert_many_txn(
txn,
table="event_txn_id",
+ keys=(
+ "event_id",
+ "room_id",
+ "user_id",
+ "token_id",
+ "txn_id",
+ "inserted_ts",
+ ),
values=to_insert,
)
@@ -1161,8 +1166,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="event_forward_extremities",
+ keys=("event_id", "room_id"),
values=[
- {"event_id": ev_id, "room_id": room_id}
+ (ev_id, room_id)
for room_id, new_extrem in new_forward_extremities.items()
for ev_id in new_extrem
],
@@ -1174,12 +1180,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
+ keys=("room_id", "event_id", "stream_ordering"),
values=[
- {
- "room_id": room_id,
- "event_id": event_id,
- "stream_ordering": max_stream_order,
- }
+ (room_id, event_id, max_stream_order)
for room_id, new_extrem in new_forward_extremities.items()
for event_id in new_extrem
],
@@ -1251,20 +1254,22 @@ class PersistEventsStore:
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
- def _update_outliers_txn(self, txn, events_and_contexts):
+ def _update_outliers_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Update any outliers with new event info.
- This turns outliers into ex-outliers (unless the new event was
- rejected).
+ This turns outliers into ex-outliers (unless the new event was rejected), and
+ also removes any other events we have already seen from the list.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
+ txn: db connection
+ events_and_contexts: events we are persisting
Returns:
- list[(EventBase, EventContext)] new list, without events which
- are already in the events table.
+ new list, without events which are already in the events table.
"""
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)"
@@ -1272,7 +1277,9 @@ class PersistEventsStore:
[event.event_id for event, _ in events_and_contexts],
)
- have_persisted = {event_id: outlier for event_id, outlier in txn}
+ have_persisted: Dict[str, bool] = {
+ event_id: outlier for event_id, outlier in txn
+ }
to_remove = set()
for event, context in events_and_contexts:
@@ -1282,15 +1289,22 @@ class PersistEventsStore:
to_remove.add(event)
if context.rejected:
- # If the event is rejected then we don't care if the event
- # was an outlier or not.
+ # If the incoming event is rejected then we don't care if the event
+ # was an outlier or not - what we have is at least as good.
continue
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
- # an outlier in the database. We now have some state at that
+ # an outlier in the database. We now have some state at that event
# so we need to update the state_groups table with that state.
+ #
+ # Note that we do not update the stream_ordering of the event in this
+ # scenario. XXX: does this cause bugs? It will mean we won't send such
+ # events down /sync. In general they will be historical events, so that
+ # doesn't matter too much, but that is not always the case.
+
+ logger.info("Updating state for ex-outlier event %s", event.event_id)
# insert into event_to_state_groups.
try:
@@ -1342,7 +1356,7 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
- self.db_pool.simple_insert_many_values_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
@@ -1358,7 +1372,7 @@ class PersistEventsStore:
),
)
- self.db_pool.simple_insert_many_values_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="events",
keys=(
@@ -1412,7 +1426,7 @@ class PersistEventsStore:
)
txn.execute(sql + clause, [False] + args)
- self.db_pool.simple_insert_many_values_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_events",
keys=("event_id", "room_id", "type", "state_key"),
@@ -1622,14 +1636,9 @@ class PersistEventsStore:
return self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
+ keys=("event_id", "label", "room_id", "topological_ordering"),
values=[
- {
- "event_id": event_id,
- "label": label,
- "room_id": room_id,
- "topological_ordering": topological_ordering,
- }
- for label in labels
+ (event_id, label, room_id, topological_ordering) for label in labels
],
)
@@ -1657,16 +1666,13 @@ class PersistEventsStore:
vals = []
for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
- vals.append(
- {
- "event_id": event.event_id,
- "algorithm": ref_alg,
- "hash": memoryview(ref_hash_bytes),
- }
- )
+ vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
self.db_pool.simple_insert_many_txn(
- txn, table="event_reference_hashes", values=vals
+ txn,
+ table="event_reference_hashes",
+ keys=("event_id", "algorithm", "hash"),
+ values=vals,
)
def _store_room_members_txn(
@@ -1689,18 +1695,25 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
+ keys=(
+ "event_id",
+ "user_id",
+ "sender",
+ "room_id",
+ "membership",
+ "display_name",
+ "avatar_url",
+ ),
values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": non_null_str_or_none(
- event.content.get("displayname")
- ),
- "avatar_url": non_null_str_or_none(event.content.get("avatar_url")),
- }
+ (
+ event.event_id,
+ event.state_key,
+ event.user_id,
+ event.room_id,
+ event.membership,
+ non_null_str_or_none(event.content.get("displayname")),
+ non_null_str_or_none(event.content.get("avatar_url")),
+ )
for event in events
],
)
@@ -1791,6 +1804,13 @@ class PersistEventsStore:
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)
+ # It should be safe to only invalidate the cache if the user has not
+ # previously participated in the thread, but that's difficult (and
+ # potentially error-prone) so it is always invalidated.
+ txn.call_after(
+ self.store.get_thread_participated.invalidate,
+ (parent_id, event.room_id, event.sender),
+ )
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
@@ -2163,13 +2183,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="event_edges",
+ keys=("event_id", "prev_event_id", "room_id", "is_state"),
values=[
- {
- "event_id": ev.event_id,
- "prev_event_id": e_id,
- "room_id": ev.room_id,
- "is_state": False,
- }
+ (ev.event_id, e_id, ev.room_id, False)
for ev in events
for e_id in ev.prev_event_ids()
],
@@ -2226,17 +2242,17 @@ class PersistEventsStore:
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _LinkMap:
"""A helper type for tracking links between chains."""
# Stores the set of links as nested maps: source chain ID -> target chain ID
# -> source sequence number -> target sequence number.
- maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+ maps: Dict[int, Dict[int, Dict[int, int]]] = attr.Factory(dict)
# Stores the links that have been added (with new set to true), as tuples of
# `(source chain ID, source sequence no, target chain ID, target sequence no.)`
- additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+ additions: Set[Tuple[int, int, int, int]] = attr.Factory(set)
def add_link(
self,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index a68f14ba..d5f00596 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -65,22 +65,22 @@ class _BackgroundUpdates:
REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn."""
# The last room_id/depth/stream processed.
- room_id = attr.ib(type=str)
- depth = attr.ib(type=int)
- stream = attr.ib(type=int)
+ room_id: str
+ depth: int
+ stream: int
# Number of rows processed
- processed_count = attr.ib(type=int)
+ processed_count: int
# Map from room_id to last depth/stream processed for each room that we have
# processed all events for (i.e. the rooms we can flip the
# `has_auth_chain_index` for)
- finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+ finished_room_map: Dict[str, Tuple[int, int]]
class EventsBackgroundUpdatesStore(SQLBaseStore):
@@ -684,13 +684,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
+ keys=("event_id", "label", "room_id", "topological_ordering"),
values=[
- {
- "event_id": event_id,
- "label": label,
- "room_id": event_json["room_id"],
- "topological_ordering": event_json["depth"],
- }
+ (
+ event_id,
+ label,
+ event_json["room_id"],
+ event_json["depth"],
+ )
for label in event_json["content"].get(
EventContentFields.LABELS, []
)
@@ -803,29 +804,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if not has_state:
state_events.append(
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
+ (event.event_id, event.room_id, event.type, event.state_key)
)
if not has_event_auth:
# Old, dodgy, events may have duplicate auth events, which we
# need to deduplicate as we have a unique constraint.
for auth_id in set(event.auth_event_ids()):
- auth_events.append(
- {
- "room_id": event.room_id,
- "event_id": event.event_id,
- "auth_id": auth_id,
- }
- )
+ auth_events.append((event.event_id, event.room_id, auth_id))
if state_events:
await self.db_pool.simple_insert_many(
table="state_events",
+ keys=("event_id", "room_id", "type", "state_key"),
values=state_events,
desc="_rejected_events_metadata_state_events",
)
@@ -833,6 +824,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if auth_events:
await self.db_pool.simple_insert_many(
table="event_auth",
+ keys=("event_id", "room_id", "auth_id"),
values=auth_events,
desc="_rejected_events_metadata_event_auth",
)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cbf9ec38..4f05811a 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -129,18 +129,29 @@ class PresenceStore(PresenceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="presence_stream",
+ keys=(
+ "stream_id",
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ "instance_name",
+ ),
values=[
- {
- "stream_id": stream_id,
- "user_id": state.user_id,
- "state": state.state,
- "last_active_ts": state.last_active_ts,
- "last_federation_update_ts": state.last_federation_update_ts,
- "last_user_sync_ts": state.last_user_sync_ts,
- "status_msg": state.status_msg,
- "currently_active": state.currently_active,
- "instance_name": self._instance_name,
- }
+ (
+ stream_id,
+ state.user_id,
+ state.state,
+ state.last_active_ts,
+ state.last_federation_update_ts,
+ state.last_user_sync_ts,
+ state.status_msg,
+ state.currently_active,
+ self._instance_name,
+ )
for stream_id, state in zip(stream_orderings, presence_states)
],
)
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 747b4f31..cf64cd63 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -561,13 +561,9 @@ class PusherStore(PusherWorkerStore):
self.db_pool.simple_insert_many_txn(
txn,
table="deleted_pushers",
+ keys=("stream_id", "app_id", "pushkey", "user_id"),
values=[
- {
- "stream_id": stream_id,
- "app_id": pusher.app_id,
- "pushkey": pusher.pushkey,
- "user_id": user_id,
- }
+ (stream_id, pusher.app_id, pusher.pushkey, user_id)
for stream_id, pusher in zip(stream_ids, pushers)
],
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4175c82a..aac94fa4 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -51,7 +51,7 @@ class ExternalIDReuseException(Exception):
pass
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
"""Result of looking up an access token.
@@ -69,14 +69,14 @@ class TokenLookupResult:
cached.
"""
- user_id = attr.ib(type=str)
- is_guest = attr.ib(type=bool, default=False)
- shadow_banned = attr.ib(type=bool, default=False)
- token_id = attr.ib(type=Optional[int], default=None)
- device_id = attr.ib(type=Optional[str], default=None)
- valid_until_ms = attr.ib(type=Optional[int], default=None)
- token_owner = attr.ib(type=str)
- token_used = attr.ib(type=bool, default=False)
+ user_id: str
+ is_guest: bool = False
+ shadow_banned: bool = False
+ token_id: Optional[int] = None
+ device_id: Optional[str] = None
+ valid_until_ms: Optional[int] = None
+ token_owner: str = attr.ib()
+ token_used: bool = False
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4ff6aed2..2cb5d06c 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,14 +13,30 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
+from frozendict import frozendict
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -29,10 +45,24 @@ from synapse.storage.relations import (
)
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+ self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
@cached(tree=True)
async def get_relations_for_event(
self,
@@ -354,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_thread_summary(
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
- """Get the number of threaded replies, the senders of those replies, and
- the latest reply (if any) for the given event.
+ """Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
event_id: Summarize the thread related to this event ID.
@@ -368,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
- # Fetch the count of threaded events and the latest event ID.
+ # Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
@@ -389,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event_id = row[0]
+ # Fetch the number of threaded replies.
sql = """
SELECT COUNT(event_id)
FROM event_relations
@@ -413,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
+ @cached()
+ async def get_thread_participated(
+ self, event_id: str, room_id: str, user_id: str
+ ) -> bool:
+ """Get whether the requesting user participated in a thread.
+
+ This is separate from get_thread_summary since that can be cached across
+ all users while this value is specific to the requeser.
+
+ Args:
+ event_id: The thread related to this event ID.
+ room_id: The room the event belongs to.
+ user_id: The user requesting the summary.
+
+ Returns:
+ True if the requesting user participated in the thread, otherwise false.
+ """
+
+ def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+ # Fetch whether the requester has participated or not.
+ sql = """
+ SELECT 1
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND room_id = ?
+ AND relation_type = ?
+ AND sender = ?
+ """
+
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
+ return bool(txn.fetchone())
+
+ return await self.db_pool.runInteraction(
+ "get_thread_summary", _get_thread_summary_txn
+ )
+
async def events_have_relations(
self,
parent_ids: List[str],
@@ -515,6 +583,104 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ async def _get_bundled_aggregation_for_event(
+ self, event: EventBase, user_id: str
+ ) -> Optional[Dict[str, Any]]:
+ """Generate bundled aggregations for an event.
+
+ Note that this does not use a cache, but depends on cached methods.
+
+ Args:
+ event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ The bundled aggregations for an event, if bundled aggregations are
+ enabled and the event can have bundled aggregations.
+ """
+ # State events and redacted events do not get bundled aggregations.
+ if event.is_state() or event.internal_metadata.is_redacted():
+ return None
+
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ return None
+
+ event_id = event.event_id
+ room_id = event.room_id
+
+ # The bundled aggregations to include, a mapping of relation type to a
+ # type-specific value. Some types include the direct return type here
+ # while others need more processing during serialization.
+ aggregations: Dict[str, Any] = {}
+
+ annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
+ if annotations.chunk:
+ aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+ references = await self.get_relations_for_event(
+ event_id, room_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = await self.get_applicable_edit(event_id, room_id)
+
+ if edit:
+ aggregations[RelationTypes.REPLACE] = edit
+
+ # If this event is the start of a thread, include a summary of the replies.
+ if self._msc3440_enabled:
+ thread_count, latest_thread_event = await self.get_thread_summary(
+ event_id, room_id
+ )
+ participated = await self.get_thread_participated(
+ event_id, room_id, user_id
+ )
+ if latest_thread_event:
+ aggregations[RelationTypes.THREAD] = {
+ "latest_event": latest_thread_event,
+ "count": thread_count,
+ "current_user_participated": participated,
+ }
+
+ # Store the bundled aggregations in the event metadata for later use.
+ return aggregations
+
+ async def get_bundled_aggregations(
+ self,
+ events: Iterable[EventBase],
+ user_id: str,
+ ) -> Dict[str, Dict[str, Any]]:
+ """Generate bundled aggregations for events.
+
+ Args:
+ events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ A map of event ID to the bundled aggregation for the event. Not all
+ events may have bundled aggregations in the results.
+ """
+ # If bundled aggregations are disabled, nothing to do.
+ if not self._msc1849_enabled:
+ return {}
+
+ # TODO Parallelize.
+ results = {}
+ for event in events:
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+ if event_result is not None:
+ results[event.event_id] = event_result
+
+ return results
+
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index c0e83785..95167116 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -551,24 +551,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
FROM room_stats_state state
INNER JOIN room_stats_current curr USING (room_id)
INNER JOIN rooms USING (room_id)
- %s
- ORDER BY %s %s
+ {where}
+ ORDER BY {order_by} {direction}, state.room_id {direction}
LIMIT ?
OFFSET ?
- """ % (
- where_statement,
- order_by_column,
- "ASC" if order_by_asc else "DESC",
+ """.format(
+ where=where_statement,
+ order_by=order_by_column,
+ direction="ASC" if order_by_asc else "DESC",
)
# Use a nested SELECT statement as SQL can't count(*) with an OFFSET
count_sql = """
SELECT count(*) FROM (
SELECT room_id FROM room_stats_state state
- %s
+ {where}
) AS get_room_ids
- """ % (
- where_statement,
+ """.format(
+ where=where_statement,
)
def _get_rooms_paginate_txn(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index cda80d65..4489732f 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1177,18 +1177,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
await self.db_pool.runInteraction("forget_membership", f)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _JoinedHostsCache:
"""The cached data used by the `_get_joined_hosts_cache`."""
# Dict of host to the set of their users in the room at the state group.
- hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
+ hosts_to_joined_users: Dict[str, Set[str]] = attr.Factory(dict)
# The state group `hosts_to_joined_users` is derived from. Will be an object
# if the instance is newly created or if the state is not based on a state
# group. (An object is used as a sentinel value to ensure that it never is
# equal to anything else).
- state_group = attr.ib(type=Union[object, int], factory=object)
+ state_group: Union[object, int] = attr.Factory(object)
def __len__(self):
return sum(len(v) for v in self.hosts_to_joined_users.values())
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
index 5a971204..e8c776b9 100644
--- a/synapse/storage/databases/main/session.py
+++ b/synapse/storage/databases/main/session.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6c299caf..4b78b4d0 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -560,3 +560,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction(
"get_destinations_paginate_txn", get_destinations_paginate_txn
)
+
+ async def is_destination_known(self, destination: str) -> bool:
+ """Check if a destination is known to the server."""
+ result = await self.db_pool.simple_select_one_onecol(
+ table="destinations",
+ keyvalues={"destination": destination},
+ retcol="1",
+ allow_none=True,
+ desc="is_destination_known",
+ )
+ return bool(result)
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index a1a1a6a1..2d339b60 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -23,19 +23,19 @@ from synapse.types import JsonDict
from synapse.util import json_encoder, stringutils
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class UIAuthSessionData:
- session_id = attr.ib(type=str)
+ session_id: str
# The dictionary from the client root level, not the 'auth' key.
- clientdict = attr.ib(type=JsonDict)
+ clientdict: JsonDict
# The URI and method the session was intiatied with. These are checked at
# each stage of the authentication to ensure that the asked for operation
# has not changed.
- uri = attr.ib(type=str)
- method = attr.ib(type=str)
+ uri: str
+ method: str
# A string description of the operation that the current authentication is
# authorising.
- description = attr.ib(type=str)
+ description: str
class UIAuthWorkerStore(SQLBaseStore):
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 0f9b8575..f7c778bd 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -105,8 +105,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
GROUP BY room_id
"""
txn.execute(sql)
- rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ rooms = list(txn.fetchall())
+ self.db_pool.simple_insert_many_txn(
+ txn, TEMP_TABLE + "_rooms", keys=("room_id", "events"), values=rooms
+ )
del rooms
sql = (
@@ -117,9 +119,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute(sql)
txn.execute("SELECT name FROM users")
- users = [{"user_id": x[0]} for x in txn.fetchall()]
+ users = list(txn.fetchall())
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db_pool.simple_insert_many_txn(
+ txn, TEMP_TABLE + "_users", keys=("user_id",), values=users
+ )
new_pos = await self.get_max_stream_id_in_current_state_deltas()
await self.db_pool.runInteraction(