summaryrefslogtreecommitdiff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-04-22 20:34:38 +0200
committerAndrej Shadura <andrewsh@debian.org>2022-04-22 20:34:38 +0200
commit1b9b92888056ce7fd1f3a010ca7afd5c3963d44e (patch)
tree716cb361eb4332eca7b147f35c87ceb29b3ac958 /synapse/storage/databases
parent02eb467b57bf21597094de52232c93d5f4a38b7d (diff)
New upstream version 1.57.1
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py20
-rw-r--r--synapse/storage/databases/main/appservice.py107
-rw-r--r--synapse/storage/databases/main/client_ips.py167
-rw-r--r--synapse/storage/databases/main/devices.py315
-rw-r--r--synapse/storage/databases/main/events.py10
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py60
-rw-r--r--synapse/storage/databases/main/receipts.py13
-rw-r--r--synapse/storage/databases/main/registration.py157
-rw-r--r--synapse/storage/databases/main/relations.py227
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/signatures.py2
-rw-r--r--synapse/storage/databases/main/state.py32
-rw-r--r--synapse/storage/databases/main/stream.py70
-rw-r--r--synapse/storage/databases/main/tags.py4
15 files changed, 818 insertions, 397 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index f024761b..951031af 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -33,7 +33,7 @@ from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore
-from .client_ips import ClientIpStore
+from .client_ips import ClientIpWorkerStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
from .directory import DirectoryStore
@@ -49,7 +49,7 @@ from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore
-from .monthly_active_users import MonthlyActiveUsersStore
+from .monthly_active_users import MonthlyActiveUsersWorkerStore
from .openid import OpenIdStore
from .presence import PresenceStore
from .profile import ProfileStore
@@ -112,13 +112,13 @@ class DataStore(
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
- ClientIpStore,
+ ClientIpWorkerStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
- MonthlyActiveUsersStore,
+ MonthlyActiveUsersWorkerStore,
StatsStore,
RelationsStore,
CensorEventsStore,
@@ -146,6 +146,7 @@ class DataStore(
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
],
)
@@ -182,17 +183,6 @@ class DataStore(
super().__init__(database, db_conn, hs)
- device_list_max = self._device_list_id_gen.get_current_token()
- self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max
- )
- self._user_signature_stream_cache = StreamChangeCache(
- "UserSignatureStreamChangeCache", device_list_max
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max
- )
-
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 06944465..fa732edc 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
from synapse.appservice import (
ApplicationService,
@@ -26,10 +26,16 @@ from synapse.appservice import (
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
+from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -72,9 +78,25 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
+ def get_max_as_txn_id(txn: Cursor) -> int:
+ logger.warning("Falling back to slow query, you should port to postgres")
+ txn.execute(
+ "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
+ )
+ return txn.fetchone()[0] # type: ignore
+
+ self._as_txn_seq_gen = build_sequence_generator(
+ db_conn,
+ database.engine,
+ get_max_as_txn_id,
+ "application_services_txn_id_seq",
+ table="application_services_txns",
+ id_column="txn_id",
+ )
+
super().__init__(database, db_conn, hs)
- def get_app_services(self):
+ def get_app_services(self) -> List[ApplicationService]:
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
@@ -217,6 +239,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
+ device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -231,27 +254,14 @@ class ApplicationServiceTransactionWorkerStore(
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
+ device_list_summary: The device list summary to include in the transaction.
Returns:
A new transaction.
"""
- def _create_appservice_txn(txn):
- # work out new txn id (highest txn id for this service += 1)
- # The highest id may be the last one sent (in which case it is last_txn)
- # or it may be the highest in the txns list (which are waiting to be/are
- # being sent)
- last_txn_id = self._get_last_txn(txn, service.id)
-
- txn.execute(
- "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
- (service.id,),
- )
- highest_txn_id = txn.fetchone()[0]
- if highest_txn_id is None:
- highest_txn_id = 0
-
- new_txn_id = max(highest_txn_id, last_txn_id) + 1
+ def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
+ new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
# Insert new txn into txn table
event_ids = json_encoder.encode([e.event_id for e in events])
@@ -268,6 +278,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
+ device_list_summary=device_list_summary,
)
return await self.db_pool.runInteraction(
@@ -283,25 +294,8 @@ class ApplicationServiceTransactionWorkerStore(
txn_id: The transaction ID being completed.
service: The application service which was sent this transaction.
"""
- txn_id = int(txn_id)
-
- def _complete_appservice_txn(txn):
- # Debugging query: Make sure the txn being completed is EXACTLY +1 from
- # what was there before. If it isn't, we've got problems (e.g. the AS
- # has probably missed some events), so whine loudly but still continue,
- # since it shouldn't fail completion of the transaction.
- last_txn_id = self._get_last_txn(txn, service.id)
- if (last_txn_id + 1) != txn_id:
- logger.error(
- "appservice: Completing a transaction which has an ID > 1 from "
- "the last ID sent to this AS. We've either dropped events or "
- "sent it to the AS out of order. FIX ME. last_txn=%s "
- "completing_txn=%s service_id=%s",
- last_txn_id,
- txn_id,
- service.id,
- )
+ def _complete_appservice_txn(txn: LoggingTransaction) -> None:
# Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn(
txn,
@@ -332,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore(
An AppServiceTransaction or None.
"""
- def _get_oldest_unsent_txn(txn):
+ def _get_oldest_unsent_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
@@ -359,8 +355,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
- # TODO: to-device messages, one-time key counts and unused fallback keys
- # are not yet populated for catch-up transactions.
+ # TODO: to-device messages, one-time key counts, device list summaries and unused
+ # fallback keys are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
@@ -370,21 +366,11 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
+ device_list_summary=DeviceListUpdates(),
)
- def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
- txn.execute(
- "SELECT last_txn FROM application_services_state WHERE as_id=?",
- (service_id,),
- )
- last_txn_id = txn.fetchone()
- if last_txn_id is None or last_txn_id[0] is None: # no row exists
- return 0
- else:
- return int(last_txn_id[0]) # select 'last_txn' col
-
async def set_appservice_last_pos(self, pos: int) -> None:
- def set_appservice_last_pos_txn(txn):
+ def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
@@ -398,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore(
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""
- def get_new_events_for_appservice_txn(txn):
+ def get_new_events_for_appservice_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@@ -430,13 +418,13 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
- if type not in ("read_receipt", "presence", "to_device"):
+ if type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
)
- def get_type_stream_id_for_appservice_txn(txn):
+ def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
stream_id_type = "%s_stream_id" % type
txn.execute(
# We do NOT want to escape `stream_id_type`.
@@ -446,7 +434,8 @@ class ApplicationServiceTransactionWorkerStore(
)
last_stream_id = txn.fetchone()
if last_stream_id is None or last_stream_id[0] is None: # no row exists
- return 0
+ # Stream tokens always start from 1, to avoid foot guns around `0` being falsey.
+ return 1
else:
return int(last_stream_id[0])
@@ -457,13 +446,13 @@ class ApplicationServiceTransactionWorkerStore(
async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
- if stream_type not in ("read_receipt", "presence", "to_device"):
+ if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
)
- def set_appservice_stream_type_pos_txn(txn):
+ def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8b0c614e..0df160d2 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -25,7 +25,9 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
-from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.databases.main.monthly_active_users import (
+ MonthlyActiveUsersWorkerStore,
+)
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
@@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
-class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, we can shift this update process off to
+ # the background worker
+ self._update_on_this_worker = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._update_on_this_worker = hs.get_instance_name() == "master"
+
self.user_ips_max_age = hs.config.server.user_ips_max_age
+ # (user_id, access_token, ip,) -> last_seen
+ self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
+ cache_name="client_ip_last_seen", max_size=50000
+ )
+
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+ if self._update_on_this_worker:
+ # This is the designated worker that can write to the client IP
+ # tables.
+
+ # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+ self._batch_row_update: Dict[
+ Tuple[str, str, str], Tuple[str, Optional[str], int]
+ ] = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
+
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period."""
@@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn
)
- async def get_last_client_ip_by_device(
+ async def _get_last_client_ip_by_device_from_database(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on.
@@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res}
- async def get_user_ip_and_agents(
+ async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp.
@@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
for access_token, ip, user_agent, last_seen in rows
]
-
-class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
-
- # (user_id, access_token, ip,) -> last_seen
- self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
- cache_name="client_ip_last_seen", max_size=50000
- )
-
- super().__init__(database, db_conn, hs)
-
- # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
- self._batch_row_update: Dict[
- Tuple[str, str, str], Tuple[str, Optional[str], int]
- ] = {}
-
- self._client_ip_looper = self._clock.looping_call(
- self._update_client_ips_batch, 5 * 1000
- )
- self.hs.get_reactor().addSystemEventTrigger(
- "before", "shutdown", self._update_client_ips_batch
- )
-
async def insert_client_ip(
self,
user_id: str,
@@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
- await self.populate_monthly_active_users(user_id)
+
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.client_ip_last_seen.set(key, now)
- self._batch_row_update[key] = (user_agent, device_id, now)
+ if self._update_on_this_worker:
+ await self.populate_monthly_active_users(user_id)
+ self._batch_row_update[key] = (user_agent, device_id, now)
+ else:
+ # We are not the designated writer-worker, so stream over replication
+ self.hs.get_replication_command_handler().send_user_ip(
+ user_id, access_token, ip, user_agent, device_id, now
+ )
@wrap_as_background_process("update_client_ips")
async def _update_client_ips_batch(self) -> None:
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update client IPs"
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -603,51 +616,57 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
to_update = self._batch_row_update
self._batch_row_update = {}
- await self.db_pool.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
- )
+ if to_update:
+ await self.db_pool.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ )
def _update_client_ips_batch_txn(
self,
txn: LoggingTransaction,
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
) -> None:
- if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
- not self.database_engine.can_native_upsert
- ):
- self.database_engine.lock_table(txn, "user_ips")
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update client IPs"
+
+ # Keys and values for the `user_ips` upsert.
+ user_ips_keys = []
+ user_ips_values = []
+
+ # Keys and values for the `devices` update.
+ devices_keys = []
+ devices_values = []
for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
-
- self.db_pool.simple_upsert_txn(
- txn,
- table="user_ips",
- keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
- values={
- "user_agent": user_agent,
- "device_id": device_id,
- "last_seen": last_seen,
- },
- lock=False,
- )
+ user_ips_keys.append((user_id, access_token, ip))
+ user_ips_values.append((user_agent, device_id, last_seen))
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- # this is always an update rather than an upsert: the row should
- # already exist, and if it doesn't, that may be because it has been
- # deleted, and we don't want to re-create it.
- self.db_pool.simple_update_txn(
- txn,
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
- updatevalues={
- "user_agent": user_agent,
- "last_seen": last_seen,
- "ip": ip,
- },
- )
+ devices_keys.append((user_id, device_id))
+ devices_values.append((user_agent, last_seen, ip))
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="user_ips",
+ key_names=("user_id", "access_token", "ip"),
+ key_values=user_ips_keys,
+ value_names=("user_agent", "device_id", "last_seen"),
+ value_values=user_ips_values,
+ )
+
+ if devices_values:
+ self.db_pool.simple_update_many_txn(
+ txn,
+ table="devices",
+ key_names=("user_id", "device_id"),
+ key_values=devices_keys,
+ value_names=("user_agent", "last_seen", "ip"),
+ value_values=devices_values,
+ )
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
@@ -662,7 +681,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
- ret = await super().get_last_client_ip_by_device(user_id, device_id)
+ ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
+
+ if not self._update_on_this_worker:
+ # Only the writing-worker has additional in-memory data to enhance
+ # the result
+ return ret
# Update what is retrieved from the database with data which is pending
# insertion, as if it has already been stored in the database.
@@ -707,9 +731,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
Only the latest user agent for each access token and IP address combination
is available.
"""
+ rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts)
+
+ if not self._update_on_this_worker:
+ # Only the writing-worker has additional in-memory data to enhance
+ # the result
+ return rows_from_db
+
results: Dict[Tuple[str, str], LastConnectionInfo] = {
(connection["access_token"], connection["ip"]): connection
- for connection in await super().get_user_ip_and_agents(user, since_ts)
+ for connection in rows_from_db
}
# Overlay data that is pending insertion on top of the results from the
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 3b3a089b..dc8009b2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ device_list_max = self._device_list_id_gen.get_current_token()
+ device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_stream_cache = StreamChangeCache(
+ "DeviceListStreamChangeCache",
+ min_device_list_id,
+ prefilled_cache=device_list_prefill,
+ )
+
+ (
+ user_signature_stream_prefill,
+ user_signature_stream_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "user_signature_stream",
+ entity_column="from_user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=1000,
+ )
+ self._user_signature_stream_cache = StreamChangeCache(
+ "UserSignatureStreamChangeCache",
+ user_signature_stream_list_id,
+ prefilled_cache=user_signature_stream_prefill,
+ )
+
+ (
+ device_list_federation_prefill,
+ device_list_federation_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_outbound_pokes",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache",
+ device_list_federation_list_id,
+ prefilled_cache=device_list_federation_prefill,
+ )
+
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
@@ -681,42 +731,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed(
- self, from_key: int, user_ids: Iterable[str]
+ self,
+ from_key: int,
+ user_ids: Optional[Iterable[str]] = None,
+ to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
- from_key: The device lists stream token
- user_ids: The user IDs to query for devices.
+ from_key: The minimum device lists stream token to query device list changes for,
+ exclusive.
+ user_ids: If provided, only check if these users have changed their device lists.
+ Otherwise changes from all users are returned.
+ to_key: The maximum device lists stream token to query device list changes for,
+ inclusive.
Returns:
- The set of user_ids whose devices have changed since `from_key`
+ The set of user_ids whose devices have changed since `from_key` (exclusive)
+ until `to_key` (inclusive).
"""
-
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ if user_ids is None:
+ # Get set of all users that have had device list changes since 'from_key'
+ user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
+ from_key
+ )
+ else:
+ # The same as above, but filter results to only those users in 'user_ids'
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
- if not to_check:
+ if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
- sql = """
+ stream_id_where_clause = "stream_id > ?"
+ sql_args = [from_key]
+
+ if to_key:
+ stream_id_where_clause += " AND stream_id <= ?"
+ sql_args.append(to_key)
+
+ sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE stream_id > ?
+ WHERE {stream_id_where_clause}
AND
"""
- for chunk in batch_iter(to_check, 100):
+ # Query device changes with a batch of users at a time
+ for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
- txn.execute(sql + clause, (from_key,) + tuple(args))
+ txn.execute(sql + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
return changes
@@ -788,6 +860,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
LIMIT ?
"""
@@ -1506,7 +1579,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
- self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+ self,
+ user_id: str,
+ device_ids: Collection[str],
+ hosts: Optional[Collection[str]],
+ room_ids: Collection[str],
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -1515,7 +1592,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will
return None.
- hosts: The remote destinations that should be notified of the change.
+ hosts: The remote destinations that should be notified of the change. If
+ None then the set of hosts have *not* been calculated, and will be
+ calculated later by a background task.
+ room_ids: The rooms that the user is in
Returns:
The maximum stream ID of device list updates that were added to the database, or
@@ -1524,34 +1604,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return None
- async with self._device_list_id_gen.get_next_mult(
- len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_change_to_stream",
- self._add_device_change_to_stream_txn,
+ context = get_active_span_text_map()
+
+ def add_device_changes_txn(
+ txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+ ):
+ self._add_device_change_to_stream_txn(
+ txn,
user_id,
device_ids,
- stream_ids,
+ stream_ids_for_device_change,
)
- if not hosts:
- return stream_ids[-1]
+ self._add_device_outbound_room_poke_txn(
+ txn,
+ user_id,
+ device_ids,
+ room_ids,
+ stream_ids_for_device_change,
+ context,
+ hosts_have_been_calculated=hosts is not None,
+ )
- context = get_active_span_text_map()
- async with self._device_list_id_gen.get_next_mult(
- len(hosts) * len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_outbound_poke_to_stream",
- self._add_device_outbound_poke_to_stream_txn,
+ # If the set of hosts to send to has not been calculated yet (and so
+ # `hosts` is None) or there are no `hosts` to send to, then skip
+ # trying to persist them to the DB.
+ if not hosts:
+ return
+
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
user_id,
device_ids,
hosts,
- stream_ids,
+ stream_ids_for_outbound_pokes,
context,
)
+ # `device_lists_stream` wants a stream ID per device update.
+ num_stream_ids = len(device_ids)
+
+ if hosts:
+ # `device_lists_outbound_pokes` wants a different stream ID for
+ # each row, which is a row per host per device update.
+ num_stream_ids += len(hosts) * len(device_ids)
+
+ async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+ stream_ids_for_device_change = stream_ids[: len(device_ids)]
+ stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+ await self.db_pool.runInteraction(
+ "add_device_change_to_stream",
+ add_device_changes_txn,
+ stream_ids_for_device_change,
+ stream_ids_for_outbound_pokes,
+ )
+
return stream_ids[-1]
def _add_device_change_to_stream_txn(
@@ -1595,7 +1703,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
- stream_ids: List[str],
+ stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
@@ -1606,8 +1714,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
now = self._clock.time_msec()
- next_stream_id = iter(stream_ids)
+ stream_id_iterator = iter(stream_ids)
+ encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -1623,16 +1732,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
(
destination,
- next(next_stream_id),
+ next(stream_id_iterator),
user_id,
device_id,
False,
now,
- json_encoder.encode(context)
- if whitelisted_homeserver(destination)
- else "{}",
+ encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)
+
+ def _add_device_outbound_room_poke_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Iterable[str],
+ room_ids: Collection[str],
+ stream_ids: List[str],
+ context: Dict[str, str],
+ hosts_have_been_calculated: bool,
+ ) -> None:
+ """Record the user in the room has updated their device.
+
+ Args:
+ hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+ has been updated already with the updates.
+ """
+
+ # We only need to convert to outbound pokes if they are our user.
+ converted_to_destinations = (
+ hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+ )
+
+ encoded_context = json_encoder.encode(context)
+
+ # The `device_lists_changes_in_room.stream_id` column matches the
+ # corresponding `stream_id` of the update in the `device_lists_stream`
+ # table, i.e. all rows persisted for the same device update will have
+ # the same `stream_id` (but different room IDs).
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keys=(
+ "user_id",
+ "device_id",
+ "room_id",
+ "stream_id",
+ "converted_to_destinations",
+ "opentracing_context",
+ ),
+ values=[
+ (
+ user_id,
+ device_id,
+ room_id,
+ stream_id,
+ converted_to_destinations,
+ encoded_context,
+ )
+ for room_id in room_ids
+ for device_id, stream_id in zip(device_ids, stream_ids)
+ ],
+ )
+
+ async def get_uncoverted_outbound_room_pokes(
+ self, limit: int = 10
+ ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+ """Get device list changes by room that have not yet been handled and
+ written to `device_lists_outbound_pokes`.
+
+ Returns:
+ A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ """
+
+ sql = """
+ SELECT user_id, device_id, room_id, stream_id, opentracing_context
+ FROM device_lists_changes_in_room
+ WHERE NOT converted_to_destinations
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ def get_uncoverted_outbound_room_pokes_txn(txn):
+ txn.execute(sql, (limit,))
+ return txn.fetchall()
+
+ return await self.db_pool.runInteraction(
+ "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+ )
+
+ async def add_device_list_outbound_pokes(
+ self,
+ user_id: str,
+ device_id: str,
+ room_id: str,
+ stream_id: int,
+ hosts: Collection[str],
+ context: Optional[Dict[str, str]],
+ ) -> None:
+ """Queue the device update to be sent to the given set of hosts,
+ calculated from the room ID.
+
+ Marks the associated row in `device_lists_changes_in_room` as handled.
+ """
+
+ def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+ if hosts:
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
+ user_id=user_id,
+ device_ids=[device_id],
+ hosts=hosts,
+ stream_ids=stream_ids,
+ context=context,
+ )
+
+ self.db_pool.simple_update_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "room_id": room_id,
+ },
+ updatevalues={"converted_to_destinations": True},
+ )
+
+ if not hosts:
+ # If there are no hosts then we don't try and generate stream IDs.
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ [],
+ )
+
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ stream_ids,
+ )
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index d2532431..3fcd5f5b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -197,12 +197,10 @@ class PersistEventsStore:
)
persist_event_counter.inc(len(events_and_contexts))
- if stream < 0:
- # backfilled events have negative stream orderings, so we don't
- # want to set the event_persisted_position to that.
- synapse.metrics.event_persisted_position.set(
- events_and_contexts[-1][0].internal_metadata.stream_ordering
- )
+ if not use_negative_stream_ordering:
+ # we don't want to set the event_persisted_position to a negative
+ # stream_ordering.
+ synapse.metrics.event_persisted_position.set(stream)
for event, context in events_and_contexts:
if context.app_service:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 59454a47..a60e3f4f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -22,7 +22,6 @@ from typing import (
Dict,
Iterable,
List,
- NoReturn,
Optional,
Set,
Tuple,
@@ -1330,10 +1329,9 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
- # this only exists for the benefit of the @cachedList descriptor on
- # _have_seen_events_dict
- raise NotImplementedError()
+ async def have_seen_event(self, room_id: str, event_id: str) -> bool:
+ res = await self._have_seen_events_dict(((room_id, event_id),))
+ return res[(room_id, event_id)]
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 21662296..4f1c22c7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,7 +15,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
-class MonthlyActiveUsersWorkerStore(SQLBaseStore):
+class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, we can shift this update process off to
+ # the background worker
+ self._update_on_this_worker = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._update_on_this_worker = hs.get_instance_name() == "master"
+
self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
self._max_mau_value = hs.config.server.max_mau_value
+ self._mau_stats_only = hs.config.server.mau_stats_only
+
+ if self._update_on_this_worker:
+ # Do not add more reserved users than the total allowable number
+ self.db_pool.new_transaction(
+ db_conn,
+ "initialise_mau_threepids",
+ [],
+ [],
+ self._initialise_reserved_users,
+ hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
+ )
+
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self._mau_stats_only = hs.config.server.mau_stats_only
-
- # Do not add more reserved users than the total allowable number
- self.db_pool.new_transaction(
- db_conn,
- "initialise_mau_threepids",
- [],
- [],
- self._initialise_reserved_users,
- hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
- )
-
def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
@@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn:
threepids: List of threepid dicts to reserve
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
# XXX what is this function trying to achieve? It upserts into
# monthly_active_users for each *registered* reserved mau user, but why?
@@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id: user to add/update
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
+
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
# is_support_user which is complicated because I want to cache the result.
@@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn (cursor):
user_id (str): user to add/update
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
@@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id(str): the user_id to query
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
+
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e6f97aee..332e901d 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
+ max_receipts_stream_id = self.get_max_receipt_stream_id()
+ receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "receipts_linearized",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=max_receipts_stream_id,
+ limit=10000,
+ )
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
+ "ReceiptsRoomChangeCache",
+ min_receipts_stream_id,
+ prefilled_cache=receipts_stream_prefill,
)
def get_max_receipt_stream_id(self) -> int:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7f3d190e..d43163c2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import UserID, UserInfo
+from synapse.types import JsonDict, UserID, UserInfo
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -79,7 +79,7 @@ class TokenLookupResult:
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
- def _default_token_owner(self):
+ def _default_token_owner(self) -> str:
return self.user_id
@@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
the account.
"""
- def set_account_validity_for_user_txn(txn):
+ def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_txn(
txn=txn,
table="account_validity",
@@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_renewal_token_for_user",
)
- async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
+ async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
- A list of dictionaries, each with a user ID and expiration time (in milliseconds).
+ A list of tuples, each with a user ID and expiration time (in milliseconds).
"""
- def select_users_txn(txn, now_ms, renew_at):
+ def select_users_txn(
+ txn: LoggingTransaction, now_ms: int, renew_at: int
+ ) -> List[Tuple[str, int]]:
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
@@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
admin: true iff the user is to be a server admin, false otherwise.
"""
- def set_server_admin_txn(txn):
+ def set_server_admin_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
)
@@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_type: type of the user or None for a user without a type.
"""
- def set_user_type_txn(txn):
+ def set_user_type_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
)
@@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
- def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+ def _query_for_auth(
+ self, txn: LoggingTransaction, token: str
+ ) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
users.is_guest,
@@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"is_support_user", self.is_support_user_txn, user_id
)
- def is_real_user_txn(self, txn, user_id):
+ def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
@@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return res is None
- def is_support_user_txn(self, txn, user_id):
+ def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
@@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A mapping of user_id -> password_hash.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> Dict[str, str]:
sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,))
- return dict(txn)
+ result = cast(List[Tuple[str, str]], txn.fetchall())
+ return dict(result)
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
@@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def _replace_user_external_id_txn(
txn: LoggingTransaction,
- ):
+ ) -> None:
_remove_user_external_ids_txn(txn, user_id)
for auth_provider, external_id in record_external_ids:
@@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return [(r["auth_provider"], r["external_id"]) for r in res]
- async def count_all_users(self):
+ async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
@@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
who registered on the homeserver in the past 24 hours
"""
- def _count_daily_user_type(txn):
+ def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
@@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"count_daily_user_type", _count_daily_user_type
)
- async def count_nonbridged_users(self):
- def _count_users(txn):
+ async def count_nonbridged_users(self) -> int:
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT COUNT(*) FROM users
WHERE appservice_id IS NULL
"""
)
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
- async def count_real_users(self):
+ async def count_real_users(self) -> int:
"""Counts all users without a special user_type registered on the homeserver."""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
@@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return user_id
def get_user_id_by_threepid_txn(
- self, txn, medium: str, address: str
+ self, txn: LoggingTransaction, medium: str, address: str
) -> Optional[str]:
"""Returns user id from threepid
@@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
+ async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
- ):
+ ) -> None:
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
@@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
assert address or sid
- def get_threepid_validation_session_txn(txn):
+ def get_threepid_validation_session_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
session_id: The ID of the session to delete
"""
- def delete_threepid_session_txn(txn):
+ def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="threepid_validation_token",
@@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ def cull_expired_threepid_validation_tokens_txn(
+ txn: LoggingTransaction, ts: int
+ ) -> None:
sql = """
DELETE FROM threepid_validation_token WHERE
expires < ?
@@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@wrap_as_background_process("account_validity_set_expiration_dates")
- async def _set_expiration_date_when_missing(self):
+ async def _set_expiration_date_when_missing(self) -> None:
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
"""
- def select_users_with_no_expiration_date_txn(txn):
+ def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
"""Retrieves the list of registered users with no expiration date from the
database, filtering out deactivated users.
"""
@@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
select_users_with_no_expiration_date_txn,
)
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ def set_expiration_date_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
+ ) -> None:
"""Sets an expiration date to the account with the given user ID.
Args:
@@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token pending use
"""
- def _set_registration_token_pending_txn(txn):
+ def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
pending = self.db_pool.simple_select_one_onecol_txn(
txn,
"registration_tokens",
@@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
updatevalues={"pending": pending + 1},
)
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_registration_token_pending", _set_registration_token_pending_txn
)
@@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token to be 'used'
"""
- def _use_registration_token_txn(txn):
+ def _use_registration_token_txn(txn: LoggingTransaction) -> None:
# Normally, res is Optional[Dict[str, Any]].
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
@@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
},
)
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"use_registration_token", _use_registration_token_txn
)
@@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A list of dicts, each containing details of a token.
"""
- def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+ def select_registration_tokens_txn(
+ txn: LoggingTransaction, now: int, valid: Optional[bool]
+ ) -> List[Dict[str, Any]]:
if valid is None:
# Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens")
@@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Whether the row was inserted or not.
"""
- def _create_registration_token_txn(txn):
+ def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
row = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A dict with all info about the token, or None if token doesn't exist.
"""
- def _update_registration_token_txn(txn):
+ def _update_registration_token_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
try:
self.db_pool.simple_update_one_txn(
txn,
@@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> Optional[RefreshTokenLookupResult]:
"""Lookup a refresh token with hints about its validity."""
- def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+ def _lookup_refresh_token_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[RefreshTokenLookupResult]:
txn.execute(
"""
SELECT
@@ -1745,6 +1762,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ return res if res else False
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@@ -1795,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
unique=False,
)
- async def _background_update_set_deactivated_flag(self, progress, batch_size):
+ async def _background_update_set_deactivated_flag(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
last_user = progress.get("user_id", "")
- def _background_update_set_deactivated_flag_txn(txn):
+ def _background_update_set_deactivated_flag_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[bool, int]:
txn.execute(
"""
SELECT
@@ -1874,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
deactivated,
)
- def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+ def set_user_deactivated_status_txn(
+ self, txn: LoggingTransaction, user_id: str, deactivated: bool
+ ) -> None:
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
@@ -1887,18 +1922,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
- @cached()
- async def is_guest(self, user_id: str) -> bool:
- res = await self.db_pool.simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="is_guest",
- allow_none=True,
- desc="is_guest",
- )
-
- return res if res else False
-
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
@@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id
- def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ def _set_device_for_access_token_txn(
+ self, txn: LoggingTransaction, token: str, device_id: str
+ ) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"
)
@@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def _register_user(
self,
- txn,
+ txn: LoggingTransaction,
user_id: str,
password_hash: Optional[str],
was_guest: bool,
@@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
- ):
+ ) -> None:
user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time())
@@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
pointless. Use flush_user separately.
"""
- def user_set_password_hash_txn(txn):
+ def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
@@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn,
table="users",
@@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn,
table="users",
@@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
A tuple of (token, token id, device id) for each of the deleted tokens
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
@@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def delete_access_token(self, access_token: str) -> None:
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
async def delete_refresh_token(self, refresh_token: str) -> None:
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
)
@@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"""
# Insert everything into a transaction in order to run atomically
- def validate_threepid_session_txn(txn):
+ def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_session",
@@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
longer be valid
"""
- def start_or_continue_validation_session_txn(txn):
+ def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
# Create or update a validation session
self.db_pool.simple_upsert_txn(
txn,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b2295fd5..407158ce 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Collection,
Dict,
+ FrozenSet,
Iterable,
List,
Optional,
@@ -39,8 +40,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -49,6 +49,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _RelatedEvent:
+ """
+ Contains enough information about a related event in order to properly filter
+ events from ignored users.
+ """
+
+ # The event ID of the related event.
+ event_id: str
+ # The sender of the related event.
+ sender: str
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -73,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
- ) -> PaginationChunk:
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@@ -90,8 +103,10 @@ class RelationsWorkerStore(SQLBaseStore):
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- List of event IDs that match relations requested. The rows are of
- the form `{"event_id": "..."}`.
+ A tuple of:
+ A list of related event IDs & their senders.
+
+ The next stream token, if one exists.
"""
# We don't use `event_id`, it's there so that we can cache based on
# it. The `event_id` must match the `event.event_id`.
@@ -132,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
- SELECT event_id, relation_type, topological_ordering, stream_ordering
+ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@@ -146,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -156,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
- events.append({"event_id": row[0]})
- last_topo_id = row[2]
- last_stream_id = row[3]
+ events.append(_RelatedEvent(row[0], row[2]))
+ last_topo_id = row[3]
+ last_stream_id = row[4]
# If there are more events, generate the next pagination key.
next_token = None
@@ -179,9 +194,7 @@ class RelationsWorkerStore(SQLBaseStore):
groups_key=0,
)
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
- )
+ return events[:limit], next_token
return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
@@ -252,15 +265,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def get_aggregation_groups_for_event(
- self,
- event_id: str,
- room_id: str,
- event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[AggregationPaginationToken] = None,
- to_token: Optional[AggregationPaginationToken] = None,
- ) -> PaginationChunk:
+ self, event_id: str, room_id: str, limit: int = 5
+ ) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -270,82 +276,96 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
- event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
- direction: Whether to fetch the highest count first (`"b"`) or
- the lowest count first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [
+ args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
+ limit,
]
- if event_type:
- where_clause.append("type = ?")
- where_args.append(event_type)
+ sql = """
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
+ GROUP BY relation_type, type, aggregation_key
+ ORDER BY COUNT(*) DESC
+ LIMIT ?
+ """
- having_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
- to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
- engine=self.database_engine,
+ def _get_aggregation_groups_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> List[JsonDict]:
+ txn.execute(sql, args)
+
+ return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
+
+ return await self.db_pool.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
+ async def get_aggregation_groups_for_users(
+ self,
+ event_id: str,
+ room_id: str,
+ limit: int,
+ users: FrozenSet[str] = frozenset(),
+ ) -> Dict[Tuple[str, str], int]:
+ """Fetch the partial aggregations for an event for specific users.
- if having_clause:
- having_clause = "HAVING " + having_clause
- else:
- having_clause = ""
+ This is used, in conjunction with get_aggregation_groups_for_event, to
+ remove information from the results for ignored users.
- sql = """
- SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+ Args:
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ limit: Only fetch the `limit` groups.
+ users: The users to fetch information for.
+
+ Returns:
+ A map of (event type, aggregation key) to a count of users.
+ """
+
+ if not users:
+ return {}
+
+ args: List[Union[str, int]] = [
+ event_id,
+ room_id,
+ RelationTypes.ANNOTATION,
+ ]
+
+ users_sql, users_args = make_in_list_sql_clause(
+ self.database_engine, "sender", users
+ )
+ args.extend(users_args)
+
+ sql = f"""
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
- WHERE {where_clause}
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
GROUP BY relation_type, type, aggregation_key
- {having_clause}
- ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ ORDER BY COUNT(*) DESC
LIMIT ?
- """.format(
- where_clause=" AND ".join(where_clause),
- order=order,
- having_clause=having_clause,
- )
+ """
- def _get_aggregation_groups_for_event_txn(
+ def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
- txn.execute(sql, where_args + [limit + 1])
+ ) -> Dict[Tuple[str, str], int]:
+ txn.execute(sql, args + [limit])
- next_batch = None
- events = []
- for row in txn:
- events.append({"type": row[0], "key": row[1], "count": row[2]})
- next_batch = AggregationPaginationToken(row[2], row[3])
-
- if len(events) <= limit:
- next_batch = None
-
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
+ return {(row[0], row[1]): row[2] for row in txn}
return await self.db_pool.runInteraction(
- "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
)
@cached()
@@ -574,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
return summaries
+ async def get_threaded_messages_per_user(
+ self,
+ event_ids: Collection[str],
+ users: FrozenSet[str] = frozenset(),
+ ) -> Dict[Tuple[str, str], int]:
+ """Get the number of threaded replies for a set of users.
+
+ This is used, in conjunction with get_thread_summaries, to calculate an
+ accurate count of the replies to a thread by subtracting ignored users.
+
+ Args:
+ event_ids: The events to check for threaded replies.
+ users: The user to calculate the count of their replies.
+
+ Returns:
+ A map of the (event_id, sender) to the count of their replies.
+ """
+ if not users:
+ return {}
+
+ # Fetch the number of threaded replies.
+ sql = """
+ SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
+ WHERE
+ %s
+ AND %s
+ AND %s
+ GROUP BY parent.event_id, child.sender
+ """
+
+ def _get_threaded_messages_per_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[Tuple[str, str], int]:
+ users_sql, users_args = make_in_list_sql_clause(
+ self.database_engine, "child.sender", users
+ )
+ events_clause, events_args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
+ else:
+ relations_clause = "relation_type = ?"
+ relations_args = [RelationTypes.THREAD]
+
+ txn.execute(
+ sql % (users_sql, events_clause, relations_clause),
+ users_args + events_args + relations_args,
+ )
+ return {(row[0], row[1]): row[2] for row in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
+ )
+
@cached()
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
raise NotImplementedError()
@@ -661,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
- def _get_if_events_have_relations(txn) -> List[str]:
+ def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3248da53..48e83592 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
- self, user_id: str, membership_list: Collection[str]
+ self,
+ user_id: str,
+ membership_list: Collection[str],
+ excluded_rooms: Optional[List[str]] = None,
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id: The user ID.
membership_list: A list of synapse.api.constants.Membership
values which the user must be in.
+ excluded_rooms: A list of rooms to ignore.
Returns:
The RoomsForUser that the user matches the membership types.
@@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership_list,
)
- # Now we filter out forgotten rooms
- forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
- return [room for room in rooms if room.room_id not in forgotten_rooms]
+ # Now we filter out forgotten and excluded rooms
+ rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+
+ if excluded_rooms is not None:
+ rooms_to_exclude.update(set(excluded_rooms))
+
+ return [room for room in rooms if room.room_id not in rooms_to_exclude]
def _get_rooms_for_local_user_where_membership_is_txn(
- self, txn, user_id: str, membership_list: List[str]
+ self,
+ txn,
+ user_id: str,
+ membership_list: List[str],
) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
@@ -877,7 +888,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return frozenset(cache.hosts_to_joined_users)
# Since we'll mutate the cache we need to lock.
- with (await self._joined_host_linearizer.queue(room_id)):
+ async with self._joined_host_linearizer.queue(room_id):
if state_entry.state_group == cache.state_group:
# Same state group, so nothing to do. We've already checked for
# this above, but the cache may have changed while waiting on
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 0518b8b9..95148fd2 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore):
@cached()
- def get_event_reference_hash(self, event_id):
+ def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 28460fd3..ecdc1fdc 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections.abc
import logging
-from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+
+from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap
+from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
- async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
+ async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
predecessor = create_event.content.get("predecessor", None)
# Ensure the key is a dictionary
- if not isinstance(predecessor, collections.abc.Mapping):
+ if not isinstance(predecessor, (dict, frozendict)):
return None
+ # The keys must be strings since the data is JSON.
return predecessor
async def get_create_event_for_room(self, room_id: str) -> EventBase:
@@ -202,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The current state of the room.
"""
- def _get_current_state_ids_txn(txn):
+ def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
@@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
- async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
- """Returns mapping event_id -> state_group"""
+ async def _get_state_group_for_events(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, int]:
+ """Returns mapping event_id -> state_group.
+
+ Raises:
+ RuntimeError if the state is unknown at any of the given events
+ """
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
@@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events",
)
- return {row["event_id"]: row["state_group"] for row in rows}
+ res = {row["event_id"]: row["state_group"] for row in rows}
+ for e in event_ids:
+ if e not in res:
+ raise RuntimeError("No state group for unknown or outlier event %s" % e)
+ return res
async def get_referenced_state_groups(
self, state_groups: Iterable[int]
@@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
)
for user_id in potentially_left_users - joined_users:
- await self.mark_remote_user_device_list_as_unsubscribed(user_id)
+ await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined]
return batch_size
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 39e1efe3..6d45a8a9 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -36,7 +36,17 @@ what sort order was used:
"""
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
from frozendict import frozendict
@@ -585,7 +595,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
- self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
+ self,
+ user_id: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
+ excluded_rooms: Optional[List[str]] = None,
) -> List[EventBase]:
"""Fetch membership events for a given user.
@@ -610,23 +624,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
min_from_id = from_key.stream
max_to_id = to_key.get_max_stream_pos()
+ args: List[Any] = [user_id, min_from_id, max_to_id]
+
+ ignore_room_clause = ""
+ if excluded_rooms is not None and len(excluded_rooms) > 0:
+ ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
+ "?" for _ in excluded_rooms
+ )
+ args = args + excluded_rooms
+
sql = """
SELECT m.event_id, instance_name, topological_ordering, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ %s
ORDER BY e.stream_ordering ASC
- """
- txn.execute(
- sql,
- (
- user_id,
- min_from_id,
- max_to_id,
- ),
+ """ % (
+ ignore_room_clause,
)
+ txn.execute(sql, args)
+
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
@@ -722,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A tuple of (stream ordering, topological ordering, event_id)
"""
- def _f(txn):
+ def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
@@ -732,27 +752,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering))
- return txn.fetchone()
+ return cast(Optional[Tuple[int, int, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
)
- async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
- """Returns the current token for rooms stream.
+ async def get_current_room_stream_token_for_room_id(
+ self, room_id: Optional[str] = None
+ ) -> RoomStreamToken:
+ """Returns the current position of the rooms stream.
- By default, it returns the current global stream token. Specifying a
- `room_id` causes it to return the current room specific topological
- token.
+ By default, it returns a live token with the current global stream
+ token. Specifying a `room_id` causes it to return a historic token with
+ the room specific topological token.
"""
- token = self.get_room_max_stream_ordering()
+ stream_ordering = self.get_room_max_stream_ordering()
if room_id is None:
- return "s%d" % (token,)
+ return RoomStreamToken(None, stream_ordering)
else:
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- return "t%d-%d" % (topo, token)
+ return RoomStreamToken(topo, stream_ordering)
def get_stream_id_for_event_txn(
self,
@@ -827,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod
def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
- ):
+ ) -> None:
"""Inserts ordering information to events' internal metadata from
the DB rows.
@@ -973,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the `current_id`).
"""
- def get_all_new_events_stream_txn(txn):
+ def get_all_new_events_stream_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@@ -1319,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_id_for_instance(self, instance_name: str) -> int:
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
- def _get_id_for_instance_txn(txn):
+ def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
instance_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="instance_map",
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index c8e508a9..b0f5de67 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
def get_tag_content(
- txn: LoggingTransaction, tag_ids
+ txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
@@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
def _update_revision_txn(
- self, txn, user_id: str, room_id: str, next_id: int
+ self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int
) -> None:
"""Update the latest revision of the tags for the given user and room.