summaryrefslogtreecommitdiff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-12-14 17:36:00 +0100
committerAndrej Shadura <andrewsh@debian.org>2021-12-14 17:36:00 +0100
commite894f36b081e694a99c1829d05a2382edd8c4533 (patch)
tree094399067207ada06f538d472bf44b608e984cac /synapse/storage/databases/main
parent7eba982950257d89142cd5a793320be25c5d669a (diff)
New upstream version 1.49.0
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py50
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py19
-rw-r--r--synapse/storage/databases/main/events.py107
-rw-r--r--synapse/storage/databases/main/events_worker.py581
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/databases/main/registration.py28
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--synapse/storage/databases/main/stream.py15
-rw-r--r--synapse/storage/databases/main/transactions.py70
12 files changed, 717 insertions, 180 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index baec35ee..4a883dc1 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state}, ["as_id"]
+ "application_services_state", {"state": state.value}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
desc="get_appservice_state",
)
if result:
- return result.get("state")
+ return ApplicationServiceState(result.get("state"))
return None
async def set_appservice_state(
@@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
- "application_services_state", {"as_id": service.id}, {"state": state}
+ "application_services_state", {"as_id": service.id}, {"state": state.value}
)
async def create_appservice_txn(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e5..d5a4a661 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
+ async def get_devices_by_auth_provider_session_id(
+ self, auth_provider_id: str, auth_provider_session_id: str
+ ) -> List[Dict[str, Any]]:
+ """Retrieve the list of devices associated with a SSO IdP session ID.
+
+ Args:
+ auth_provider_id: The SSO IdP ID as defined in the server config
+ auth_provider_session_id: The session ID within the IdP
+ Returns:
+ A list of dicts containing the device_id and the user_id of each device
+ """
+ return await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ )
+
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
@@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def store_device(
- self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+ self,
+ user_id: str,
+ device_id: str,
+ initial_device_display_name: Optional[str],
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: id of device
initial_device_display_name: initial displayname of the device.
Ignored if device exists.
+ auth_provider_id: The SSO IdP the user used, if any.
+ auth_provider_session_id: The session ID (sid) got from a OIDC login.
Returns:
Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
+ if auth_provider_id and auth_provider_session_id:
+ await self.db_pool.simple_insert(
+ "device_auth_providers",
+ values={
+ "user_id": user_id,
+ "device_id": device_id,
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ desc="store_device_auth_provider",
+ )
+
self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
@@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id},
)
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_auth_providers",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id},
+ )
+
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ef5d1ef0..9580a407 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore):
DELETE FROM event_auth
WHERE event_id IN (
SELECT event_id FROM events
- LEFT JOIN state_events USING (room_id, event_id)
+ LEFT JOIN state_events AS se USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
- AND state_key IS null
+ AND se.state_key IS null
)
"""
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d957e770..3efdd0c9 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
+from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
]
+class BasePushAction(TypedDict):
+ event_id: str
+ actions: List[Union[dict, str]]
+
+
+class HttpPushAction(BasePushAction):
+ room_id: str
+ stream_ordering: int
+
+
+class EmailPushAction(HttpPushAction):
+ received_ts: Optional[int]
+
+
def _serialize_action(actions, is_highlight):
"""Custom serializer for actions. This allows us to "compress" common actions.
@@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[dict]:
+ ) -> List[HttpPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
@@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[dict]:
+ ) -> List[EmailPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 06832221..4e528612 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
)
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -108,23 +106,30 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
- # Ideally we'd move these ID gens here, unfortunately some other ID
- # generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
- self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
+ # Since we have been configured to write, we ought to have id generators,
+ # rather than id trackers.
+ assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+ assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+ self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
+ *,
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
- backfilled: bool = False,
+ use_negative_stream_ordering: bool = False,
+ inhibit_local_membership_updates: bool = False,
) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -137,7 +142,14 @@ class PersistEventsStore:
room state
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
- backfilled
+ use_negative_stream_ordering: Whether to start stream_ordering on
+ the negative side and decrement. This should be set as True
+ for backfilled events because backfilled events get a negative
+ stream ordering so they don't come down incremental `/sync`.
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
Returns:
Resolves when the events have been persisted
@@ -159,7 +171,7 @@ class PersistEventsStore:
#
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
- if backfilled:
+ if use_negative_stream_ordering:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -176,13 +188,13 @@ class PersistEventsStore:
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(events_and_contexts))
- if not backfilled:
+ if stream < 0:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
@@ -316,8 +328,9 @@ class PersistEventsStore:
def _persist_events_txn(
self,
txn: LoggingTransaction,
+ *,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool,
+ inhibit_local_membership_updates: bool = False,
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
):
@@ -330,7 +343,10 @@ class PersistEventsStore:
Args:
txn
events_and_contexts: events to persist
- backfilled: True if the events were backfilled
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
@@ -363,9 +379,7 @@ class PersistEventsStore:
events_and_contexts
)
- self._update_room_depths_txn(
- txn, events_and_contexts=events_and_contexts, backfilled=backfilled
- )
+ self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
@@ -398,7 +412,7 @@ class PersistEventsStore:
txn,
events_and_contexts=events_and_contexts,
all_events_and_contexts=all_events_and_contexts,
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# We call this last as it assumes we've inserted the events into
@@ -561,9 +575,9 @@ class PersistEventsStore:
# fetch their auth event info.
while missing_auth_chains:
sql = """
- SELECT event_id, events.type, state_key, chain_id, sequence_number
+ SELECT event_id, events.type, se.state_key, chain_id, sequence_number
FROM events
- INNER JOIN state_events USING (event_id)
+ INNER JOIN state_events AS se USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
WHERE
"""
@@ -1200,7 +1214,6 @@ class PersistEventsStore:
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool,
):
"""Update min_depth for each room
@@ -1208,13 +1221,18 @@ class PersistEventsStore:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
- backfilled (bool): True if the events were backfilled
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
- if not backfilled:
+ # Then update the `stream_ordering` position to mark the latest
+ # event as the front of the room. This should not be done for
+ # backfilled events because backfilled events have negative
+ # stream_ordering and happened in the past so we know that we don't
+ # need to update the stream_ordering tip/front for the room.
+ assert event.internal_metadata.stream_ordering is not None
+ if event.internal_metadata.stream_ordering >= 0:
txn.call_after(
self.store._events_stream_cache.entity_has_changed,
event.room_id,
@@ -1427,7 +1445,12 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _update_metadata_tables_txn(
- self, txn, events_and_contexts, all_events_and_contexts, backfilled
+ self,
+ txn,
+ *,
+ events_and_contexts,
+ all_events_and_contexts,
+ inhibit_local_membership_updates: bool = False,
):
"""Update all the miscellaneous tables for new events
@@ -1439,7 +1462,10 @@ class PersistEventsStore:
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
- backfilled (bool): True if the events were backfilled
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
"""
# Insert all the push actions into the event_push_actions table.
@@ -1513,7 +1539,7 @@ class PersistEventsStore:
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# Insert event_reference_hashes table.
@@ -1553,11 +1579,13 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
- to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+ to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+ self.store._get_event_cache.set(
+ (cache_entry.event.event_id,), cache_entry
+ )
txn.call_after(prefill)
@@ -1638,8 +1666,19 @@ class PersistEventsStore:
txn, table="event_reference_hashes", values=vals
)
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database."""
+ def _store_room_members_txn(
+ self, txn, events, *, inhibit_local_membership_updates: bool = False
+ ):
+ """
+ Store a room member in the database.
+ Args:
+ txn: The transaction to use.
+ events: List of events to store.
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
+ """
def non_null_str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) and "\u0000" not in val else None
@@ -1682,7 +1721,7 @@ class PersistEventsStore:
# band membership", like a remote invite or a rejection of a remote invite.
if (
self.is_mine_id(event.state_key)
- and not backfilled
+ and not inhibit_local_membership_updates
and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership()
):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d..c7b660ac 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
import logging
import threading
from typing import (
+ TYPE_CHECKING,
+ Any,
Collection,
Container,
Dict,
Iterable,
List,
+ NoReturn,
Optional,
Set,
Tuple,
+ cast,
overload,
)
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
- room_version_id: Optional[int]
+ room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
+ self._stream_id_gen: AbstractStreamIdTracker
+ self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache = LruCache(
+ self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
- str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
+ self._event_fetch_list: List[
+ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+ ] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
+ def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[False] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[False] = ...,
+ check_room_id: Optional[str] = ...,
) -> EventBase:
...
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[True] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[True] = ...,
+ check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
- event_ids: Iterable[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
- ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
- Dict[str, _EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred())
+ Dict[str, EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id):
+ def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn: Connection) -> None:
+ def _maybe_start_fetch_thread(self) -> None:
+ """Starts an event fetch thread if we are not yet at the maximum number."""
+ with self._event_fetch_lock:
+ if (
+ self._event_fetch_list
+ and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+ ):
+ self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process("fetch_events", self._fetch_thread)
+
+ async def _fetch_thread(self) -> None:
+ """Services requests for events from `_event_fetch_list`."""
+ exc = None
+ try:
+ await self.db_pool.runWithConnection(self._fetch_loop)
+ except BaseException as e:
+ exc = e
+ raise
+ finally:
+ should_restart = False
+ event_fetches_to_fail = []
+ with self._event_fetch_lock:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+ # There may still be work remaining in `_event_fetch_list` if we
+ # failed, or it was added in between us deciding to exit and
+ # decrementing `_event_fetch_ongoing`.
+ if self._event_fetch_list:
+ if exc is None:
+ # We decided to exit, but then some more work was added
+ # before `_event_fetch_ongoing` was decremented.
+ # If a new event fetch thread was not started, we should
+ # restart ourselves since the remaining event fetch threads
+ # may take a while to get around to the new work.
+ #
+ # Unfortunately it is not possible to tell whether a new
+ # event fetch thread was started, so we restart
+ # unconditionally. If we are unlucky, we will end up with
+ # an idle fetch thread, but it will time out after
+ # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+ # in any case.
+ #
+ # Note that multiple fetch threads may run down this path at
+ # the same time.
+ should_restart = True
+ elif isinstance(exc, Exception):
+ if self._event_fetch_ongoing == 0:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will
+ # handle them.
+ event_fetches_to_fail = self._event_fetch_list
+ self._event_fetch_list = []
+ else:
+ # We weren't the last remaining fetcher, so another
+ # fetcher will pick up the work. This will either happen
+ # after their existing work, however long that takes,
+ # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+ # they are idle.
+ pass
+ else:
+ # The exception is a `SystemExit`, `KeyboardInterrupt` or
+ # `GeneratorExit`. Don't try to do anything clever here.
+ pass
+
+ if should_restart:
+ # We exited cleanly but noticed more work.
+ self._maybe_start_fetch_thread()
+
+ if event_fetches_to_fail:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will handle them.
+ assert exc is not None
+ with PreserveLoggingContext():
+ for _, deferred in event_fetches_to_fail:
+ deferred.errback(exc)
+
+ def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- try:
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- break
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ # There are no requests waiting. If we haven't yet reached the
+ # maximum iteration limit, wait for some more requests to turn up.
+ # Otherwise, bail out.
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ return
+
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
- self._fetch_event_list(conn, event_list)
- finally:
- self._event_fetch_ongoing -= 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ self._fetch_event_list(conn, event_list)
def _fetch_event_list(
- self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ self,
+ conn: LoggingDatabaseConnection,
+ event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
- def fire():
+ def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
+ def fire_errback(exc: Exception) -> None:
+ for _, d in event_list:
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
+ self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
- self, event_ids: Iterable[str]
- ) -> Dict[str, _EventCacheEntry]:
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
- fetched_events = {}
+ fetched_event_ids: Set[str] = set()
+ fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
- redaction_ids = set()
+ redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
- fetched_events[event_id] = row
+ fetched_event_ids.add(event_id)
if row:
+ fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
- event_map = {}
+ event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
- if not row:
- continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
+ room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
- result_map = {}
+ result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
- cache_entry = _EventCacheEntry(
+ cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
- events_d = defer.Deferred()
+ events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
-
self._event_fetch_lock.notify()
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.db_pool.runWithConnection, self._do_fetch
- )
+ self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- async def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
- set[str]: The events we have already seen.
+ The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
- def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+ def have_seen_events_txn(
+ txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+ ) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str):
+ async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
- def _get_current_state_event_counts_txn(self, txn, room_id):
+ def _get_current_state_event_counts_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> int:
"""
See get_current_state_event_counts.
"""
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- async def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- dict[str:int] of complexity version to complexity.
+ dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self):
+ def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_all_new_forward_event_rows(txn):
+ def get_all_new_forward_event_rows(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_ex_outlier_stream_rows_txn(txn):
+ def get_ex_outlier_stream_rows_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_backfill_event_rows(txn):
+ def get_all_new_backfill_event_rows(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " se.state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" AND instance_name = ?"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
- new_event_updates = [(row[0], row[1:]) for row in txn]
+ new_event_updates: List[
+ Tuple[int, Tuple[str, str, str, str, str, str]]
+ ] = []
+ row: Tuple[int, str, str, str, str, str, str]
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " se.state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
- new_event_updates.extend((row[0], row[1:]) for row in txn)
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
- ) -> Tuple[List[Tuple], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
- def get_all_updated_current_state_deltas_txn(txn):
+ def get_all_updated_current_state_deltas_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
- def get_deltas_for_stream_id_txn(txn, stream_id):
+ def get_deltas_for_stream_id_txn(
+ txn: LoggingTransaction, stream_id: int
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows: List[Tuple] = await self.db_pool.runInteraction(
+ rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- async def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
- async def get_event_ordering(self, event_id):
+ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
- def get_next_event_to_expire_txn(txn):
+ def get_next_event_to_expire_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
- return txn.fetchone()
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
- async def _cleanup_old_transaction_ids(self):
+ async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
- def _cleanup_old_transaction_ids_txn(txn):
+ def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
@@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
"_cleanup_old_transaction_ids",
_cleanup_old_transaction_ids_txn,
)
+
+ async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
+ """Check if the given event is next to a backward gap of missing events.
+ <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
+
+ Args:
+ room_id: room where the event lives
+ event_id: event to check
+
+ Returns:
+ Boolean indicating whether it's an extremity
+ """
+
+ def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
+ # If the event in question has any of its prev_events listed as a
+ # backward extremity, it's next to a gap.
+ #
+ # We can't just check the backward edges in `event_edges` because
+ # when we persist events, we will also record the prev_events as
+ # edges to the event in question regardless of whether we have those
+ # prev_events yet. We need to check whether those prev_events are
+ # backward extremities, also known as gaps, that need to be
+ # backfilled.
+ backward_extremity_query = """
+ SELECT 1 FROM event_backward_extremities
+ WHERE
+ room_id = ?
+ AND %s
+ LIMIT 1
+ """
+
+ # If the event in question is a backward extremity or has any of its
+ # prev_events listed as a backward extremity, it's next to a
+ # backward gap.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "event_id",
+ [event.event_id] + list(event.prev_event_ids()),
+ )
+
+ txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
+ backward_extremities = txn.fetchall()
+
+ # We consider any backward extremity as a backward gap
+ if len(backward_extremities):
+ return True
+
+ return False
+
+ return await self.db_pool.runInteraction(
+ "is_event_next_to_backward_gap_txn",
+ is_event_next_to_backward_gap_txn,
+ )
+
+ async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
+ """Check if the given event is next to a forward gap of missing events.
+ The gap in front of the latest events is not considered a gap.
+ <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
+ <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
+
+ Args:
+ room_id: room where the event lives
+ event_id: event to check
+
+ Returns:
+ Boolean indicating whether it's an extremity
+ """
+
+ def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
+ # If the event in question is a forward extremity, we will just
+ # consider any potential forward gap as not a gap since it's one of
+ # the latest events in the room.
+ #
+ # `event_forward_extremities` does not include backfilled or outlier
+ # events so we can't rely on it to find forward gaps. We can only
+ # use it to determine whether a message is the latest in the room.
+ #
+ # We can't combine this query with the `forward_edge_query` below
+ # because if the event in question has no forward edges (isn't
+ # referenced by any other event's prev_events) but is in
+ # `event_forward_extremities`, we don't want to return 0 rows and
+ # say it's next to a gap.
+ forward_extremity_query = """
+ SELECT 1 FROM event_forward_extremities
+ WHERE
+ room_id = ?
+ AND event_id = ?
+ LIMIT 1
+ """
+
+ # Check to see whether the event in question is already referenced
+ # by another event. If we don't see any edges, we're next to a
+ # forward gap.
+ forward_edge_query = """
+ SELECT 1 FROM event_edges
+ /* Check to make sure the event referencing our event in question is not rejected */
+ LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+ WHERE
+ event_edges.room_id = ?
+ AND event_edges.prev_event_id = ?
+ /* It's not a valid edge if the event referencing our event in
+ * question is rejected.
+ */
+ AND rejections.event_id IS NULL
+ LIMIT 1
+ """
+
+ # We consider any forward extremity as the latest in the room and
+ # not a forward gap.
+ #
+ # To expand, even though there is technically a gap at the front of
+ # the room where the forward extremities are, we consider those the
+ # latest messages in the room so asking other homeservers for more
+ # is useless. The new latest messages will just be federated as
+ # usual.
+ txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+ forward_extremities = txn.fetchall()
+ if len(forward_extremities):
+ return False
+
+ # If there are no forward edges to the event in question (another
+ # event hasn't referenced this event in their prev_events), then we
+ # assume there is a forward gap in the history.
+ txn.execute(forward_edge_query, (event.room_id, event.event_id))
+ forward_edges = txn.fetchall()
+ if not len(forward_edges):
+ return True
+
+ return False
+
+ return await self.db_pool.runInteraction(
+ "is_event_next_to_gap_txn",
+ is_event_next_to_gap_txn,
+ )
+
+ async def get_event_id_for_timestamp(
+ self, room_id: str, timestamp: int, direction: str
+ ) -> Optional[str]:
+ """Find the closest event to the given timestamp in the given direction.
+
+ Args:
+ room_id: Room to fetch the event from
+ timestamp: The point in time (inclusive) we should navigate from in
+ the given direction to find the closest event.
+ direction: ["f"|"b"] to indicate whether we should navigate forward
+ or backward from the given timestamp to find the closest event.
+
+ Returns:
+ The closest event_id otherwise None if we can't find any event in
+ the given direction.
+ """
+
+ sql_template = """
+ SELECT event_id FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE
+ origin_server_ts %s ?
+ AND room_id = ?
+ /* Make sure event is not rejected */
+ AND rejections.event_id IS NULL
+ ORDER BY origin_server_ts %s
+ LIMIT 1;
+ """
+
+ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
+ if direction == "b":
+ # Find closest event *before* a given timestamp. We use descending
+ # (which gives values largest to smallest) because we want the
+ # largest possible timestamp *before* the given timestamp.
+ comparison_operator = "<="
+ order = "DESC"
+ else:
+ # Find closest event *after* a given timestamp. We use ascending
+ # (which gives values smallest to largest) because we want the
+ # closest possible timestamp *after* the given timestamp.
+ comparison_operator = ">="
+ order = "ASC"
+
+ txn.execute(
+ sql_template % (comparison_operator, order), (timestamp, room_id)
+ )
+ row = txn.fetchone()
+ if row:
+ (event_id,) = row
+ return event_id
+
+ return None
+
+ if direction not in ("f", "b"):
+ raise ValueError("Unknown direction: %s" % (direction,))
+
+ return await self.db_pool.runInteraction(
+ "get_event_id_for_timestamp_txn",
+ get_event_id_for_timestamp_txn,
+ )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3eb30944..91b0576b 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
logger.info("[purge] looking for events to delete")
- should_delete_expr = "state_key IS NULL"
+ should_delete_expr = "state_events.state_key IS NULL"
should_delete_params: Tuple[Any, ...] = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023..3b632673 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: Union[
- StreamIdGenerator, SlavedIdTracker
- ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+ self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ )
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0e8c1686..e1ddf069 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -106,6 +106,15 @@ class RefreshTokenLookupResult:
has_next_access_token_been_used: bool
"""True if the next access token was already used at least once."""
+ expiry_ts: Optional[int]
+ """The time at which the refresh token expires and can not be used.
+ If None, the refresh token doesn't expire."""
+
+ ultimate_session_expiry_ts: Optional[int]
+ """The time at which the session comes to an end and can no longer be
+ refreshed.
+ If None, the session can be refreshed indefinitely."""
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
@@ -1626,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
rt.user_id,
rt.device_id,
rt.next_token_id,
- (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
- at.used has_next_access_token_been_used
+ (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+ at.used AS has_next_access_token_been_used,
+ rt.expiry_ts,
+ rt.ultimate_session_expiry_ts
FROM refresh_tokens rt
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1648,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
has_next_refresh_token_been_refreshed=row[4],
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
+ expiry_ts=row[6],
+ ultimate_session_expiry_ts=row[7],
)
return await self.db_pool.runInteraction(
@@ -1915,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: str,
token: str,
device_id: Optional[str],
+ expiry_ts: Optional[int],
+ ultimate_session_expiry_ts: Optional[int],
) -> int:
"""Adds a refresh token for the given user.
@@ -1922,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the refresh token.
+ expiry_ts (milliseconds since the epoch): Time after which the
+ refresh token cannot be used.
+ If None, the refresh token never expires until it has been used.
+ ultimate_session_expiry_ts (milliseconds since the epoch):
+ Time at which the session will end and can not be extended any
+ further.
+ If None, the session can be refreshed indefinitely.
Raises:
StoreError if there was a problem adding this.
Returns:
@@ -1937,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"token": token,
"next_token_id": None,
+ "expiry_ts": expiry_ts,
+ "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
},
desc="add_refresh_token_to_user",
)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 033a9831..6b2a8d06 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
- AND state_key = ?
+ AND c.state_key = ?
AND c.membership = ?
"""
else:
@@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
- AND state_key = ?
+ AND c.state_key = ?
AND m.membership = ?
"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 42dc807d..57aab552 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
oldest `limit` events.
Returns:
- The list of events (in ascending order) and the token from the start
+ The list of events (in ascending stream order) and the token from the start
of the chunk of events returned.
"""
if from_key == to_key:
@@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return [], from_key
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
+ """Fetch membership events for a given user.
+
+ All such events whose stream ordering `s` lies in the range
+ `from_key < s <= to_key` are returned. Events are ordered by ascending stream
+ order.
+ """
+ # Start by ruling out cases where a DB query is not necessary.
if from_key == to_key:
return []
@@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return []
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
Returns:
A list of events and a token pointing to the start of the returned
- events. The events returned are in ascending order.
+ events. The events returned are in ascending topological order.
"""
rows, token = await self.get_recent_event_ids_for_room(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index d7dc1f73..16228225 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,6 +14,7 @@
import logging
from collections import namedtuple
+from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr
@@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple(
)
+class DestinationSortOrder(Enum):
+ """Enum to define the sorting method used when returning destinations."""
+
+ DESTINATION = "destination"
+ RETRY_LAST_TS = "retry_last_ts"
+ RETTRY_INTERVAL = "retry_interval"
+ FAILURE_TS = "failure_ts"
+ LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DestinationRetryTimings:
"""The current destination retry timing info for a remote server."""
@@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destinations = [row[0] for row in txn]
return destinations
+
+ async def get_destinations_paginate(
+ self,
+ start: int,
+ limit: int,
+ destination: Optional[str] = None,
+ order_by: str = DestinationSortOrder.DESTINATION.value,
+ direction: str = "f",
+ ) -> Tuple[List[JsonDict], int]:
+ """Function to retrieve a paginated list of destinations.
+ This will return a json list of destinations and the
+ total number of destinations matching the filter criteria.
+
+ Args:
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ destination: search string in destination
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
+ Returns:
+ A tuple of a list of mappings from destination to information
+ and a count of total destinations.
+ """
+
+ def get_destinations_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+ order_by_column = DestinationSortOrder(order_by).value
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ args = []
+ where_statement = ""
+ if destination:
+ args.extend(["%" + destination.lower() + "%"])
+ where_statement = "WHERE LOWER(destination) LIKE ?"
+
+ sql_base = f"FROM destinations {where_statement} "
+ sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = f"""
+ SELECT destination, retry_last_ts, retry_interval, failure_ts,
+ last_successful_stream_ordering
+ {sql_base}
+ ORDER BY {order_by_column} {order}, destination ASC
+ LIMIT ? OFFSET ?
+ """
+ txn.execute(sql, args + [limit, start])
+ destinations = self.db_pool.cursor_to_dict(txn)
+ return destinations, count
+
+ return await self.db_pool.runInteraction(
+ "get_destinations_paginate_txn", get_destinations_paginate_txn
+ )