summaryrefslogtreecommitdiff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-03-21 19:25:18 +0100
committerAndrej Shadura <andrewsh@debian.org>2022-03-21 19:25:18 +0100
commit5d8241ddfec4abdb690c84aeb49ca47dca78fc97 (patch)
tree24a93572d20a9f3d94e323777dbd4302932e6f0f /synapse/storage/databases
parent85ec4b0c69e373dfcc6a8b0ddee58875c84dcc7b (diff)
New upstream version 1.54.0
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/__init__.py3
-rw-r--r--synapse/storage/databases/main/appservice.py33
-rw-r--r--synapse/storage/databases/main/devices.py10
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py112
-rw-r--r--synapse/storage/databases/main/events.py93
-rw-r--r--synapse/storage/databases/main/events_worker.py32
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py2
-rw-r--r--synapse/storage/databases/main/presence.py61
-rw-r--r--synapse/storage/databases/main/purge_events.py13
-rw-r--r--synapse/storage/databases/main/registration.py39
-rw-r--r--synapse/storage/databases/main/relations.py26
-rw-r--r--synapse/storage/databases/main/room.py41
-rw-r--r--synapse/storage/databases/main/roommember.py62
-rw-r--r--synapse/storage/databases/main/search.py43
-rw-r--r--synapse/storage/databases/main/state.py27
-rw-r--r--synapse/storage/databases/main/user_directory.py22
16 files changed, 516 insertions, 103 deletions
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index cfe887b7..ce3d1d4e 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -24,6 +24,7 @@ from synapse.storage.prepare_database import prepare_database
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ class Databases(Generic[DataStoreT]):
"""
databases: List[DatabasePool]
- main: DataStoreT
+ main: "DataStore" # FIXME: #11165: actually an instance of `main_store_class`
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 304814af..06944465 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -20,14 +20,18 @@ from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
+ TransactionOneTimeKeyCounts,
+ TransactionUnusedFallbackKeys,
)
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict
from synapse.util import json_encoder
+from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -56,7 +60,7 @@ def _make_exclusive_regex(
return exclusive_user_pattern
-class ApplicationServiceWorkerStore(SQLBaseStore):
+class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -124,6 +128,18 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
return service
return None
+ @cached(iterable=True, cache_context=True)
+ async def get_app_service_users_in_room(
+ self,
+ room_id: str,
+ app_service: "ApplicationService",
+ cache_context: _CacheContext,
+ ) -> List[str]:
+ users_in_room = await self.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ return list(filter(app_service.is_interested_in_user, users_in_room))
+
class ApplicationServiceStore(ApplicationServiceWorkerStore):
# This is currently empty due to there not being any AS storage functions
@@ -199,6 +215,8 @@ class ApplicationServiceTransactionWorkerStore(
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
+ one_time_key_counts: TransactionOneTimeKeyCounts,
+ unused_fallback_keys: TransactionUnusedFallbackKeys,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -209,6 +227,10 @@ class ApplicationServiceTransactionWorkerStore(
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction.
+ one_time_key_counts: Counts of remaining one-time keys for relevant
+ appservice devices in the transaction.
+ unused_fallback_keys: Lists of unused fallback keys for relevant
+ appservice devices in the transaction.
Returns:
A new transaction.
@@ -244,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore(
events=events,
ephemeral=ephemeral,
to_device_messages=to_device_messages,
+ one_time_key_counts=one_time_key_counts,
+ unused_fallback_keys=unused_fallback_keys,
)
return await self.db_pool.runInteraction(
@@ -335,12 +359,17 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
+ # TODO: to-device messages, one-time key counts and unused fallback keys
+ # are not yet populated for catch-up transactions.
+ # We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
id=entry["txn_id"],
events=events,
ephemeral=[],
to_device_messages=[],
+ one_time_key_counts={},
+ unused_fallback_keys={},
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8d845fe9..3b3a089b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
+ def get_cached_device_list_changes(
+ self,
+ from_key: int,
+ ) -> Optional[Set[str]]:
+ """Get set of users whose devices have changed since `from_key`, or None
+ if that information is not in our cache.
+ """
+
+ return self._device_list_stream_cache.get_all_entities_changed(from_key)
+
async def get_users_whose_devices_changed(
self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1f8447b5..9b293475 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -29,6 +29,10 @@ import attr
from canonicaljson import encode_canonical_json
from synapse.api.constants import DeviceKeyAlgorithms
+from synapse.appservice import (
+ TransactionOneTimeKeyCounts,
+ TransactionUnusedFallbackKeys,
+)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -439,6 +443,114 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
+ async def count_bulk_e2e_one_time_keys_for_as(
+ self, user_ids: Collection[str]
+ ) -> TransactionOneTimeKeyCounts:
+ """
+ Counts, in bulk, the one-time keys for all the users specified.
+ Intended to be used by application services for populating OTK counts in
+ transactions.
+
+ Return structure is of the shape:
+ user_id -> device_id -> algorithm -> count
+ Empty algorithm -> count dicts are created if needed to represent a
+ lack of unused one-time keys.
+ """
+
+ def _count_bulk_e2e_one_time_keys_txn(
+ txn: LoggingTransaction,
+ ) -> TransactionOneTimeKeyCounts:
+ user_in_where_clause, user_parameters = make_in_list_sql_clause(
+ self.database_engine, "user_id", user_ids
+ )
+ sql = f"""
+ SELECT user_id, device_id, algorithm, COUNT(key_id)
+ FROM devices
+ LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id)
+ WHERE {user_in_where_clause}
+ GROUP BY user_id, device_id, algorithm
+ """
+ txn.execute(sql, user_parameters)
+
+ result: TransactionOneTimeKeyCounts = {}
+
+ for user_id, device_id, algorithm, count in txn:
+ # We deliberately construct empty dictionaries for
+ # users and devices without any unused one-time keys.
+ # We *could* omit these empty dicts if there have been no
+ # changes since the last transaction, but we currently don't
+ # do any change tracking!
+ device_count_by_algo = result.setdefault(user_id, {}).setdefault(
+ device_id, {}
+ )
+ if algorithm is not None:
+ # algorithm will be None if this device has no keys.
+ device_count_by_algo[algorithm] = count
+
+ return result
+
+ return await self.db_pool.runInteraction(
+ "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn
+ )
+
+ async def get_e2e_bulk_unused_fallback_key_types(
+ self, user_ids: Collection[str]
+ ) -> TransactionUnusedFallbackKeys:
+ """
+ Finds, in bulk, the types of unused fallback keys for all the users specified.
+ Intended to be used by application services for populating unused fallback
+ keys in transactions.
+
+ Return structure is of the shape:
+ user_id -> device_id -> algorithms
+ Empty lists are created for devices if there are no unused fallback
+ keys. This matches the response structure of MSC3202.
+ """
+ if len(user_ids) == 0:
+ return {}
+
+ def _get_bulk_e2e_unused_fallback_keys_txn(
+ txn: LoggingTransaction,
+ ) -> TransactionUnusedFallbackKeys:
+ user_in_where_clause, user_parameters = make_in_list_sql_clause(
+ self.database_engine, "devices.user_id", user_ids
+ )
+ # We can't use USING here because we require the `.used` condition
+ # to be part of the JOIN condition so that we generate empty lists
+ # when all keys are used (as opposed to just when there are no keys at all).
+ sql = f"""
+ SELECT devices.user_id, devices.device_id, algorithm
+ FROM devices
+ LEFT JOIN e2e_fallback_keys_json AS fallback_keys
+ ON devices.user_id = fallback_keys.user_id
+ AND devices.device_id = fallback_keys.device_id
+ AND NOT fallback_keys.used
+ WHERE
+ {user_in_where_clause}
+ """
+ txn.execute(sql, user_parameters)
+
+ result: TransactionUnusedFallbackKeys = {}
+
+ for user_id, device_id, algorithm in txn:
+ # We deliberately construct empty dictionaries and lists for
+ # users and devices without any unused fallback keys.
+ # We *could* omit these empty dicts if there have been no
+ # changes since the last transaction, but we currently don't
+ # do any change tracking!
+ device_unused_keys = result.setdefault(user_id, {}).setdefault(
+ device_id, []
+ )
+ if algorithm is not None:
+ # algorithm will be None if this device has no keys.
+ device_unused_keys.append(algorithm)
+
+ return result
+
+ return await self.db_pool.runInteraction(
+ "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn
+ )
+
async def set_e2e_fallback_keys(
self, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5246fcca..ca2a9ba9 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -130,7 +130,7 @@ class PersistEventsStore:
*,
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
- new_forward_extremeties: Dict[str, List[str]],
+ new_forward_extremities: Dict[str, Set[str]],
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
) -> None:
@@ -143,7 +143,7 @@ class PersistEventsStore:
the room based on forward extremities
state_delta_for_room: Map from room_id to the delta to apply to
room state
- new_forward_extremities: Map from room_id to list of event IDs
+ new_forward_extremities: Map from room_id to set of event IDs
that are the new forward extremities of the room.
use_negative_stream_ordering: Whether to start stream_ordering on
the negative side and decrement. This should be set as True
@@ -193,7 +193,7 @@ class PersistEventsStore:
events_and_contexts=events_and_contexts,
inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
- new_forward_extremeties=new_forward_extremeties,
+ new_forward_extremities=new_forward_extremities,
)
persist_event_counter.inc(len(events_and_contexts))
@@ -220,7 +220,7 @@ class PersistEventsStore:
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
- for room_id, latest_event_ids in new_forward_extremeties.items():
+ for room_id, latest_event_ids in new_forward_extremities.items():
self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
@@ -334,8 +334,8 @@ class PersistEventsStore:
events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool = False,
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
- new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
- ):
+ new_forward_extremities: Optional[Dict[str, Set[str]]] = None,
+ ) -> None:
"""Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table,
@@ -353,13 +353,13 @@ class PersistEventsStore:
from the database. This is useful when retrying due to
IntegrityError.
state_delta_for_room: The current-state delta for each room.
- new_forward_extremetie: The new forward extremities for each room.
+ new_forward_extremities: The new forward extremities for each room.
For each room, a list of the event ids which are the forward
extremities.
"""
state_delta_for_room = state_delta_for_room or {}
- new_forward_extremeties = new_forward_extremeties or {}
+ new_forward_extremities = new_forward_extremities or {}
all_events_and_contexts = events_and_contexts
@@ -372,7 +372,7 @@ class PersistEventsStore:
self._update_forward_extremities_txn(
txn,
- new_forward_extremities=new_forward_extremeties,
+ new_forward_extremities=new_forward_extremities,
max_stream_order=max_stream_order,
)
@@ -975,6 +975,17 @@ class PersistEventsStore:
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
+ # Figure out the changes of membership to invalidate the
+ # `get_rooms_for_user` cache.
+ # We find out which membership events we may have deleted
+ # and which we have added, then we invalidate the caches for all
+ # those users.
+ members_changed = {
+ state_key
+ for ev_type, state_key in itertools.chain(to_delete, to_insert)
+ if ev_type == EventTypes.Member
+ }
+
if delta_state.no_longer_in_room:
# Server is no longer in the room so we delete the room from
# current_state_events, being careful we've already updated the
@@ -993,6 +1004,11 @@ class PersistEventsStore:
"""
txn.execute(sql, (stream_id, self._instance_name, room_id))
+ # We also want to invalidate the membership caches for users
+ # that were in the room.
+ users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+ members_changed.update(users_in_room)
+
self.db_pool.simple_delete_txn(
txn,
table="current_state_events",
@@ -1102,17 +1118,6 @@ class PersistEventsStore:
# Invalidate the various caches
- # Figure out the changes of membership to invalidate the
- # `get_rooms_for_user` cache.
- # We find out which membership events we may have deleted
- # and which we have added, then we invalidate the caches for all
- # those users.
- members_changed = {
- state_key
- for ev_type, state_key in itertools.chain(to_delete, to_insert)
- if ev_type == EventTypes.Member
- }
-
for member in members_changed:
txn.call_after(
self.store.get_rooms_for_user_with_stream_ordering.invalidate,
@@ -1153,7 +1158,10 @@ class PersistEventsStore:
)
def _update_forward_extremities_txn(
- self, txn, new_forward_extremities, max_stream_order
+ self,
+ txn: LoggingTransaction,
+ new_forward_extremities: Dict[str, Set[str]],
+ max_stream_order: int,
):
for room_id in new_forward_extremities.keys():
self.db_pool.simple_delete_txn(
@@ -1468,10 +1476,10 @@ class PersistEventsStore:
def _update_metadata_tables_txn(
self,
- txn,
+ txn: LoggingTransaction,
*,
- events_and_contexts,
- all_events_and_contexts,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ all_events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool = False,
):
"""Update all the miscellaneous tables for new events
@@ -1948,20 +1956,20 @@ class PersistEventsStore:
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
- def _store_room_topic_txn(self, txn, event):
- if hasattr(event, "content") and "topic" in event.content:
+ def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+ if isinstance(event.content.get("topic"), str):
self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
- def _store_room_name_txn(self, txn, event):
- if hasattr(event, "content") and "name" in event.content:
+ def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+ if isinstance(event.content.get("name"), str):
self.store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
- def _store_room_message_txn(self, txn, event):
- if hasattr(event, "content") and "body" in event.content:
+ def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+ if isinstance(event.content.get("body"), str):
self.store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
@@ -2137,6 +2145,14 @@ class PersistEventsStore:
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
+ # double-check that we don't have any events that claim to be outliers
+ # *and* have partial state (which is meaningless: we should have no
+ # state at all for an outlier)
+ if context.partial_state:
+ raise ValueError(
+ "Outlier event %s claims to have partial state", event.event_id
+ )
+
continue
# if the event was rejected, just give it the same state as its
@@ -2147,6 +2163,23 @@ class PersistEventsStore:
state_groups[event.event_id] = context.state_group
+ # if we have partial state for these events, record the fact. (This happens
+ # here rather than in _store_event_txn because it also needs to happen when
+ # we de-outlier an event.)
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="partial_state_events",
+ keys=("room_id", "event_id"),
+ values=[
+ (
+ event.room_id,
+ event.event_id,
+ )
+ for event, ctx in events_and_contexts
+ if ctx.partial_state
+ ],
+ )
+
self.db_pool.simple_upsert_many_txn(
txn,
table="event_to_state_groups",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8d428704..26784f75 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore):
include the previous states content in the unsigned field.
allow_rejected: If True, return rejected events. Otherwise,
- omits rejeted events from the response.
+ omits rejected events from the response.
Returns:
A mapping from event_id to event.
@@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore):
forward_edge_query = """
SELECT 1 FROM event_edges
/* Check to make sure the event referencing our event in question is not rejected */
- LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+ LEFT JOIN rejections ON event_edges.event_id = rejections.event_id
WHERE
event_edges.room_id = ?
AND event_edges.prev_event_id = ?
@@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore):
"get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn,
)
+
+ @cachedList("is_partial_state_event", list_name="event_ids")
+ async def get_partial_state_events(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, bool]:
+ """Checks which of the given events have partial state"""
+ result = await self.db_pool.simple_select_many_batch(
+ table="partial_state_events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=["event_id"],
+ desc="get_partial_state_events",
+ )
+ # convert the result to a dict, to make @cachedList work
+ partial = {r["event_id"] for r in result}
+ return {e_id: e_id in partial for e_id in event_ids}
+
+ @cached()
+ async def is_partial_state_event(self, event_id: str) -> bool:
+ """Checks if the given event has partial state"""
+ result = await self.db_pool.simple_select_one_onecol(
+ table="partial_state_events",
+ keyvalues={"event_id": event_id},
+ retcol="1",
+ allow_none=True,
+ desc="is_partial_state_event",
+ )
+ return result is not None
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 8f09dd8e..e9a0cdc6 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -112,7 +112,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
for tp in self.hs.config.server.mau_limits_reserved_threepids[
: self.hs.config.server.max_mau_value
]:
- user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
tp["medium"], canonicalise_email(tp["address"])
)
if user_id:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4f05811a..d3c46116 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
@@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
- ):
+ ) -> None:
super().__init__(database, db_conn, hs)
# Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
- ):
+ ) -> None:
super().__init__(database, db_conn, hs)
+ self._instance_name = hs.get_instance_name()
+ self._presence_id_gen: AbstractStreamIdGenerator
+
self._can_persist_presence = (
- hs.get_instance_name() in hs.config.worker.writers.presence
+ self._instance_name in hs.config.worker.writers.presence
)
if isinstance(database.engine, PostgresEngine):
@@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return stream_orderings[-1], self._presence_id_gen.get_current_token()
- def _update_presence_txn(self, txn, stream_orderings, presence_states):
+ def _update_presence_txn(
+ self, txn: LoggingTransaction, stream_orderings, presence_states
+ ) -> None:
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
if last_id == current_id:
return [], current_id, False
- def get_all_presence_updates_txn(txn):
+ def get_all_presence_updates_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
- status_msg,
- currently_active
+ status_msg, currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(row[0], row[1:]) for row in txn]
+ updates = cast(
+ List[Tuple[int, list]],
+ [(row[0], row[1:]) for row in txn],
+ )
upper_bound = current_id
limited = False
@@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
)
@cached()
- def _get_presence_for_user(self, user_id):
+ def _get_presence_for_user(self, user_id: str) -> None:
raise NotImplementedError()
@cachedList(
@@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
list_name="user_ids",
num_args=1,
)
- async def get_presence_for_users(self, user_ids):
+ async def get_presence_for_users(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
@@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
True if the user should have full presence sent to them, False otherwise.
"""
- def _should_user_receive_full_presence_with_token_txn(txn):
+ def _should_user_receive_full_presence_with_token_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
sql = """
SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ?
@@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
_should_user_receive_full_presence_with_token_txn,
)
- async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+ async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
"""Adds to the list of users who should receive a full snapshot of presence
upon their next sync.
@@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return users_to_state
- def get_current_presence_token(self):
+ def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token()
- def _get_active_presence(self, db_conn: Connection):
+ def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
@@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return [UserPresenceState(**row) for row in rows]
- def take_presence_startup_info(self):
+ def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
- self._presence_on_startup = None
+ self._presence_on_startup = []
return active_on_startup
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index e87a8fb8..2e3818e4 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -13,9 +13,10 @@
# limitations under the License.
import logging
-from typing import Any, List, Set, Tuple
+from typing import Any, List, Set, Tuple, cast
from synapse.api.errors import SynapseError
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
@@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
)
def _purge_history_txn(
- self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ token: RoomStreamToken,
+ delete_local_events: bool,
) -> Set[int]:
# Tables that should be pruned:
# event_auth
@@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
""",
(room_id,),
)
- (min_depth,) = txn.fetchone()
+ (min_depth,) = cast(Tuple[int], txn.fetchone())
logger.info("[purge] updating room_depth to %d", min_depth)
@@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"purge_room", self._purge_room_txn, room_id
)
- def _purge_room_txn(self, txn, room_id: str) -> List[int]:
+ def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index aac94fa4..dc666523 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> None:
"""Record a mapping from an external user id to a mxid
+ See notes in _record_user_external_id_txn about what constitutes valid data.
+
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
+
Raises:
ExternalIDReuseException if the new external_id could not be mapped.
"""
@@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
external_id: str,
user_id: str,
) -> None:
+ """
+ Record a mapping from an external user id to a mxid.
+
+ Note that the auth provider IDs (and the external IDs) are not validated
+ against configured IdPs as Synapse does not know its relationship to
+ external systems. For example, it might be useful to pre-configure users
+ before enabling a new IdP or an IdP might be temporarily offline, but
+ still valid.
+
+ Args:
+ txn: The database transaction.
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
self.db_pool.simple_insert_txn(
txn,
@@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Replace mappings from external user ids to a mxid in a single transaction.
All mappings are deleted and the new ones are created.
+ See notes in _record_user_external_id_txn about what constitutes valid data.
+
Args:
record_external_ids:
List with tuple of auth_provider and external_id to record
user_id: complete mxid that it is mapped to
+
Raises:
ExternalIDReuseException if the new external_id could not be mapped.
"""
@@ -1660,7 +1681,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id=row[1],
device_id=row[2],
next_token_id=row[3],
- has_next_refresh_token_been_refreshed=row[4],
+ # SQLite returns 0 or 1 for false/true, so convert to a bool.
+ has_next_refresh_token_been_refreshed=bool(row[4]),
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
expiry_ts=row[6],
@@ -1676,12 +1698,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Set the successor of a refresh token, removing the existing successor
if any.
+ This also deletes the predecessor refresh and access tokens,
+ since they cannot be valid anymore.
+
Args:
token_id: ID of the refresh token to update.
next_token_id: ID of its successor.
"""
- def _replace_refresh_token_txn(txn) -> None:
+ def _replace_refresh_token_txn(txn: LoggingTransaction) -> None:
# First check if there was an existing refresh token
old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
txn,
@@ -1707,6 +1732,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"id": old_next_token_id},
)
+ # Delete the previous refresh token, since we only want to keep the
+ # last 2 refresh tokens in the database.
+ # (The predecessor of the latest refresh token is still useful in
+ # case the refresh was interrupted and the client re-uses the old
+ # one.)
+ # This cascades to delete the associated access token.
+ self.db_pool.simple_delete_txn(
+ txn, "refresh_tokens", {"next_token_id": token_id}
+ )
+
await self.db_pool.runInteraction(
"replace_refresh_token", _replace_refresh_token_txn
)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index e2c27e59..36aa1092 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -53,8 +53,13 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
+ # The latest event in the thread.
latest_event: EventBase
+ # The latest edit to the latest event in the thread.
+ latest_edit: Optional[EventBase]
+ # The total number of events in the thread.
count: int
+ # True if the current user has sent an event to the thread.
current_user_participated: bool
@@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries(
self, event_ids: Collection[str]
- ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
- """Get the number of threaded replies and the latest reply (if any) for the given event.
+ ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
+ """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
Args:
event_ids: Summarize the thread related to this event ID.
@@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore):
A map of the thread summary each event. A missing event implies there
are no threaded replies.
- Each summary includes the number of items in the thread and the most
- recent response.
+ Each summary is a tuple of:
+ The number of events in the thread.
+ The most recent event in the thread.
+ The most recent edit to the most recent event in the thread, if applicable.
"""
def _get_thread_summaries_txn(
@@ -482,7 +489,7 @@ class RelationsWorkerStore(SQLBaseStore):
# TODO Should this only allow m.room.message events.
if isinstance(self.database_engine, PostgresEngine):
# The `DISTINCT ON` clause will pick the *first* row it encounters,
- # so ordering by topologica ordering + stream ordering desc will
+ # so ordering by topological ordering + stream ordering desc will
# ensure we get the latest event in the thread.
sql = """
SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
@@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
+ # Check to see if any of those events are edited.
+ latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+
# Map to the event IDs to the thread summary.
#
# There might not be a summary due to there not being a thread or
@@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore):
summary = None
if latest_event:
- summary = (counts[parent_event_id], latest_event)
+ latest_edit = latest_edits.get(latest_event_id)
+ summary = (counts[parent_event_id], latest_event, latest_edit)
summaries[parent_event_id] = summary
return summaries
@@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore):
)
for event_id, summary in summaries.items():
if summary:
- thread_count, latest_thread_event = summary
+ thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
+ latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 95167116..94068940 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -20,6 +20,7 @@ from typing import (
TYPE_CHECKING,
Any,
Awaitable,
+ Collection,
Dict,
List,
Optional,
@@ -1498,7 +1499,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
async def upsert_room_on_join(
- self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
+ self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
) -> None:
"""Ensure that the room is stored in the table
@@ -1511,7 +1512,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
has_auth_chain_index = await self.has_auth_chain_index(room_id)
create_event = None
- for e in auth_events:
+ for e in state_events:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
@@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
lock=False,
)
+ async def store_partial_state_room(
+ self,
+ room_id: str,
+ servers: Collection[str],
+ ) -> None:
+ """Mark the given room as containing events with partial state
+
+ Args:
+ room_id: the ID of the room
+ servers: other servers known to be in the room
+ """
+ await self.db_pool.runInteraction(
+ "store_partial_state_room",
+ self._store_partial_state_room_txn,
+ room_id,
+ servers,
+ )
+
+ @staticmethod
+ def _store_partial_state_room_txn(
+ txn: LoggingTransaction, room_id: str, servers: Collection[str]
+ ) -> None:
+ DatabasePool.simple_insert_txn(
+ txn,
+ table="partial_state_rooms",
+ values={
+ "room_id": room_id,
+ },
+ )
+ DatabasePool.simple_insert_many_txn(
+ txn,
+ table="partial_state_rooms_servers",
+ keys=("room_id", "server_name"),
+ values=((room_id, s) for s in servers),
+ )
+
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
) -> None:
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4489732f..e48ec5f4 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for room_id, instance, stream_id in txn
)
+ @cachedList(
+ cached_method_name="get_rooms_for_user_with_stream_ordering",
+ list_name="user_ids",
+ )
+ async def get_rooms_for_users_with_stream_ordering(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+ """A batched version of `get_rooms_for_user_with_stream_ordering`.
+
+ Returns:
+ Map from user_id to set of rooms that is currently in.
+ """
+ return await self.db_pool.runInteraction(
+ "get_rooms_for_users_with_stream_ordering",
+ self._get_rooms_for_users_with_stream_ordering_txn,
+ user_ids,
+ )
+
+ def _get_rooms_for_users_with_stream_ordering_txn(
+ self, txn, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "c.state_key",
+ user_ids,
+ )
+
+ if self._current_state_events_membership_up_to_date:
+ sql = f"""
+ SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND c.membership = ?
+ AND {clause}
+ """
+ else:
+ sql = f"""
+ SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS m USING (room_id, event_id)
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND m.membership = ?
+ AND {clause}
+ """
+
+ txn.execute(sql, [Membership.JOIN] + args)
+
+ result = {user_id: set() for user_id in user_ids}
+ for user_id, room_id, instance, stream_id in txn:
+ result[user_id].add(
+ GetRoomsForUserWithStreamOrdering(
+ room_id, PersistedEventPosition(instance, stream_id)
+ )
+ )
+
+ return {user_id: frozenset(v) for user_id, v in result.items()}
+
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2d085a57..e23b1190 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -28,6 +28,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -114,6 +115,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
+ EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings"
def __init__(
self,
@@ -146,6 +148,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
+ self.db_pool.updates.register_background_update_handler(
+ self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
+ )
+
async def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -371,6 +377,27 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return num_rows
+ async def _background_delete_non_strings(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Deletes rows with non-string `value`s from `event_search` if using sqlite.
+
+ Prior to Synapse 1.44.0, malformed events received over federation could cause integers
+ to be inserted into the `event_search` table when using sqlite.
+ """
+
+ def delete_non_strings_txn(txn: LoggingTransaction) -> None:
+ txn.execute("DELETE FROM event_search WHERE typeof(value) != 'text'")
+
+ await self.db_pool.runInteraction(
+ self.EVENT_SEARCH_DELETE_NON_STRINGS, delete_non_strings_txn
+ )
+
+ await self.db_pool.updates._end_background_update(
+ self.EVENT_SEARCH_DELETE_NON_STRINGS
+ )
+ return 1
+
class SearchStore(SearchBackgroundUpdateStore):
def __init__(
@@ -381,17 +408,19 @@ class SearchStore(SearchBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
- async def search_msgs(self, room_ids, search_term, keys):
+ async def search_msgs(
+ self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
+ ) -> JsonDict:
"""Performs a full text search over events with given keys.
Args:
- room_ids (list): List of room ids to search in
- search_term (str): Search term to search for
- keys (list): List of keys to search in, currently supports
+ room_ids: List of room ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
Returns:
- list of dicts
+ Dictionary of results
"""
clauses = []
@@ -499,10 +528,10 @@ class SearchStore(SearchBackgroundUpdateStore):
self,
room_ids: Collection[str],
search_term: str,
- keys: List[str],
+ keys: Iterable[str],
limit,
pagination_token: Optional[str] = None,
- ) -> List[dict]:
+ ) -> JsonDict:
"""Performs a full text search over events with given keys.
Args:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 2fb3e651..417aef1d 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -42,6 +42,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
+def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
+ v = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not v:
+ raise UnsupportedRoomVersionError(
+ "Room %s uses a room version %s which is no longer supported"
+ % (room_id, room_version_id)
+ )
+ return v
+
+
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""
@@ -62,11 +72,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Typically this happens if support for the room's version has been
removed from Synapse.
"""
- return await self.db_pool.runInteraction(
- "get_room_version_txn",
- self.get_room_version_txn,
- room_id,
- )
+ room_version_id = await self.get_room_version_id(room_id)
+ return _retrieve_and_check_room_version(room_id, room_version_id)
def get_room_version_txn(
self, txn: LoggingTransaction, room_id: str
@@ -82,15 +89,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
removed from Synapse.
"""
room_version_id = self.get_room_version_id_txn(txn, room_id)
- v = KNOWN_ROOM_VERSIONS.get(room_version_id)
-
- if not v:
- raise UnsupportedRoomVersionError(
- "Room %s uses a room version %s which is no longer supported"
- % (room_id, room_version_id)
- )
-
- return v
+ return _retrieve_and_check_room_version(room_id, room_version_id)
@cached(max_entries=10000)
async def get_room_version_id(self, room_id: str) -> str:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f7c778bd..e7fddd24 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
- ):
+ ) -> None:
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
- is_in_room = await self.is_host_joined(room_id, self.server_name)
+ is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
if is_in_room:
- users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+ users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
# Throw away users excluded from the directory.
users_with_profile = {
user_id: profile
@@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id in users_to_work_on:
if await self.should_include_local_user_in_dir(user_id):
- profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+ profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
@@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be
# contacted.
- if self.get_app_service_by_user_id(user) is not None:
+ if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
return False
# We're opting to exclude appservice users (anyone matching the user
@@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be
# contacted.
- if self.get_if_app_services_interested_in_user(user):
+ if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
# TODO we might want to make this configurable for each app service
return False
# Support users are for diagnostics and should not appear in the user directory.
- if await self.is_support_user(user):
+ if await self.is_support_user(user): # type: ignore[attr-defined]
return False
# Deactivated users aren't contactable, so should not appear in the user directory.
try:
- if await self.get_user_deactivated_status(user):
+ if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
return False
except StoreError:
# No such user in the users table. No need to do this when calling
@@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
- current_state_ids = await self.get_filtered_current_state_ids(
+ current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
- join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
+ join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
- hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
+ hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
if hist_vis_ev:
if (
hist_vis_ev.content.get("history_visibility")