summaryrefslogtreecommitdiff
path: root/synapse/storage
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-04-22 20:34:38 +0200
committerAndrej Shadura <andrewsh@debian.org>2022-04-22 20:34:38 +0200
commit1b9b92888056ce7fd1f3a010ca7afd5c3963d44e (patch)
tree716cb361eb4332eca7b147f35c87ceb29b3ac958 /synapse/storage
parent02eb467b57bf21597094de52232c93d5f4a38b7d (diff)
New upstream version 1.57.1
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py167
-rw-r--r--synapse/storage/databases/main/__init__.py20
-rw-r--r--synapse/storage/databases/main/appservice.py107
-rw-r--r--synapse/storage/databases/main/client_ips.py167
-rw-r--r--synapse/storage/databases/main/devices.py315
-rw-r--r--synapse/storage/databases/main/events.py10
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py60
-rw-r--r--synapse/storage/databases/main/receipts.py13
-rw-r--r--synapse/storage/databases/main/registration.py157
-rw-r--r--synapse/storage/databases/main/relations.py227
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/signatures.py2
-rw-r--r--synapse/storage/databases/main/state.py32
-rw-r--r--synapse/storage/databases/main/stream.py70
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/relations.py84
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql23
-rw-r--r--synapse/storage/schema/main/delta/69/01as_txn_seq.py44
-rw-r--r--synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql38
-rw-r--r--synapse/storage/state.py20
-rw-r--r--synapse/storage/types.py1
23 files changed, 1094 insertions, 504 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3ef2bdd7..5eb545c8 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -241,9 +241,17 @@ class LoggingTransaction:
self.exception_callbacks = exception_callbacks
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
- """Call the given callback on the main twisted thread after the
- transaction has finished. Used to invalidate the caches on the
- correct thread.
+ """Call the given callback on the main twisted thread after the transaction has
+ finished.
+
+ Mostly used to invalidate the caches on the correct thread.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `call_after`
+ will accumulate across transaction attempts and will _all_ be called once a
+ transaction attempt succeeds, regardless of whether previous transaction
+ attempts failed. Otherwise, if all transaction attempts fail, all
+ `call_on_exception` callbacks will be run instead.
"""
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
@@ -254,6 +262,15 @@ class LoggingTransaction:
def call_on_exception(
self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
+ """Call the given callback on the main twisted thread after the transaction has
+ failed.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `call_on_exception`
+ will accumulate across transaction attempts and will _all_ be called once the
+ final transaction attempt fails. No `call_on_exception` callbacks will be run
+ if any transaction attempt succeeds.
+ """
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@@ -1251,6 +1268,7 @@ class DatabasePool:
value_names: Collection[str],
value_values: Collection[Collection[Any]],
desc: str,
+ lock: bool = True,
) -> None:
"""
Upsert, many times.
@@ -1262,6 +1280,8 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert. Unused if the database engine
+ supports native upserts.
"""
# We can autocommit if we are going to use native upserts
@@ -1269,7 +1289,7 @@ class DatabasePool:
self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
)
- return await self.runInteraction(
+ await self.runInteraction(
desc,
self.simple_upsert_many_txn,
table,
@@ -1277,6 +1297,7 @@ class DatabasePool:
key_values,
value_names,
value_values,
+ lock=lock,
db_autocommit=autocommit,
)
@@ -1288,6 +1309,7 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
+ lock: bool = True,
) -> None:
"""
Upsert, many times.
@@ -1299,6 +1321,8 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert. Unused if the database engine
+ supports native upserts.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -1306,7 +1330,7 @@ class DatabasePool:
)
else:
return self.simple_upsert_many_txn_emulated(
- txn, table, key_names, key_values, value_names, value_values
+ txn, table, key_names, key_values, value_names, value_values, lock=lock
)
def simple_upsert_many_txn_emulated(
@@ -1317,6 +1341,7 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
+ lock: bool = True,
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
@@ -1328,17 +1353,24 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
if not value_names:
value_values = [() for x in range(len(key_values))]
+ if lock:
+ # Lock the table just once, to prevent it being done once per row.
+ # Note that, according to Postgres' documentation, once obtained,
+ # the lock is held for the remainder of the current transaction.
+ self.engine.lock_table(txn, "user_ips")
+
for keyv, valv in zip(key_values, value_values):
_keys = {x: y for x, y in zip(key_names, keyv)}
_vals = {x: y for x, y in zip(value_names, valv)}
- self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+ self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False)
def simple_upsert_many_txn_native_upsert(
self,
@@ -1775,6 +1807,86 @@ class DatabasePool:
return txn.rowcount
+ async def simple_update_many(
+ self,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ desc: str,
+ ) -> None:
+ """
+ Update, many times, using batching where possible.
+ If the keys don't match anything, nothing will be updated.
+
+ Args:
+ table: The table to update
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The names of value columns to update.
+ value_values: A list of each row's value column values.
+ """
+
+ await self.runInteraction(
+ desc,
+ self.simple_update_many_txn,
+ table,
+ key_names,
+ key_values,
+ value_names,
+ value_values,
+ )
+
+ @staticmethod
+ def simple_update_many_txn(
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Collection[Iterable[Any]],
+ ) -> None:
+ """
+ Update, many times, using batching where possible.
+ If the keys don't match anything, nothing will be updated.
+
+ Args:
+ table: The table to update
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The names of value columns to update.
+ value_values: A list of each row's value column values.
+ """
+
+ if len(value_values) != len(key_values):
+ raise ValueError(
+ f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
+ )
+
+ # List of tuples of (value values, then key values)
+ # (This matches the order needed for the query)
+ args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)]
+
+ for ks, vs in zip(key_values, value_values):
+ args.append(tuple(vs) + tuple(ks))
+
+ # 'col1 = ?, col2 = ?, ...'
+ set_clause = ", ".join(f"{n} = ?" for n in value_names)
+
+ if key_names:
+ # 'WHERE col3 = ? AND col4 = ? AND col5 = ?'
+ where_clause = "WHERE " + (" AND ".join(f"{n} = ?" for n in key_names))
+ else:
+ where_clause = ""
+
+ # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
+ sql = f"""
+ UPDATE {table} SET {set_clause} {where_clause}
+ """
+
+ txn.execute_batch(sql, args)
+
async def simple_update_one(
self,
table: str,
@@ -2013,29 +2125,40 @@ class DatabasePool:
max_value: int,
limit: int = 100000,
) -> Tuple[Dict[Any, int], int]:
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - %(limit)s"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- "limit": limit,
- }
+ """Gets roughly the last N changes in the given stream table as a
+ map from entity to the stream ID of the most recent change.
+
+ Also returns the minimum stream ID.
+ """
+
+ # This may return many rows for the same entity, but the `limit` is only
+ # a suggestion so we don't care that much.
+ #
+ # Note: Some stream tables can have multiple rows with the same stream
+ # ID. Instead of handling this with complicated SQL, we instead simply
+ # add one to the returned minimum stream ID to ensure correctness.
+ sql = f"""
+ SELECT {entity_column}, {stream_column}
+ FROM {table}
+ ORDER BY {stream_column} DESC
+ LIMIT ?
+ """
txn = db_conn.cursor(txn_name="get_cache_dict")
- txn.execute(sql, (int(max_value),))
+ txn.execute(sql, (limit,))
- cache = {row[0]: int(row[1]) for row in txn}
+ # The rows come out in reverse stream ID order, so we want to keep the
+ # stream ID of the first row for each entity.
+ cache: Dict[Any, int] = {}
+ for row in txn:
+ cache.setdefault(row[0], int(row[1]))
txn.close()
if cache:
- min_val = min(cache.values())
+ # We add one here as we don't know if we have all rows for the
+ # minimum stream ID.
+ min_val = min(cache.values()) + 1
else:
min_val = max_value
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index f024761b..951031af 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -33,7 +33,7 @@ from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore
-from .client_ips import ClientIpStore
+from .client_ips import ClientIpWorkerStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
from .directory import DirectoryStore
@@ -49,7 +49,7 @@ from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore
-from .monthly_active_users import MonthlyActiveUsersStore
+from .monthly_active_users import MonthlyActiveUsersWorkerStore
from .openid import OpenIdStore
from .presence import PresenceStore
from .profile import ProfileStore
@@ -112,13 +112,13 @@ class DataStore(
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
- ClientIpStore,
+ ClientIpWorkerStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
- MonthlyActiveUsersStore,
+ MonthlyActiveUsersWorkerStore,
StatsStore,
RelationsStore,
CensorEventsStore,
@@ -146,6 +146,7 @@ class DataStore(
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
],
)
@@ -182,17 +183,6 @@ class DataStore(
super().__init__(database, db_conn, hs)
- device_list_max = self._device_list_id_gen.get_current_token()
- self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max
- )
- self._user_signature_stream_cache = StreamChangeCache(
- "UserSignatureStreamChangeCache", device_list_max
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max
- )
-
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 06944465..fa732edc 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
from synapse.appservice import (
ApplicationService,
@@ -26,10 +26,16 @@ from synapse.appservice import (
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
+from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -72,9 +78,25 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
+ def get_max_as_txn_id(txn: Cursor) -> int:
+ logger.warning("Falling back to slow query, you should port to postgres")
+ txn.execute(
+ "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
+ )
+ return txn.fetchone()[0] # type: ignore
+
+ self._as_txn_seq_gen = build_sequence_generator(
+ db_conn,
+ database.engine,
+ get_max_as_txn_id,
+ "application_services_txn_id_seq",
+ table="application_services_txns",
+ id_column="txn_id",
+ )
+
super().__init__(database, db_conn, hs)
- def get_app_services(self):
+ def get_app_services(self) -> List[ApplicationService]:
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
@@ -217,6 +239,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
+ device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -231,27 +254,14 @@ class ApplicationServiceTransactionWorkerStore(
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
+ device_list_summary: The device list summary to include in the transaction.
Returns:
A new transaction.
"""
- def _create_appservice_txn(txn):
- # work out new txn id (highest txn id for this service += 1)
- # The highest id may be the last one sent (in which case it is last_txn)
- # or it may be the highest in the txns list (which are waiting to be/are
- # being sent)
- last_txn_id = self._get_last_txn(txn, service.id)
-
- txn.execute(
- "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
- (service.id,),
- )
- highest_txn_id = txn.fetchone()[0]
- if highest_txn_id is None:
- highest_txn_id = 0
-
- new_txn_id = max(highest_txn_id, last_txn_id) + 1
+ def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
+ new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
# Insert new txn into txn table
event_ids = json_encoder.encode([e.event_id for e in events])
@@ -268,6 +278,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
+ device_list_summary=device_list_summary,
)
return await self.db_pool.runInteraction(
@@ -283,25 +294,8 @@ class ApplicationServiceTransactionWorkerStore(
txn_id: The transaction ID being completed.
service: The application service which was sent this transaction.
"""
- txn_id = int(txn_id)
-
- def _complete_appservice_txn(txn):
- # Debugging query: Make sure the txn being completed is EXACTLY +1 from
- # what was there before. If it isn't, we've got problems (e.g. the AS
- # has probably missed some events), so whine loudly but still continue,
- # since it shouldn't fail completion of the transaction.
- last_txn_id = self._get_last_txn(txn, service.id)
- if (last_txn_id + 1) != txn_id:
- logger.error(
- "appservice: Completing a transaction which has an ID > 1 from "
- "the last ID sent to this AS. We've either dropped events or "
- "sent it to the AS out of order. FIX ME. last_txn=%s "
- "completing_txn=%s service_id=%s",
- last_txn_id,
- txn_id,
- service.id,
- )
+ def _complete_appservice_txn(txn: LoggingTransaction) -> None:
# Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn(
txn,
@@ -332,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore(
An AppServiceTransaction or None.
"""
- def _get_oldest_unsent_txn(txn):
+ def _get_oldest_unsent_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
@@ -359,8 +355,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
- # TODO: to-device messages, one-time key counts and unused fallback keys
- # are not yet populated for catch-up transactions.
+ # TODO: to-device messages, one-time key counts, device list summaries and unused
+ # fallback keys are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
@@ -370,21 +366,11 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
+ device_list_summary=DeviceListUpdates(),
)
- def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
- txn.execute(
- "SELECT last_txn FROM application_services_state WHERE as_id=?",
- (service_id,),
- )
- last_txn_id = txn.fetchone()
- if last_txn_id is None or last_txn_id[0] is None: # no row exists
- return 0
- else:
- return int(last_txn_id[0]) # select 'last_txn' col
-
async def set_appservice_last_pos(self, pos: int) -> None:
- def set_appservice_last_pos_txn(txn):
+ def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
@@ -398,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore(
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""
- def get_new_events_for_appservice_txn(txn):
+ def get_new_events_for_appservice_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@@ -430,13 +418,13 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
- if type not in ("read_receipt", "presence", "to_device"):
+ if type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
)
- def get_type_stream_id_for_appservice_txn(txn):
+ def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
stream_id_type = "%s_stream_id" % type
txn.execute(
# We do NOT want to escape `stream_id_type`.
@@ -446,7 +434,8 @@ class ApplicationServiceTransactionWorkerStore(
)
last_stream_id = txn.fetchone()
if last_stream_id is None or last_stream_id[0] is None: # no row exists
- return 0
+ # Stream tokens always start from 1, to avoid foot guns around `0` being falsey.
+ return 1
else:
return int(last_stream_id[0])
@@ -457,13 +446,13 @@ class ApplicationServiceTransactionWorkerStore(
async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
- if stream_type not in ("read_receipt", "presence", "to_device"):
+ if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
)
- def set_appservice_stream_type_pos_txn(txn):
+ def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8b0c614e..0df160d2 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -25,7 +25,9 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
-from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.databases.main.monthly_active_users import (
+ MonthlyActiveUsersWorkerStore,
+)
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
@@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
-class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, we can shift this update process off to
+ # the background worker
+ self._update_on_this_worker = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._update_on_this_worker = hs.get_instance_name() == "master"
+
self.user_ips_max_age = hs.config.server.user_ips_max_age
+ # (user_id, access_token, ip,) -> last_seen
+ self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
+ cache_name="client_ip_last_seen", max_size=50000
+ )
+
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+ if self._update_on_this_worker:
+ # This is the designated worker that can write to the client IP
+ # tables.
+
+ # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+ self._batch_row_update: Dict[
+ Tuple[str, str, str], Tuple[str, Optional[str], int]
+ ] = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
+
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period."""
@@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn
)
- async def get_last_client_ip_by_device(
+ async def _get_last_client_ip_by_device_from_database(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on.
@@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res}
- async def get_user_ip_and_agents(
+ async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp.
@@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
for access_token, ip, user_agent, last_seen in rows
]
-
-class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
-
- # (user_id, access_token, ip,) -> last_seen
- self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
- cache_name="client_ip_last_seen", max_size=50000
- )
-
- super().__init__(database, db_conn, hs)
-
- # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
- self._batch_row_update: Dict[
- Tuple[str, str, str], Tuple[str, Optional[str], int]
- ] = {}
-
- self._client_ip_looper = self._clock.looping_call(
- self._update_client_ips_batch, 5 * 1000
- )
- self.hs.get_reactor().addSystemEventTrigger(
- "before", "shutdown", self._update_client_ips_batch
- )
-
async def insert_client_ip(
self,
user_id: str,
@@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
- await self.populate_monthly_active_users(user_id)
+
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.client_ip_last_seen.set(key, now)
- self._batch_row_update[key] = (user_agent, device_id, now)
+ if self._update_on_this_worker:
+ await self.populate_monthly_active_users(user_id)
+ self._batch_row_update[key] = (user_agent, device_id, now)
+ else:
+ # We are not the designated writer-worker, so stream over replication
+ self.hs.get_replication_command_handler().send_user_ip(
+ user_id, access_token, ip, user_agent, device_id, now
+ )
@wrap_as_background_process("update_client_ips")
async def _update_client_ips_batch(self) -> None:
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update client IPs"
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -603,51 +616,57 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
to_update = self._batch_row_update
self._batch_row_update = {}
- await self.db_pool.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
- )
+ if to_update:
+ await self.db_pool.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ )
def _update_client_ips_batch_txn(
self,
txn: LoggingTransaction,
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
) -> None:
- if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
- not self.database_engine.can_native_upsert
- ):
- self.database_engine.lock_table(txn, "user_ips")
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update client IPs"
+
+ # Keys and values for the `user_ips` upsert.
+ user_ips_keys = []
+ user_ips_values = []
+
+ # Keys and values for the `devices` update.
+ devices_keys = []
+ devices_values = []
for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
-
- self.db_pool.simple_upsert_txn(
- txn,
- table="user_ips",
- keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
- values={
- "user_agent": user_agent,
- "device_id": device_id,
- "last_seen": last_seen,
- },
- lock=False,
- )
+ user_ips_keys.append((user_id, access_token, ip))
+ user_ips_values.append((user_agent, device_id, last_seen))
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- # this is always an update rather than an upsert: the row should
- # already exist, and if it doesn't, that may be because it has been
- # deleted, and we don't want to re-create it.
- self.db_pool.simple_update_txn(
- txn,
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
- updatevalues={
- "user_agent": user_agent,
- "last_seen": last_seen,
- "ip": ip,
- },
- )
+ devices_keys.append((user_id, device_id))
+ devices_values.append((user_agent, last_seen, ip))
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="user_ips",
+ key_names=("user_id", "access_token", "ip"),
+ key_values=user_ips_keys,
+ value_names=("user_agent", "device_id", "last_seen"),
+ value_values=user_ips_values,
+ )
+
+ if devices_values:
+ self.db_pool.simple_update_many_txn(
+ txn,
+ table="devices",
+ key_names=("user_id", "device_id"),
+ key_values=devices_keys,
+ value_names=("user_agent", "last_seen", "ip"),
+ value_values=devices_values,
+ )
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
@@ -662,7 +681,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
- ret = await super().get_last_client_ip_by_device(user_id, device_id)
+ ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
+
+ if not self._update_on_this_worker:
+ # Only the writing-worker has additional in-memory data to enhance
+ # the result
+ return ret
# Update what is retrieved from the database with data which is pending
# insertion, as if it has already been stored in the database.
@@ -707,9 +731,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
Only the latest user agent for each access token and IP address combination
is available.
"""
+ rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts)
+
+ if not self._update_on_this_worker:
+ # Only the writing-worker has additional in-memory data to enhance
+ # the result
+ return rows_from_db
+
results: Dict[Tuple[str, str], LastConnectionInfo] = {
(connection["access_token"], connection["ip"]): connection
- for connection in await super().get_user_ip_and_agents(user, since_ts)
+ for connection in rows_from_db
}
# Overlay data that is pending insertion on top of the results from the
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 3b3a089b..dc8009b2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ device_list_max = self._device_list_id_gen.get_current_token()
+ device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_stream_cache = StreamChangeCache(
+ "DeviceListStreamChangeCache",
+ min_device_list_id,
+ prefilled_cache=device_list_prefill,
+ )
+
+ (
+ user_signature_stream_prefill,
+ user_signature_stream_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "user_signature_stream",
+ entity_column="from_user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=1000,
+ )
+ self._user_signature_stream_cache = StreamChangeCache(
+ "UserSignatureStreamChangeCache",
+ user_signature_stream_list_id,
+ prefilled_cache=user_signature_stream_prefill,
+ )
+
+ (
+ device_list_federation_prefill,
+ device_list_federation_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_outbound_pokes",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache",
+ device_list_federation_list_id,
+ prefilled_cache=device_list_federation_prefill,
+ )
+
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
@@ -681,42 +731,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed(
- self, from_key: int, user_ids: Iterable[str]
+ self,
+ from_key: int,
+ user_ids: Optional[Iterable[str]] = None,
+ to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
- from_key: The device lists stream token
- user_ids: The user IDs to query for devices.
+ from_key: The minimum device lists stream token to query device list changes for,
+ exclusive.
+ user_ids: If provided, only check if these users have changed their device lists.
+ Otherwise changes from all users are returned.
+ to_key: The maximum device lists stream token to query device list changes for,
+ inclusive.
Returns:
- The set of user_ids whose devices have changed since `from_key`
+ The set of user_ids whose devices have changed since `from_key` (exclusive)
+ until `to_key` (inclusive).
"""
-
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ if user_ids is None:
+ # Get set of all users that have had device list changes since 'from_key'
+ user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
+ from_key
+ )
+ else:
+ # The same as above, but filter results to only those users in 'user_ids'
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
- if not to_check:
+ if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
- sql = """
+ stream_id_where_clause = "stream_id > ?"
+ sql_args = [from_key]
+
+ if to_key:
+ stream_id_where_clause += " AND stream_id <= ?"
+ sql_args.append(to_key)
+
+ sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE stream_id > ?
+ WHERE {stream_id_where_clause}
AND
"""
- for chunk in batch_iter(to_check, 100):
+ # Query device changes with a batch of users at a time
+ for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
- txn.execute(sql + clause, (from_key,) + tuple(args))
+ txn.execute(sql + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
return changes
@@ -788,6 +860,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
LIMIT ?
"""
@@ -1506,7 +1579,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
- self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+ self,
+ user_id: str,
+ device_ids: Collection[str],
+ hosts: Optional[Collection[str]],
+ room_ids: Collection[str],
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -1515,7 +1592,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will
return None.
- hosts: The remote destinations that should be notified of the change.
+ hosts: The remote destinations that should be notified of the change. If
+ None then the set of hosts have *not* been calculated, and will be
+ calculated later by a background task.
+ room_ids: The rooms that the user is in
Returns:
The maximum stream ID of device list updates that were added to the database, or
@@ -1524,34 +1604,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return None
- async with self._device_list_id_gen.get_next_mult(
- len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_change_to_stream",
- self._add_device_change_to_stream_txn,
+ context = get_active_span_text_map()
+
+ def add_device_changes_txn(
+ txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+ ):
+ self._add_device_change_to_stream_txn(
+ txn,
user_id,
device_ids,
- stream_ids,
+ stream_ids_for_device_change,
)
- if not hosts:
- return stream_ids[-1]
+ self._add_device_outbound_room_poke_txn(
+ txn,
+ user_id,
+ device_ids,
+ room_ids,
+ stream_ids_for_device_change,
+ context,
+ hosts_have_been_calculated=hosts is not None,
+ )
- context = get_active_span_text_map()
- async with self._device_list_id_gen.get_next_mult(
- len(hosts) * len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_outbound_poke_to_stream",
- self._add_device_outbound_poke_to_stream_txn,
+ # If the set of hosts to send to has not been calculated yet (and so
+ # `hosts` is None) or there are no `hosts` to send to, then skip
+ # trying to persist them to the DB.
+ if not hosts:
+ return
+
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
user_id,
device_ids,
hosts,
- stream_ids,
+ stream_ids_for_outbound_pokes,
context,
)
+ # `device_lists_stream` wants a stream ID per device update.
+ num_stream_ids = len(device_ids)
+
+ if hosts:
+ # `device_lists_outbound_pokes` wants a different stream ID for
+ # each row, which is a row per host per device update.
+ num_stream_ids += len(hosts) * len(device_ids)
+
+ async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+ stream_ids_for_device_change = stream_ids[: len(device_ids)]
+ stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+ await self.db_pool.runInteraction(
+ "add_device_change_to_stream",
+ add_device_changes_txn,
+ stream_ids_for_device_change,
+ stream_ids_for_outbound_pokes,
+ )
+
return stream_ids[-1]
def _add_device_change_to_stream_txn(
@@ -1595,7 +1703,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
- stream_ids: List[str],
+ stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
@@ -1606,8 +1714,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
now = self._clock.time_msec()
- next_stream_id = iter(stream_ids)
+ stream_id_iterator = iter(stream_ids)
+ encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -1623,16 +1732,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
(
destination,
- next(next_stream_id),
+ next(stream_id_iterator),
user_id,
device_id,
False,
now,
- json_encoder.encode(context)
- if whitelisted_homeserver(destination)
- else "{}",
+ encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)
+
+ def _add_device_outbound_room_poke_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Iterable[str],
+ room_ids: Collection[str],
+ stream_ids: List[str],
+ context: Dict[str, str],
+ hosts_have_been_calculated: bool,
+ ) -> None:
+ """Record the user in the room has updated their device.
+
+ Args:
+ hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+ has been updated already with the updates.
+ """
+
+ # We only need to convert to outbound pokes if they are our user.
+ converted_to_destinations = (
+ hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+ )
+
+ encoded_context = json_encoder.encode(context)
+
+ # The `device_lists_changes_in_room.stream_id` column matches the
+ # corresponding `stream_id` of the update in the `device_lists_stream`
+ # table, i.e. all rows persisted for the same device update will have
+ # the same `stream_id` (but different room IDs).
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keys=(
+ "user_id",
+ "device_id",
+ "room_id",
+ "stream_id",
+ "converted_to_destinations",
+ "opentracing_context",
+ ),
+ values=[
+ (
+ user_id,
+ device_id,
+ room_id,
+ stream_id,
+ converted_to_destinations,
+ encoded_context,
+ )
+ for room_id in room_ids
+ for device_id, stream_id in zip(device_ids, stream_ids)
+ ],
+ )
+
+ async def get_uncoverted_outbound_room_pokes(
+ self, limit: int = 10
+ ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+ """Get device list changes by room that have not yet been handled and
+ written to `device_lists_outbound_pokes`.
+
+ Returns:
+ A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ """
+
+ sql = """
+ SELECT user_id, device_id, room_id, stream_id, opentracing_context
+ FROM device_lists_changes_in_room
+ WHERE NOT converted_to_destinations
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ def get_uncoverted_outbound_room_pokes_txn(txn):
+ txn.execute(sql, (limit,))
+ return txn.fetchall()
+
+ return await self.db_pool.runInteraction(
+ "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+ )
+
+ async def add_device_list_outbound_pokes(
+ self,
+ user_id: str,
+ device_id: str,
+ room_id: str,
+ stream_id: int,
+ hosts: Collection[str],
+ context: Optional[Dict[str, str]],
+ ) -> None:
+ """Queue the device update to be sent to the given set of hosts,
+ calculated from the room ID.
+
+ Marks the associated row in `device_lists_changes_in_room` as handled.
+ """
+
+ def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+ if hosts:
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
+ user_id=user_id,
+ device_ids=[device_id],
+ hosts=hosts,
+ stream_ids=stream_ids,
+ context=context,
+ )
+
+ self.db_pool.simple_update_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "room_id": room_id,
+ },
+ updatevalues={"converted_to_destinations": True},
+ )
+
+ if not hosts:
+ # If there are no hosts then we don't try and generate stream IDs.
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ [],
+ )
+
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ stream_ids,
+ )
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index d2532431..3fcd5f5b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -197,12 +197,10 @@ class PersistEventsStore:
)
persist_event_counter.inc(len(events_and_contexts))
- if stream < 0:
- # backfilled events have negative stream orderings, so we don't
- # want to set the event_persisted_position to that.
- synapse.metrics.event_persisted_position.set(
- events_and_contexts[-1][0].internal_metadata.stream_ordering
- )
+ if not use_negative_stream_ordering:
+ # we don't want to set the event_persisted_position to a negative
+ # stream_ordering.
+ synapse.metrics.event_persisted_position.set(stream)
for event, context in events_and_contexts:
if context.app_service:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 59454a47..a60e3f4f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -22,7 +22,6 @@ from typing import (
Dict,
Iterable,
List,
- NoReturn,
Optional,
Set,
Tuple,
@@ -1330,10 +1329,9 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
- # this only exists for the benefit of the @cachedList descriptor on
- # _have_seen_events_dict
- raise NotImplementedError()
+ async def have_seen_event(self, room_id: str, event_id: str) -> bool:
+ res = await self._have_seen_events_dict(((room_id, event_id),))
+ return res[(room_id, event_id)]
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 21662296..4f1c22c7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,7 +15,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
-class MonthlyActiveUsersWorkerStore(SQLBaseStore):
+class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, we can shift this update process off to
+ # the background worker
+ self._update_on_this_worker = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._update_on_this_worker = hs.get_instance_name() == "master"
+
self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
self._max_mau_value = hs.config.server.max_mau_value
+ self._mau_stats_only = hs.config.server.mau_stats_only
+
+ if self._update_on_this_worker:
+ # Do not add more reserved users than the total allowable number
+ self.db_pool.new_transaction(
+ db_conn,
+ "initialise_mau_threepids",
+ [],
+ [],
+ self._initialise_reserved_users,
+ hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
+ )
+
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self._mau_stats_only = hs.config.server.mau_stats_only
-
- # Do not add more reserved users than the total allowable number
- self.db_pool.new_transaction(
- db_conn,
- "initialise_mau_threepids",
- [],
- [],
- self._initialise_reserved_users,
- hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
- )
-
def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
@@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn:
threepids: List of threepid dicts to reserve
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
# XXX what is this function trying to achieve? It upserts into
# monthly_active_users for each *registered* reserved mau user, but why?
@@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id: user to add/update
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
+
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
# is_support_user which is complicated because I want to cache the result.
@@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn (cursor):
user_id (str): user to add/update
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
@@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id(str): the user_id to query
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
+
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e6f97aee..332e901d 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
+ max_receipts_stream_id = self.get_max_receipt_stream_id()
+ receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "receipts_linearized",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=max_receipts_stream_id,
+ limit=10000,
+ )
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
+ "ReceiptsRoomChangeCache",
+ min_receipts_stream_id,
+ prefilled_cache=receipts_stream_prefill,
)
def get_max_receipt_stream_id(self) -> int:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7f3d190e..d43163c2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import UserID, UserInfo
+from synapse.types import JsonDict, UserID, UserInfo
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -79,7 +79,7 @@ class TokenLookupResult:
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
- def _default_token_owner(self):
+ def _default_token_owner(self) -> str:
return self.user_id
@@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
the account.
"""
- def set_account_validity_for_user_txn(txn):
+ def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_txn(
txn=txn,
table="account_validity",
@@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_renewal_token_for_user",
)
- async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
+ async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
- A list of dictionaries, each with a user ID and expiration time (in milliseconds).
+ A list of tuples, each with a user ID and expiration time (in milliseconds).
"""
- def select_users_txn(txn, now_ms, renew_at):
+ def select_users_txn(
+ txn: LoggingTransaction, now_ms: int, renew_at: int
+ ) -> List[Tuple[str, int]]:
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
@@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
admin: true iff the user is to be a server admin, false otherwise.
"""
- def set_server_admin_txn(txn):
+ def set_server_admin_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
)
@@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_type: type of the user or None for a user without a type.
"""
- def set_user_type_txn(txn):
+ def set_user_type_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
)
@@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
- def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+ def _query_for_auth(
+ self, txn: LoggingTransaction, token: str
+ ) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
users.is_guest,
@@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"is_support_user", self.is_support_user_txn, user_id
)
- def is_real_user_txn(self, txn, user_id):
+ def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
@@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return res is None
- def is_support_user_txn(self, txn, user_id):
+ def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
@@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A mapping of user_id -> password_hash.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> Dict[str, str]:
sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,))
- return dict(txn)
+ result = cast(List[Tuple[str, str]], txn.fetchall())
+ return dict(result)
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
@@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def _replace_user_external_id_txn(
txn: LoggingTransaction,
- ):
+ ) -> None:
_remove_user_external_ids_txn(txn, user_id)
for auth_provider, external_id in record_external_ids:
@@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return [(r["auth_provider"], r["external_id"]) for r in res]
- async def count_all_users(self):
+ async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
@@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
who registered on the homeserver in the past 24 hours
"""
- def _count_daily_user_type(txn):
+ def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
@@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"count_daily_user_type", _count_daily_user_type
)
- async def count_nonbridged_users(self):
- def _count_users(txn):
+ async def count_nonbridged_users(self) -> int:
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT COUNT(*) FROM users
WHERE appservice_id IS NULL
"""
)
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
- async def count_real_users(self):
+ async def count_real_users(self) -> int:
"""Counts all users without a special user_type registered on the homeserver."""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
@@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return user_id
def get_user_id_by_threepid_txn(
- self, txn, medium: str, address: str
+ self, txn: LoggingTransaction, medium: str, address: str
) -> Optional[str]:
"""Returns user id from threepid
@@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
+ async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
- ):
+ ) -> None:
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
@@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
assert address or sid
- def get_threepid_validation_session_txn(txn):
+ def get_threepid_validation_session_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
session_id: The ID of the session to delete
"""
- def delete_threepid_session_txn(txn):
+ def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="threepid_validation_token",
@@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ def cull_expired_threepid_validation_tokens_txn(
+ txn: LoggingTransaction, ts: int
+ ) -> None:
sql = """
DELETE FROM threepid_validation_token WHERE
expires < ?
@@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@wrap_as_background_process("account_validity_set_expiration_dates")
- async def _set_expiration_date_when_missing(self):
+ async def _set_expiration_date_when_missing(self) -> None:
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
"""
- def select_users_with_no_expiration_date_txn(txn):
+ def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
"""Retrieves the list of registered users with no expiration date from the
database, filtering out deactivated users.
"""
@@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
select_users_with_no_expiration_date_txn,
)
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ def set_expiration_date_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
+ ) -> None:
"""Sets an expiration date to the account with the given user ID.
Args:
@@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token pending use
"""
- def _set_registration_token_pending_txn(txn):
+ def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
pending = self.db_pool.simple_select_one_onecol_txn(
txn,
"registration_tokens",
@@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
updatevalues={"pending": pending + 1},
)
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_registration_token_pending", _set_registration_token_pending_txn
)
@@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token to be 'used'
"""
- def _use_registration_token_txn(txn):
+ def _use_registration_token_txn(txn: LoggingTransaction) -> None:
# Normally, res is Optional[Dict[str, Any]].
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
@@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
},
)
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"use_registration_token", _use_registration_token_txn
)
@@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A list of dicts, each containing details of a token.
"""
- def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+ def select_registration_tokens_txn(
+ txn: LoggingTransaction, now: int, valid: Optional[bool]
+ ) -> List[Dict[str, Any]]:
if valid is None:
# Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens")
@@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Whether the row was inserted or not.
"""
- def _create_registration_token_txn(txn):
+ def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
row = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A dict with all info about the token, or None if token doesn't exist.
"""
- def _update_registration_token_txn(txn):
+ def _update_registration_token_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
try:
self.db_pool.simple_update_one_txn(
txn,
@@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> Optional[RefreshTokenLookupResult]:
"""Lookup a refresh token with hints about its validity."""
- def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+ def _lookup_refresh_token_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[RefreshTokenLookupResult]:
txn.execute(
"""
SELECT
@@ -1745,6 +1762,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ return res if res else False
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@@ -1795,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
unique=False,
)
- async def _background_update_set_deactivated_flag(self, progress, batch_size):
+ async def _background_update_set_deactivated_flag(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
last_user = progress.get("user_id", "")
- def _background_update_set_deactivated_flag_txn(txn):
+ def _background_update_set_deactivated_flag_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[bool, int]:
txn.execute(
"""
SELECT
@@ -1874,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
deactivated,
)
- def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+ def set_user_deactivated_status_txn(
+ self, txn: LoggingTransaction, user_id: str, deactivated: bool
+ ) -> None:
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
@@ -1887,18 +1922,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
- @cached()
- async def is_guest(self, user_id: str) -> bool:
- res = await self.db_pool.simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="is_guest",
- allow_none=True,
- desc="is_guest",
- )
-
- return res if res else False
-
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
@@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id
- def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ def _set_device_for_access_token_txn(
+ self, txn: LoggingTransaction, token: str, device_id: str
+ ) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"
)
@@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def _register_user(
self,
- txn,
+ txn: LoggingTransaction,
user_id: str,
password_hash: Optional[str],
was_guest: bool,
@@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
- ):
+ ) -> None:
user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time())
@@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
pointless. Use flush_user separately.
"""
- def user_set_password_hash_txn(txn):
+ def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
@@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn,
table="users",
@@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn,
table="users",
@@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
A tuple of (token, token id, device id) for each of the deleted tokens
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
@@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def delete_access_token(self, access_token: str) -> None:
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
async def delete_refresh_token(self, refresh_token: str) -> None:
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
)
@@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"""
# Insert everything into a transaction in order to run atomically
- def validate_threepid_session_txn(txn):
+ def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_session",
@@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
longer be valid
"""
- def start_or_continue_validation_session_txn(txn):
+ def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
# Create or update a validation session
self.db_pool.simple_upsert_txn(
txn,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b2295fd5..407158ce 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Collection,
Dict,
+ FrozenSet,
Iterable,
List,
Optional,
@@ -39,8 +40,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -49,6 +49,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _RelatedEvent:
+ """
+ Contains enough information about a related event in order to properly filter
+ events from ignored users.
+ """
+
+ # The event ID of the related event.
+ event_id: str
+ # The sender of the related event.
+ sender: str
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -73,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
- ) -> PaginationChunk:
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@@ -90,8 +103,10 @@ class RelationsWorkerStore(SQLBaseStore):
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- List of event IDs that match relations requested. The rows are of
- the form `{"event_id": "..."}`.
+ A tuple of:
+ A list of related event IDs & their senders.
+
+ The next stream token, if one exists.
"""
# We don't use `event_id`, it's there so that we can cache based on
# it. The `event_id` must match the `event.event_id`.
@@ -132,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
- SELECT event_id, relation_type, topological_ordering, stream_ordering
+ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@@ -146,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -156,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
- events.append({"event_id": row[0]})
- last_topo_id = row[2]
- last_stream_id = row[3]
+ events.append(_RelatedEvent(row[0], row[2]))
+ last_topo_id = row[3]
+ last_stream_id = row[4]
# If there are more events, generate the next pagination key.
next_token = None
@@ -179,9 +194,7 @@ class RelationsWorkerStore(SQLBaseStore):
groups_key=0,
)
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
- )
+ return events[:limit], next_token
return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
@@ -252,15 +265,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def get_aggregation_groups_for_event(
- self,
- event_id: str,
- room_id: str,
- event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[AggregationPaginationToken] = None,
- to_token: Optional[AggregationPaginationToken] = None,
- ) -> PaginationChunk:
+ self, event_id: str, room_id: str, limit: int = 5
+ ) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -270,82 +276,96 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
- event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
- direction: Whether to fetch the highest count first (`"b"`) or
- the lowest count first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [
+ args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
+ limit,
]
- if event_type:
- where_clause.append("type = ?")
- where_args.append(event_type)
+ sql = """
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
+ GROUP BY relation_type, type, aggregation_key
+ ORDER BY COUNT(*) DESC
+ LIMIT ?
+ """
- having_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
- to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
- engine=self.database_engine,
+ def _get_aggregation_groups_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> List[JsonDict]:
+ txn.execute(sql, args)
+
+ return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
+
+ return await self.db_pool.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
+ async def get_aggregation_groups_for_users(
+ self,
+ event_id: str,
+ room_id: str,
+ limit: int,
+ users: FrozenSet[str] = frozenset(),
+ ) -> Dict[Tuple[str, str], int]:
+ """Fetch the partial aggregations for an event for specific users.
- if having_clause:
- having_clause = "HAVING " + having_clause
- else:
- having_clause = ""
+ This is used, in conjunction with get_aggregation_groups_for_event, to
+ remove information from the results for ignored users.
- sql = """
- SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+ Args:
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ limit: Only fetch the `limit` groups.
+ users: The users to fetch information for.
+
+ Returns:
+ A map of (event type, aggregation key) to a count of users.
+ """
+
+ if not users:
+ return {}
+
+ args: List[Union[str, int]] = [
+ event_id,
+ room_id,
+ RelationTypes.ANNOTATION,
+ ]
+
+ users_sql, users_args = make_in_list_sql_clause(
+ self.database_engine, "sender", users
+ )
+ args.extend(users_args)
+
+ sql = f"""
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
- WHERE {where_clause}
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
GROUP BY relation_type, type, aggregation_key
- {having_clause}
- ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ ORDER BY COUNT(*) DESC
LIMIT ?
- """.format(
- where_clause=" AND ".join(where_clause),
- order=order,
- having_clause=having_clause,
- )
+ """
- def _get_aggregation_groups_for_event_txn(
+ def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
- txn.execute(sql, where_args + [limit + 1])
+ ) -> Dict[Tuple[str, str], int]:
+ txn.execute(sql, args + [limit])
- next_batch = None
- events = []
- for row in txn:
- events.append({"type": row[0], "key": row[1], "count": row[2]})
- next_batch = AggregationPaginationToken(row[2], row[3])
-
- if len(events) <= limit:
- next_batch = None
-
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
+ return {(row[0], row[1]): row[2] for row in txn}
return await self.db_pool.runInteraction(
- "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
)
@cached()
@@ -574,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
return summaries
+ async def get_threaded_messages_per_user(
+ self,
+ event_ids: Collection[str],
+ users: FrozenSet[str] = frozenset(),
+ ) -> Dict[Tuple[str, str], int]:
+ """Get the number of threaded replies for a set of users.
+
+ This is used, in conjunction with get_thread_summaries, to calculate an
+ accurate count of the replies to a thread by subtracting ignored users.
+
+ Args:
+ event_ids: The events to check for threaded replies.
+ users: The user to calculate the count of their replies.
+
+ Returns:
+ A map of the (event_id, sender) to the count of their replies.
+ """
+ if not users:
+ return {}
+
+ # Fetch the number of threaded replies.
+ sql = """
+ SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
+ WHERE
+ %s
+ AND %s
+ AND %s
+ GROUP BY parent.event_id, child.sender
+ """
+
+ def _get_threaded_messages_per_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[Tuple[str, str], int]:
+ users_sql, users_args = make_in_list_sql_clause(
+ self.database_engine, "child.sender", users
+ )
+ events_clause, events_args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
+ else:
+ relations_clause = "relation_type = ?"
+ relations_args = [RelationTypes.THREAD]
+
+ txn.execute(
+ sql % (users_sql, events_clause, relations_clause),
+ users_args + events_args + relations_args,
+ )
+ return {(row[0], row[1]): row[2] for row in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
+ )
+
@cached()
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
raise NotImplementedError()
@@ -661,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
- def _get_if_events_have_relations(txn) -> List[str]:
+ def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3248da53..48e83592 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
- self, user_id: str, membership_list: Collection[str]
+ self,
+ user_id: str,
+ membership_list: Collection[str],
+ excluded_rooms: Optional[List[str]] = None,
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id: The user ID.
membership_list: A list of synapse.api.constants.Membership
values which the user must be in.
+ excluded_rooms: A list of rooms to ignore.
Returns:
The RoomsForUser that the user matches the membership types.
@@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership_list,
)
- # Now we filter out forgotten rooms
- forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
- return [room for room in rooms if room.room_id not in forgotten_rooms]
+ # Now we filter out forgotten and excluded rooms
+ rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+
+ if excluded_rooms is not None:
+ rooms_to_exclude.update(set(excluded_rooms))
+
+ return [room for room in rooms if room.room_id not in rooms_to_exclude]
def _get_rooms_for_local_user_where_membership_is_txn(
- self, txn, user_id: str, membership_list: List[str]
+ self,
+ txn,
+ user_id: str,
+ membership_list: List[str],
) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
@@ -877,7 +888,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return frozenset(cache.hosts_to_joined_users)
# Since we'll mutate the cache we need to lock.
- with (await self._joined_host_linearizer.queue(room_id)):
+ async with self._joined_host_linearizer.queue(room_id):
if state_entry.state_group == cache.state_group:
# Same state group, so nothing to do. We've already checked for
# this above, but the cache may have changed while waiting on
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 0518b8b9..95148fd2 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore):
@cached()
- def get_event_reference_hash(self, event_id):
+ def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 28460fd3..ecdc1fdc 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections.abc
import logging
-from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+
+from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap
+from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
- async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
+ async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
predecessor = create_event.content.get("predecessor", None)
# Ensure the key is a dictionary
- if not isinstance(predecessor, collections.abc.Mapping):
+ if not isinstance(predecessor, (dict, frozendict)):
return None
+ # The keys must be strings since the data is JSON.
return predecessor
async def get_create_event_for_room(self, room_id: str) -> EventBase:
@@ -202,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The current state of the room.
"""
- def _get_current_state_ids_txn(txn):
+ def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
@@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
- async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
- """Returns mapping event_id -> state_group"""
+ async def _get_state_group_for_events(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, int]:
+ """Returns mapping event_id -> state_group.
+
+ Raises:
+ RuntimeError if the state is unknown at any of the given events
+ """
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
@@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events",
)
- return {row["event_id"]: row["state_group"] for row in rows}
+ res = {row["event_id"]: row["state_group"] for row in rows}
+ for e in event_ids:
+ if e not in res:
+ raise RuntimeError("No state group for unknown or outlier event %s" % e)
+ return res
async def get_referenced_state_groups(
self, state_groups: Iterable[int]
@@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
)
for user_id in potentially_left_users - joined_users:
- await self.mark_remote_user_device_list_as_unsubscribed(user_id)
+ await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined]
return batch_size
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 39e1efe3..6d45a8a9 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -36,7 +36,17 @@ what sort order was used:
"""
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
from frozendict import frozendict
@@ -585,7 +595,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
- self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
+ self,
+ user_id: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
+ excluded_rooms: Optional[List[str]] = None,
) -> List[EventBase]:
"""Fetch membership events for a given user.
@@ -610,23 +624,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
min_from_id = from_key.stream
max_to_id = to_key.get_max_stream_pos()
+ args: List[Any] = [user_id, min_from_id, max_to_id]
+
+ ignore_room_clause = ""
+ if excluded_rooms is not None and len(excluded_rooms) > 0:
+ ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
+ "?" for _ in excluded_rooms
+ )
+ args = args + excluded_rooms
+
sql = """
SELECT m.event_id, instance_name, topological_ordering, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ %s
ORDER BY e.stream_ordering ASC
- """
- txn.execute(
- sql,
- (
- user_id,
- min_from_id,
- max_to_id,
- ),
+ """ % (
+ ignore_room_clause,
)
+ txn.execute(sql, args)
+
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
@@ -722,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A tuple of (stream ordering, topological ordering, event_id)
"""
- def _f(txn):
+ def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
@@ -732,27 +752,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering))
- return txn.fetchone()
+ return cast(Optional[Tuple[int, int, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
)
- async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
- """Returns the current token for rooms stream.
+ async def get_current_room_stream_token_for_room_id(
+ self, room_id: Optional[str] = None
+ ) -> RoomStreamToken:
+ """Returns the current position of the rooms stream.
- By default, it returns the current global stream token. Specifying a
- `room_id` causes it to return the current room specific topological
- token.
+ By default, it returns a live token with the current global stream
+ token. Specifying a `room_id` causes it to return a historic token with
+ the room specific topological token.
"""
- token = self.get_room_max_stream_ordering()
+ stream_ordering = self.get_room_max_stream_ordering()
if room_id is None:
- return "s%d" % (token,)
+ return RoomStreamToken(None, stream_ordering)
else:
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- return "t%d-%d" % (topo, token)
+ return RoomStreamToken(topo, stream_ordering)
def get_stream_id_for_event_txn(
self,
@@ -827,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod
def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
- ):
+ ) -> None:
"""Inserts ordering information to events' internal metadata from
the DB rows.
@@ -973,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the `current_id`).
"""
- def get_all_new_events_stream_txn(txn):
+ def get_all_new_events_stream_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@@ -1319,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_id_for_instance(self, instance_name: str) -> int:
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
- def _get_id_for_instance_txn(txn):
+ def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
instance_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="instance_map",
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index c8e508a9..b0f5de67 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
def get_tag_content(
- txn: LoggingTransaction, tag_ids
+ txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
@@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
def _update_revision_txn(
- self, txn, user_id: str, room_id: str, next_id: int
+ self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int
) -> None:
"""Update the latest revision of the tags for the given user and room.
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
deleted file mode 100644
index fba27015..00000000
--- a/synapse/storage/relations.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-
-import attr
-
-from synapse.api.errors import SynapseError
-from synapse.types import JsonDict
-
-if TYPE_CHECKING:
- from synapse.storage.databases.main import DataStore
-
-logger = logging.getLogger(__name__)
-
-
-@attr.s(slots=True, auto_attribs=True)
-class PaginationChunk:
- """Returned by relation pagination APIs.
-
- Attributes:
- chunk: The rows returned by pagination
- next_batch: Token to fetch next set of results with, if
- None then there are no more results.
- prev_batch: Token to fetch previous set of results with, if
- None then there are no previous results.
- """
-
- chunk: List[JsonDict]
- next_batch: Optional[Any] = None
- prev_batch: Optional[Any] = None
-
- async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
- d = {"chunk": self.chunk}
-
- if self.next_batch:
- d["next_batch"] = await self.next_batch.to_string(store)
-
- if self.prev_batch:
- d["prev_batch"] = await self.prev_batch.to_string(store)
-
- return d
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class AggregationPaginationToken:
- """Pagination token for relation aggregation pagination API.
-
- As the results are order by count and then MAX(stream_ordering) of the
- aggregation groups, we can just use them as our pagination token.
-
- Attributes:
- count: The count of relations in the boundary group.
- stream: The MAX stream ordering in the boundary group.
- """
-
- count: int
- stream: int
-
- @staticmethod
- def from_string(string: str) -> "AggregationPaginationToken":
- try:
- c, s = string.split("-")
- return AggregationPaginationToken(int(c), int(s))
- except ValueError:
- raise SynapseError(400, "Invalid aggregation pagination token")
-
- async def to_string(self, store: "DataStore") -> str:
- return "%d-%d" % (self.count, self.stream)
-
- def as_tuple(self) -> Tuple[Any, ...]:
- return attr.astuple(self)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 7b21c1b9..151f2aa9 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 68 # remember to update the list below when updating
+SCHEMA_VERSION = 69 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -58,6 +58,10 @@ Changes in SCHEMA_VERSION = 68:
- event_reference_hashes is no longer read.
- `events` has `state_key` and `rejection_reason` columns, which are populated for
new events.
+
+Changes in SCHEMA_VERSION = 69:
+ - We now write to `device_lists_changes_in_room` table.
+ - Use sequence to generate future `application_services_txns.txn_id`s
"""
diff --git a/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql
new file mode 100644
index 00000000..7590e34b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql
@@ -0,0 +1,23 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column to track what device list changes stream id that this application
+-- service has been caught up to.
+
+-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible
+-- state is useful for determining if we've ever sent traffic for a stream type
+-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for
+-- one way this can be used.
+ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT; \ No newline at end of file
diff --git a/synapse/storage/schema/main/delta/69/01as_txn_seq.py b/synapse/storage/schema/main/delta/69/01as_txn_seq.py
new file mode 100644
index 00000000..24bd4b39
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01as_txn_seq.py
@@ -0,0 +1,44 @@
+# Copyright 2022 Beeper
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Adds a postgres SEQUENCE for generating application service transaction IDs.
+"""
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ # If we already have some AS TXNs we want to start from the current
+ # maximum value. There are two potential places this is stored - the
+ # actual TXNs themselves *and* the AS state table. At time of migration
+ # it is possible the TXNs table is empty so we must include the AS state
+ # last_txn as a potential option, and pick the maximum.
+
+ cur.execute("SELECT COALESCE(max(txn_id), 0) FROM application_services_txns")
+ row = cur.fetchone()
+ txn_max = row[0]
+
+ cur.execute("SELECT COALESCE(max(last_txn), 0) FROM application_services_state")
+ row = cur.fetchone()
+ last_txn_max = row[0]
+
+ start_val = max(last_txn_max, txn_max) + 1
+
+ cur.execute(
+ "CREATE SEQUENCE application_services_txn_id_seq START WITH %s",
+ (start_val,),
+ )
diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
new file mode 100644
index 00000000..b5b1782b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
@@ -0,0 +1,38 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_lists_changes_in_room (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+
+ -- This initially matches `device_lists_stream.stream_id`. Note that we
+ -- delete older values from `device_lists_stream`, so we can't use a foreign
+ -- constraint here.
+ --
+ -- The table will contain rows with the same `stream_id` but different
+ -- `room_id`, as for each device update we store a row per room the user is
+ -- joined to. Therefore `(stream_id, room_id)` gives a unique index.
+ stream_id BIGINT NOT NULL,
+
+ -- We have a background process which goes through this table and converts
+ -- entries into rows in `device_lists_outbound_pokes`. Once we have processed
+ -- a row, we mark it as such by setting `converted_to_destinations=TRUE`.
+ converted_to_destinations BOOLEAN NOT NULL,
+ opentracing_context TEXT
+);
+
+CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id);
+CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations;
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 86f1a537..cda194e8 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -571,6 +571,10 @@ class StateGroupStorage:
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
if not event_ids:
return {}
@@ -659,6 +663,10 @@ class StateGroupStorage:
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
@@ -696,6 +704,10 @@ class StateGroupStorage:
Returns:
A dict from event_id -> (type, state_key) -> event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
@@ -723,6 +735,10 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
@@ -741,6 +757,10 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 57f4883b..d7d6f1d9 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -45,6 +45,7 @@ class Cursor(Protocol):
Sequence[
# Note that this is an approximate typing based on sqlite3 and other
# drivers, and may not be entirely accurate.
+ # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
Tuple[
str,
Optional[Any],