summaryrefslogtreecommitdiff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2023-02-19 09:27:32 +0100
committerAndrej Shadura <andrewsh@debian.org>2023-02-19 09:27:32 +0100
commit9480a27c4c98ccaf27bbe363bad0823aee52ed5f (patch)
treee9f40bab31df04789556f078d3325c29a3e4af37 /synapse/storage/databases/main
parentedc94df0f3cbbf133d2f3c8e5b5a93f8acff8f59 (diff)
New upstream version 1.77.0
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/__init__.py5
-rw-r--r--synapse/storage/databases/main/account_data.py235
-rw-r--r--synapse/storage/databases/main/appservice.py14
-rw-r--r--synapse/storage/databases/main/cache.py12
-rw-r--r--synapse/storage/databases/main/deviceinbox.py10
-rw-r--r--synapse/storage/databases/main/devices.py52
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py5
-rw-r--r--synapse/storage/databases/main/event_push_actions.py9
-rw-r--r--synapse/storage/databases/main/events.py6
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py12
-rw-r--r--synapse/storage/databases/main/events_worker.py160
-rw-r--r--synapse/storage/databases/main/media_repository.py5
-rw-r--r--synapse/storage/databases/main/presence.py11
-rw-r--r--synapse/storage/databases/main/push_rule.py13
-rw-r--r--synapse/storage/databases/main/pusher.py7
-rw-r--r--synapse/storage/databases/main/receipts.py43
-rw-r--r--synapse/storage/databases/main/relations.py47
-rw-r--r--synapse/storage/databases/main/room.py158
-rw-r--r--synapse/storage/databases/main/roommember.py65
-rw-r--r--synapse/storage/databases/main/state.py54
-rw-r--r--synapse/storage/databases/main/stats.py19
-rw-r--r--synapse/storage/databases/main/stream.py302
-rw-r--r--synapse/storage/databases/main/tags.py48
-rw-r--r--synapse/storage/databases/main/transactions.py13
24 files changed, 1006 insertions, 299 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0e47592b..837dc764 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import (
DatabasePool,
@@ -167,7 +168,7 @@ class DataStore(
guests: bool = True,
deactivated: bool = False,
order_by: str = UserSortOrder.NAME.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
approved: bool = True,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
@@ -197,7 +198,7 @@ class DataStore(
# Set ordering
order_by_column = UserSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 07908c41..8a359d7e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -27,7 +27,7 @@ from typing import (
)
from synapse.api.constants import AccountDataTypes
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import (
DatabasePool,
@@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="account_data",
instance_name=self._instance_name,
tables=[
@@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# SQLite).
self._account_data_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
@@ -123,7 +125,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def get_account_data_for_user(
self, user_id: str
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- """Get all the client account_data for a user.
+ """
+ Get all the client account_data for a user.
+
+ If experimental MSC3391 support is enabled, any entries with an empty
+ content body are excluded; as this means they have been deleted.
Args:
user_id: The user to get the account_data for.
@@ -135,27 +141,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "account_data",
- {"user_id": user_id},
- ["account_data_type", "content"],
- )
+ # The 'content != '{}' condition below prevents us from using
+ # `simple_select_list_txn` here, as it doesn't support conditions
+ # other than 'equals'.
+ sql = """
+ SELECT account_data_type, content FROM account_data
+ WHERE user_id = ?
+ """
+
+ # If experimental MSC3391 support is enabled, then account data entries
+ # with an empty content are considered "deleted". So skip adding them to
+ # the results.
+ if self.hs.config.experimental.msc3391_enabled:
+ sql += " AND content != '{}'"
+
+ txn.execute(sql, (user_id,))
+ rows = self.db_pool.cursor_to_dict(txn)
global_account_data = {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id},
- ["room_id", "account_data_type", "content"],
- )
+ # The 'content != '{}' condition below prevents us from using
+ # `simple_select_list_txn` here, as it doesn't support conditions
+ # other than 'equals'.
+ sql = """
+ SELECT room_id, account_data_type, content FROM room_account_data
+ WHERE user_id = ?
+ """
+
+ # If experimental MSC3391 support is enabled, then account data entries
+ # with an empty content are considered "deleted". So skip adding them to
+ # the results.
+ if self.hs.config.experimental.msc3391_enabled:
+ sql += " AND content != '{}'"
+
+ txn.execute(sql, (user_id,))
+ rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
+
room_data[row["account_data_type"]] = db_to_json(row["content"])
return global_account_data, by_room
@@ -411,10 +438,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
token: int,
rows: Iterable[Any],
) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- elif stream_name == AccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
+ if stream_name == AccountDataStream.NAME:
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
@@ -429,6 +453,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == AccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
@@ -469,6 +500,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return self._account_data_id_gen.get_current_token()
+ async def remove_account_data_for_room(
+ self, user_id: str, room_id: str, account_data_type: str
+ ) -> Optional[int]:
+ """Delete the room account data for the user of a given type.
+
+ Args:
+ user_id: The user to remove account_data for.
+ room_id: The room ID to scope the request to.
+ account_data_type: The account data type to delete.
+
+ Returns:
+ The maximum stream position, or None if there was no matching room account
+ data to delete.
+ """
+ assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
+
+ def _remove_account_data_for_room_txn(
+ txn: LoggingTransaction, next_id: int
+ ) -> bool:
+ """
+ Args:
+ txn: The transaction object.
+ next_id: The stream_id to update any existing rows to.
+
+ Returns:
+ True if an entry in room_account_data had its content set to '{}',
+ otherwise False. This informs callers of whether there actually was an
+ existing room account data entry to delete, or if the call was a no-op.
+ """
+ # We can't use `simple_update` as it doesn't have the ability to specify
+ # where clauses other than '=', which we need for `content != '{}'` below.
+ sql = """
+ UPDATE room_account_data
+ SET stream_id = ?, content = '{}'
+ WHERE user_id = ?
+ AND room_id = ?
+ AND account_data_type = ?
+ AND content != '{}'
+ """
+ txn.execute(
+ sql,
+ (next_id, user_id, room_id, account_data_type),
+ )
+ # Return true if any rows were updated.
+ return txn.rowcount != 0
+
+ async with self._account_data_id_gen.get_next() as next_id:
+ row_updated = await self.db_pool.runInteraction(
+ "remove_account_data_for_room",
+ _remove_account_data_for_room_txn,
+ next_id,
+ )
+
+ if not row_updated:
+ return None
+
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type), {}
+ )
+
+ return self._account_data_id_gen.get_current_token()
+
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
@@ -569,6 +666,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
+ async def remove_account_data_for_user(
+ self,
+ user_id: str,
+ account_data_type: str,
+ ) -> Optional[int]:
+ """
+ Delete a single piece of user account data by type.
+
+ A "delete" is performed by updating a potentially existing row in the
+ "account_data" database table for (user_id, account_data_type) and
+ setting its content to "{}".
+
+ Args:
+ user_id: The user ID to modify the account data of.
+ account_data_type: The type to remove.
+
+ Returns:
+ The maximum stream position, or None if there was no matching account data
+ to delete.
+ """
+ assert self._can_write_to_account_data
+ assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
+
+ def _remove_account_data_for_user_txn(
+ txn: LoggingTransaction, next_id: int
+ ) -> bool:
+ """
+ Args:
+ txn: The transaction object.
+ next_id: The stream_id to update any existing rows to.
+
+ Returns:
+ True if an entry in account_data had its content set to '{}', otherwise
+ False. This informs callers of whether there actually was an existing
+ account data entry to delete, or if the call was a no-op.
+ """
+ # We can't use `simple_update` as it doesn't have the ability to specify
+ # where clauses other than '=', which we need for `content != '{}'` below.
+ sql = """
+ UPDATE account_data
+ SET stream_id = ?, content = '{}'
+ WHERE user_id = ?
+ AND account_data_type = ?
+ AND content != '{}'
+ """
+ txn.execute(sql, (next_id, user_id, account_data_type))
+ if txn.rowcount == 0:
+ # We didn't update any rows. This means that there was no matching room
+ # account data entry to delete in the first place.
+ return False
+
+ # Ignored users get denormalized into a separate table as an optimisation.
+ if account_data_type == AccountDataTypes.IGNORED_USER_LIST:
+ # If this method was called with the ignored users account data type, we
+ # simply delete all ignored users.
+
+ # First pull all the users that this user ignores.
+ previously_ignored_users = set(
+ self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ retcol="ignored_user_id",
+ )
+ )
+
+ # Then delete them from the database.
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ )
+
+ # Invalidate the cache for ignored users which were removed.
+ for ignored_user_id in previously_ignored_users:
+ self._invalidate_cache_and_stream(
+ txn, self.ignored_by, (ignored_user_id,)
+ )
+
+ # Invalidate for this user the cache tracking ignored users.
+ self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
+
+ return True
+
+ async with self._account_data_id_gen.get_next() as next_id:
+ row_updated = await self.db_pool.runInteraction(
+ "remove_account_data_for_user",
+ _remove_account_data_for_user_txn,
+ next_id,
+ )
+
+ if not row_updated:
+ return None
+
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_by_type_for_user.prefill(
+ (user_id, account_data_type), {}
+ )
+
+ return self._account_data_id_gen.get_current_token()
+
async def purge_account_data_for_user(self, user_id: str) -> None:
"""
Removes ALL the account data for a user.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index c2c8018e..5fb152c4 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,7 +14,17 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Optional,
+ Pattern,
+ Sequence,
+ Tuple,
+ cast,
+)
from synapse.appservice import (
ApplicationService,
@@ -257,7 +267,7 @@ class ApplicationServiceTransactionWorkerStore(
async def create_appservice_txn(
self,
service: ApplicationService,
- events: List[EventBase],
+ events: Sequence[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_keys_count: TransactionOneTimeKeysCount,
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index a58668a3..5b664316 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
+ notifier=hs.get_replication_notifier(),
stream_name="caches",
instance_name=hs.get_instance_name(),
tables=[
@@ -164,9 +165,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=True,
)
elif stream_name == CachesStream.NAME:
- if self._cache_id_gen:
- self._cache_id_gen.advance(instance_name, token)
-
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
@@ -182,6 +180,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == CachesStream.NAME:
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 48a54d9c..8e61aba4 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -91,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
@@ -101,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
else:
self._can_write_to_device = True
self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_inbox", "stream_id"
+ db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
@@ -157,6 +158,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == ToDeviceStream.NAME:
+ self._device_inbox_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a5bb4d40..e8b6cc6b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,7 +38,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
+from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -54,7 +54,7 @@ from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
StreamIdGenerator,
)
-from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
+from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
@@ -92,12 +92,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# class below that is used on the main process.
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
+ ("device_lists_remote_pending", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
)
@@ -162,18 +164,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
- elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
- for row in rows:
- self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == DeviceListsStream.NAME:
+ self._device_list_id_gen.advance(instance_name, token)
+
+ super().process_replication_position(stream_name, instance_name, token)
+
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
+ if row.is_signature:
+ self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+ continue
+
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
@@ -1062,16 +1072,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {row["user_id"] for row in rows}
- async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
+ async def mark_remote_users_device_caches_as_stale(
+ self, user_ids: StrCollection
+ ) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
- await self.db_pool.simple_upsert(
- table="device_lists_remote_resync",
- keyvalues={"user_id": user_id},
- values={},
- insertion_values={"added_ts": self._clock.time_msec()},
- desc="mark_remote_user_device_cache_as_stale",
+
+ def _mark_remote_users_device_caches_as_stale_txn(
+ txn: LoggingTransaction,
+ ) -> None:
+ # TODO add insertion_values support to simple_upsert_many and use
+ # that!
+ for user_id in user_ids:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="device_lists_remote_resync",
+ keyvalues={"user_id": user_id},
+ values={},
+ insertion_values={"added_ts": self._clock.time_msec()},
+ )
+
+ await self.db_pool.runInteraction(
+ "mark_remote_users_device_caches_as_stale",
+ _mark_remote_users_device_caches_as_stale_txn,
)
async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4c691642..c4ac6c33 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = StreamIdGenerator(
- db_conn, "e2e_cross_signing_keys", "stream_id"
+ db_conn,
+ hs.get_replication_notifier(),
+ "e2e_cross_signing_keys",
+ "stream_id",
)
async def set_e2e_device_keys(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7ebe34f7..3a0c370f 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -275,15 +275,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
self.db_pool.updates.register_background_index_update(
- "event_push_summary_unique_index",
- index_name="event_push_summary_unique_index",
- table="event_push_summary",
- columns=["user_id", "room_id"],
- unique=True,
- replaces_index="event_push_summary_user_rm",
- )
-
- self.db_pool.updates.register_background_index_update(
"event_push_summary_unique_index2",
index_name="event_push_summary_unique_index2",
table="event_push_summary",
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 0f097a29..1536937b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1651,7 +1651,7 @@ class PersistEventsStore:
if self._ephemeral_messages_enabled:
# If there's an expiry timestamp on the event, store it.
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
- if isinstance(expiry_ts, int) and not event.is_state():
+ if type(expiry_ts) is int and not event.is_state():
self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
# Insert into the room_memberships table.
@@ -2133,10 +2133,10 @@ class PersistEventsStore:
):
if (
"min_lifetime" in event.content
- and not isinstance(event.content.get("min_lifetime"), int)
+ and type(event.content["min_lifetime"]) is not int
) or (
"max_lifetime" in event.content
- and not isinstance(event.content.get("max_lifetime"), int)
+ and type(event.content["max_lifetime"]) is not int
):
# Ignore the event if one of the value isn't an integer.
return
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 9e31798a..b9d3c36d 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -69,6 +69,8 @@ class _BackgroundUpdates:
EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
+ EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index"
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
@@ -260,6 +262,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._background_events_populate_state_key_rejections,
)
+ # Add an index that would be useful for jumping to date using
+ # get_event_id_for_timestamp.
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.EVENTS_JUMP_TO_DATE_INDEX,
+ index_name="events_jump_to_date_idx",
+ table="events",
+ columns=["room_id", "origin_server_ts"],
+ where_clause="NOT outlier",
+ )
+
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 318fd7dc..6d0ef102 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -38,7 +38,7 @@ from typing_extensions import Literal
from twisted.internet import defer
-from synapse.api.constants import EventTypes
+from synapse.api.constants import Direction, EventTypes
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
@@ -59,8 +59,9 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.replication.tcp.streams import BackfillStream
+from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream
from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -70,6 +71,7 @@ from synapse.storage.database import (
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
@@ -108,6 +110,10 @@ event_fetch_ongoing_gauge = Gauge(
)
+class InvalidEventError(Exception):
+ """The event retrieved from the database is invalid and cannot be used."""
+
+
@attr.s(slots=True, auto_attribs=True)
class EventCacheEntry:
event: EventBase
@@ -189,6 +195,7 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="events",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
@@ -198,6 +205,7 @@ class EventsWorkerStore(SQLBaseStore):
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="backfill",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
@@ -215,12 +223,14 @@ class EventsWorkerStore(SQLBaseStore):
# SQLite).
self._stream_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"events",
"stream_ordering",
step=-1,
@@ -292,6 +302,98 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
+ self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
+
+ if isinstance(database.engine, PostgresEngine):
+ self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="un_partial_stated_event_stream",
+ instance_name=hs.get_instance_name(),
+ tables=[
+ ("un_partial_stated_event_stream", "instance_name", "stream_id")
+ ],
+ sequence_name="un_partial_stated_event_stream_sequence",
+ # TODO(faster_joins, multiple writers) Support multiple writers.
+ writers=["master"],
+ )
+ else:
+ self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
+ db_conn,
+ hs.get_replication_notifier(),
+ "un_partial_stated_event_stream",
+ "stream_id",
+ )
+
+ def get_un_partial_stated_events_token(self, instance_name: str) -> int:
+ return (
+ self._un_partial_stated_events_stream_id_gen.get_current_token_for_writer(
+ instance_name
+ )
+ )
+
+ async def get_un_partial_stated_events_from_stream(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, Tuple[str, bool]]], int, bool]:
+ """Get updates for the un-partial-stated events replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def get_un_partial_stated_events_from_stream_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, bool]]], int, bool]:
+ sql = """
+ SELECT stream_id, event_id, rejection_status_changed
+ FROM un_partial_stated_event_stream
+ WHERE ? < stream_id AND stream_id <= ? AND instance_name = ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
+ updates = [
+ (
+ row[0],
+ (
+ row[1],
+ bool(row[2]),
+ ),
+ )
+ for row in txn
+ ]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db_pool.runInteraction(
+ "get_un_partial_stated_events_from_stream",
+ get_un_partial_stated_events_from_stream_txn,
+ )
+
def process_replication_rows(
self,
stream_name: str,
@@ -299,12 +401,29 @@ class EventsWorkerStore(SQLBaseStore):
token: int,
rows: Iterable[Any],
) -> None:
+ if stream_name == UnPartialStatedEventStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedEventStreamRow)
+
+ self.is_partial_state_event.invalidate((row.event_id,))
+
+ if row.rejection_status_changed:
+ # If the partial-stated event became rejected or unrejected
+ # when it wasn't before, we need to invalidate this cache.
+ self._invalidate_local_get_event_cache(row.event_id)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token)
-
- super().process_replication_rows(stream_name, instance_name, token, rows)
+ elif stream_name == UnPartialStatedEventStream.NAME:
+ self._un_partial_stated_events_stream_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased
@@ -1195,7 +1314,7 @@ class EventsWorkerStore(SQLBaseStore):
# invites, so just accept it for all membership events.
#
if d["type"] != EventTypes.Member:
- raise Exception(
+ raise InvalidEventError(
"Room %s for event %s is unknown" % (d["room_id"], event_id)
)
@@ -1660,7 +1779,7 @@ class EventsWorkerStore(SQLBaseStore):
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
sql = (
- "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+ "SELECT out.event_stream_ordering, e.event_id, e.room_id, e.type,"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
" e.outlier"
" FROM events AS e"
@@ -1672,10 +1791,10 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
- " WHERE ? < event_stream_ordering"
- " AND event_stream_ordering <= ?"
+ " WHERE ? < out.event_stream_ordering"
+ " AND out.event_stream_ordering <= ?"
" AND out.instance_name = ?"
- " ORDER BY event_stream_ordering ASC"
+ " ORDER BY out.event_stream_ordering ASC"
)
txn.execute(sql, (last_id, current_id, instance_name))
@@ -2121,7 +2240,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_event_id_for_timestamp(
- self, room_id: str, timestamp: int, direction: str
+ self, room_id: str, timestamp: int, direction: Direction
) -> Optional[str]:
"""Find the closest event to the given timestamp in the given direction.
@@ -2129,14 +2248,14 @@ class EventsWorkerStore(SQLBaseStore):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
The closest event_id otherwise None if we can't find any event in
the given direction.
"""
- if direction == "b":
+ if direction == Direction.BACKWARDS:
# Find closest event *before* a given timestamp. We use descending
# (which gives values largest to smallest) because we want the
# largest possible timestamp *before* the given timestamp.
@@ -2188,9 +2307,6 @@ class EventsWorkerStore(SQLBaseStore):
return None
- if direction not in ("f", "b"):
- raise ValueError("Unknown direction: %s" % (direction,))
-
return await self.db_pool.runInteraction(
"get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn,
@@ -2292,6 +2408,9 @@ class EventsWorkerStore(SQLBaseStore):
This can happen, for example, when resyncing state during a faster join.
+ It is the caller's responsibility to ensure that other workers are
+ sent a notification so that they call `_invalidate_local_get_event_cache()`.
+
Args:
txn:
event_id: ID of event to update
@@ -2330,14 +2449,3 @@ class EventsWorkerStore(SQLBaseStore):
)
self.invalidate_get_event_cache_after_txn(txn, event_id)
-
- # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
- # call '_send_invalidation_to_replication', but we actually need the other
- # end to call _invalidate_local_get_event_cache() rather than (just)
- # _get_event_cache.invalidate().
- #
- # One solution might be to (somehow) get the workers to call
- # _invalidate_caches_for_event() (though that will invalidate more than
- # strictly necessary).
- #
- # https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 9b172a64..b202c5eb 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,6 +26,7 @@ from typing import (
cast,
)
+from synapse.api.constants import Direction
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -176,7 +177,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
limit: int,
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -199,7 +200,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
# Set ordering
order_by_column = MediaSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 9769a18a..beb210f8 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="presence_stream",
instance_name=self._instance_name,
tables=[("presence_stream", "instance_name", "stream_id")],
@@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
)
else:
self._presence_id_gen = StreamIdGenerator(
- db_conn, "presence_stream", "stream_id"
+ db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
)
self.hs = hs
@@ -439,8 +440,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
rows: Iterable[Any],
) -> None:
if stream_name == PresenceStream.NAME:
- self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == PresenceStream.NAME:
+ self._presence_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d4c64c46..9b2bbe06 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -86,8 +86,11 @@ def _load_rules(
filtered_rules = FilteredPushRules(
push_rules,
enabled_map,
- msc3664_enabled=experimental_config.msc3664_enabled,
msc1767_enabled=experimental_config.msc1767_enabled,
+ msc3664_enabled=experimental_config.msc3664_enabled,
+ msc3381_polls_enabled=experimental_config.msc3381_polls_enabled,
+ msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions,
+ msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs,
)
return filtered_rules
@@ -117,6 +120,7 @@ class PushRulesWorkerStore(
# class below that is used on the main process.
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
is_writer=hs.config.worker.worker_app is None,
@@ -154,6 +158,13 @@ class PushRulesWorkerStore(
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == PushRulesStream.NAME:
+ self._push_rules_stream_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 40fd781a..df53e726 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore):
# class below that is used on the main process.
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
@@ -111,12 +112,12 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(
- self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+ super().process_replication_position(stream_name, instance_name, token)
async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e06725f6..29972d52 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="receipts",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
@@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# SQLite).
self._receipts_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
@@ -588,6 +590,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == ReceiptsStream.NAME:
+ self._receipts_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
@@ -932,10 +941,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
receipts."""
def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
+ if isinstance(self.database_engine, PostgresEngine):
+ ROW_ID_NAME = "ctid"
+ else:
+ ROW_ID_NAME = "rowid"
+
# Identify any duplicate receipts arising from
# https://github.com/matrix-org/synapse/issues/14406.
- # We expect the following query to use the per-thread receipt index and take
- # less than a minute.
+ # The following query takes less than a minute on matrix.org.
sql = """
SELECT MAX(stream_id), room_id, receipt_type, user_id
FROM receipts_linearized
@@ -947,19 +960,33 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn))
# Then remove duplicate receipts, keeping the one with the highest
- # `stream_id`. There should only be a single receipt with any given
- # `stream_id`.
- for max_stream_id, room_id, receipt_type, user_id in duplicate_keys:
- sql = """
+ # `stream_id`. Since there might be duplicate rows with the same
+ # `stream_id`, we delete by the ctid instead.
+ for stream_id, room_id, receipt_type, user_id in duplicate_keys:
+ sql = f"""
+ SELECT {ROW_ID_NAME}
+ FROM receipts_linearized
+ WHERE
+ room_id = ? AND
+ receipt_type = ? AND
+ user_id = ? AND
+ thread_id IS NULL AND
+ stream_id = ?
+ LIMIT 1
+ """
+ txn.execute(sql, (room_id, receipt_type, user_id, stream_id))
+ row_id = cast(Tuple[str], txn.fetchone())[0]
+
+ sql = f"""
DELETE FROM receipts_linearized
WHERE
room_id = ? AND
receipt_type = ? AND
user_id = ? AND
thread_id IS NULL AND
- stream_id < ?
+ {ROW_ID_NAME} != ?
"""
- txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id))
+ txn.execute(sql, (room_id, receipt_type, user_id, row_id))
await self.db_pool.runInteraction(
self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index aea96e9d..0018d6f7 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -30,7 +30,7 @@ from typing import (
import attr
-from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
@@ -40,9 +40,13 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
-from synapse.storage.databases.main.stream import generate_pagination_where_clause
+from synapse.storage.databases.main.stream import (
+ generate_next_token,
+ generate_pagination_bounds,
+ generate_pagination_where_clause,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
+from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -164,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
@@ -177,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the most recent `limit` events.
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`).
+ direction: Whether to fetch the most recent first (backwards) or the
+ oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
@@ -207,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?")
where_args.append(event_type)
+ order, from_bound, to_bound = generate_pagination_bounds(
+ direction,
+ from_token.room_key if from_token else None,
+ to_token.room_key if to_token else None,
+ )
+
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.room_key.as_historical_tuple()
- if from_token
- else None,
- to_token=to_token.room_key.as_historical_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
if pagination_clause:
where_clause.append(pagination_clause)
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
@@ -266,16 +269,9 @@ class RelationsWorkerStore(SQLBaseStore):
topo_orderings = topo_orderings[:limit]
stream_orderings = stream_orderings[:limit]
- topo = topo_orderings[-1]
- token = stream_orderings[-1]
- if direction == "b":
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- token -= 1
- next_key = RoomStreamToken(topo, token)
+ next_key = generate_next_token(
+ direction, topo_orderings[-1], stream_orderings[-1]
+ )
if from_token:
next_token = from_token.copy_and_replace(
@@ -292,6 +288,7 @@ class RelationsWorkerStore(SQLBaseStore):
to_device_key=0,
device_list_key=0,
groups_key=0,
+ un_partial_stated_rooms_key=0,
)
return events[:limit], next_token
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 78906a5e..644bbb88 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -18,6 +18,7 @@ from abc import abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Any,
Awaitable,
Collection,
@@ -25,7 +26,7 @@ from typing import (
List,
Mapping,
Optional,
- Sequence,
+ Set,
Tuple,
Union,
cast,
@@ -34,6 +35,7 @@ from typing import (
import attr
from synapse.api.constants import (
+ Direction,
EventContentFields,
EventTypes,
JoinRules,
@@ -43,6 +45,7 @@ from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -58,9 +61,9 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID
+from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.stringutils import MXC_REGEX
if TYPE_CHECKING:
@@ -106,7 +109,7 @@ class RoomSortOrder(Enum):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PartialStateResyncInfo:
joined_via: Optional[str]
- servers_in_room: List[str] = attr.ib(factory=list)
+ servers_in_room: Set[str] = attr.ib(factory=set)
class RoomWorkerStore(CacheInvalidationWorkerStore):
@@ -126,6 +129,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream",
instance_name=self._instance_name,
tables=[
@@ -137,9 +141,19 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
else:
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
- db_conn, "un_partial_stated_room_stream", "stream_id"
+ db_conn,
+ hs.get_replication_notifier(),
+ "un_partial_stated_room_stream",
+ "stream_id",
)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == UnPartialStatedRoomStream.NAME:
+ self._un_partial_stated_rooms_stream_id_gen.advance(instance_name, token)
+ return super().process_replication_position(stream_name, instance_name, token)
+
async def store_room(
self,
room_id: str,
@@ -1179,21 +1193,35 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
- @cached(iterable=True)
- async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
- """Gets the list of servers in a partial state room at the time we joined it.
+ async def get_partial_state_servers_at_join(
+ self, room_id: str
+ ) -> Optional[AbstractSet[str]]:
+ """Gets the set of servers in a partial state room at the time we joined it.
Returns:
The `servers_in_room` list from the `/send_join` response for partial state
rooms. May not be accurate or complete, as it comes from a remote
homeserver.
- An empty list for full state rooms.
+ `None` for full state rooms.
"""
- return await self.db_pool.simple_select_onecol(
- "partial_state_rooms_servers",
- keyvalues={"room_id": room_id},
- retcol="server_name",
- desc="get_partial_state_servers_at_join",
+ servers_in_room = await self._get_partial_state_servers_at_join(room_id)
+
+ if len(servers_in_room) == 0:
+ return None
+
+ return servers_in_room
+
+ @cached(iterable=True)
+ async def _get_partial_state_servers_at_join(
+ self, room_id: str
+ ) -> AbstractSet[str]:
+ return frozenset(
+ await self.db_pool.simple_select_onecol(
+ "partial_state_rooms_servers",
+ keyvalues={"room_id": room_id},
+ retcol="server_name",
+ desc="get_partial_state_servers_at_join",
+ )
)
async def get_partial_state_room_resync_info(
@@ -1238,11 +1266,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# partial-joined between the two SELECTs, but this is unlikely to happen
# in practice.)
continue
- entry.servers_in_room.append(server_name)
+ entry.servers_in_room.add(server_name)
return room_servers
- @cached()
+ @cached(max_entries=10000)
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
@@ -1261,6 +1289,27 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return entry is not None
+ @cachedList(cached_method_name="is_partial_state_room", list_name="room_ids")
+ async def is_partial_state_room_batched(
+ self, room_ids: StrCollection
+ ) -> Mapping[str, bool]:
+ """Checks if the given rooms have partial state.
+
+ Returns true for "partial-state" rooms, which means that the state
+ at events in the room, and `current_state_events`, may not yet be
+ complete.
+ """
+
+ rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
+ table="partial_state_rooms",
+ column="room_id",
+ iterable=room_ids,
+ retcols=("room_id",),
+ desc="is_partial_state_room_batched",
+ )
+ partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
+ return {room_id: room_id in partial_state_rooms for room_id in room_ids}
+
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
self, room_id: str
) -> Tuple[str, int]:
@@ -1277,18 +1326,49 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
return result["join_event_id"], result["device_lists_stream_id"]
- def get_un_partial_stated_rooms_token(self) -> int:
- # TODO(faster_joins, multiple writers): This is inappropriate if there
- # are multiple writers because workers that don't write often will
- # hold all readers up.
- # (See `MultiWriterIdGenerator.get_persisted_upto_position` for an
- # explanation.)
- return self._un_partial_stated_rooms_stream_id_gen.get_current_token()
+ def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
+ return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
+ instance_name
+ )
+
+ async def get_un_partial_stated_rooms_between(
+ self, last_id: int, current_id: int, room_ids: Collection[str]
+ ) -> Set[str]:
+ """Get all rooms that got un partial stated between `last_id` exclusive and
+ `current_id` inclusive.
+
+ Returns:
+ The list of room ids.
+ """
+
+ if last_id == current_id:
+ return set()
+
+ def _get_un_partial_stated_rooms_between_txn(
+ txn: LoggingTransaction,
+ ) -> Set[str]:
+ sql = """
+ SELECT DISTINCT room_id FROM un_partial_stated_room_stream
+ WHERE ? < stream_id AND stream_id <= ? AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+
+ txn.execute(sql + clause, [last_id, current_id] + args)
+
+ return {r[0] for r in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_un_partial_stated_rooms_between",
+ _get_un_partial_stated_rooms_between_txn,
+ )
async def get_un_partial_stated_rooms_from_stream(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
- """Get updates for caches replication stream.
+ """Get updates for un partial stated rooms replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
@@ -1876,7 +1956,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
async def store_partial_state_room(
self,
room_id: str,
- servers: Collection[str],
+ servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@@ -1891,11 +1971,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
room_id: the ID of the room
- servers: other servers known to be in the room
+ servers: other servers known to be in the room. must include `joined_via`.
device_lists_stream_id: the device_lists stream ID at the time when we first
joined the room.
joined_via: the server name we requested a partial join from.
"""
+ assert joined_via in servers
+
await self.db_pool.runInteraction(
"store_partial_state_room",
self._store_partial_state_room_txn,
@@ -1909,7 +1991,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
txn: LoggingTransaction,
room_id: str,
- servers: Collection[str],
+ servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@@ -1932,7 +2014,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
+ txn, self._get_partial_state_servers_at_join, (room_id,)
)
async def write_partial_state_rooms_join_event_id(
@@ -2139,7 +2221,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
start: int,
limit: int,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
user_id: Optional[str] = None,
room_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
@@ -2148,8 +2230,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
start: event offset to begin the query from
limit: number of rows to retrieve
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`)
+ direction: Whether to fetch the most recent first (backwards) or the
+ oldest first (forwards)
user_id: search for user_id. Ignored if user_id is None
room_id: search for room_id. Ignored if room_id is None
Returns:
@@ -2171,7 +2253,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
filters.append("er.room_id LIKE ?")
args.extend(["%" + room_id + "%"])
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
@@ -2295,16 +2377,16 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
(room_id,),
)
- async def clear_partial_state_room(self, room_id: str) -> bool:
+ async def clear_partial_state_room(self, room_id: str) -> Optional[int]:
"""Clears the partial state flag for a room.
Args:
room_id: The room whose partial state flag is to be cleared.
Returns:
- `True` if the partial state flag has been cleared successfully.
+ The corresponding stream id for the un-partial-stated rooms stream.
- `False` if the partial state flag could not be cleared because the room
+ `None` if the partial state flag could not be cleared because the room
still contains events with partial state.
"""
try:
@@ -2315,7 +2397,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id,
un_partial_state_room_stream_id,
)
- return True
+ return un_partial_state_room_stream_id
except self.db_pool.engine.module.IntegrityError as e:
# Assume that any `IntegrityError`s are due to partial state events.
logger.info(
@@ -2323,7 +2405,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id,
e,
)
- return False
+ return None
def _clear_partial_state_room_txn(
self,
@@ -2343,7 +2425,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
+ txn, self._get_partial_state_servers_at_join, (room_id,)
)
DatabasePool.simple_insert_txn(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index f02c1d7e..ea6a5e2f 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from itertools import chain
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Collection,
Dict,
FrozenSet,
@@ -47,7 +49,13 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ PersistedEventPosition,
+ StateMap,
+ StrCollection,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -385,7 +393,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self,
user_id: str,
membership_list: Collection[str],
- excluded_rooms: Optional[List[str]] = None,
+ excluded_rooms: StrCollection = (),
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -412,10 +420,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
# Now we filter out forgotten and excluded rooms
- rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+ rooms_to_exclude = await self.get_forgotten_rooms_for_user(user_id)
if excluded_rooms is not None:
- rooms_to_exclude.update(set(excluded_rooms))
+ # Take a copy to avoid mutating the in-cache set
+ rooms_to_exclude = set(rooms_to_exclude)
+ rooms_to_exclude.update(excluded_rooms)
return [room for room in rooms if room.room_id not in rooms_to_exclude]
@@ -1122,12 +1132,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
+ #
+ # We need to fetch all hosts joined to the room according to `state` by
+ # inspecting all join memberships in `state`. However, if the `state` is
+ # relatively recent then many of its events are likely to be held in
+ # the current state of the room, which is easily available and likely
+ # cached.
+ #
+ # We therefore compute the set of `state` events not in the
+ # current state and only fetch those.
+ current_memberships = (
+ await self._get_approximate_current_memberships_in_room(room_id)
+ )
+ unknown_state_events = {}
+ joined_users_in_current_state = []
+
+ for (type, state_key), event_id in state.items():
+ if event_id not in current_memberships:
+ unknown_state_events[type, state_key] = event_id
+ elif current_memberships[event_id] == Membership.JOIN:
+ joined_users_in_current_state.append(state_key)
+
joined_user_ids = await self.get_joined_user_ids_from_state(
- room_id, state
+ room_id, unknown_state_events
)
cache.hosts_to_joined_users = {}
- for user_id in joined_user_ids:
+ for user_id in chain(joined_user_ids, joined_users_in_current_state):
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
@@ -1138,6 +1169,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return frozenset(cache.hosts_to_joined_users)
+ async def _get_approximate_current_memberships_in_room(
+ self, room_id: str
+ ) -> Mapping[str, Optional[str]]:
+ """Build a map from event id to membership, for all events in the current state.
+
+ The event ids of non-memberships events (e.g. `m.room.power_levels`) are present
+ in the result, mapped to values of `None`.
+
+ The result is approximate for partially-joined rooms. It is fully accurate
+ for fully-joined rooms.
+ """
+
+ rows = await self.db_pool.simple_select_list(
+ "current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "membership"),
+ desc="has_completed_background_updates",
+ )
+ return {row["event_id"]: row["membership"] for row in rows}
+
@cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache()
@@ -1169,7 +1220,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
- async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
+ async def get_forgotten_rooms_for_user(self, user_id: str) -> AbstractSet[str]:
"""Gets all rooms the user has forgotten.
Args:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index c801a93b..ba325d39 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
@@ -24,6 +24,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.opentracing import trace
+from synapse.replication.tcp.streams import UnPartialStatedEventStream
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -80,6 +82,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
+ self._instance_name: str = hs.get_instance_name()
+
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
+ if stream_name == UnPartialStatedEventStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedEventStreamRow)
+ self._get_state_group_for_event.invalidate((row.event_id,))
+ self.is_partial_state_event.invalidate((row.event_id,))
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
async def get_room_version(self, room_id: str) -> RoomVersion:
"""Get the room_version of a given room
@@ -404,18 +422,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
context: EventContext,
) -> None:
"""Update the state group for a partial state event"""
- await self.db_pool.runInteraction(
- "update_state_for_partial_state_event",
- self._update_state_for_partial_state_event_txn,
- event,
- context,
- )
+ async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id:
+ await self.db_pool.runInteraction(
+ "update_state_for_partial_state_event",
+ self._update_state_for_partial_state_event_txn,
+ event,
+ context,
+ un_partial_state_event_stream_id,
+ )
def _update_state_for_partial_state_event_txn(
self,
txn: LoggingTransaction,
event: EventBase,
context: EventContext,
+ un_partial_state_event_stream_id: int,
) -> None:
# we shouldn't have any outliers here
assert not event.internal_metadata.is_outlier()
@@ -436,7 +457,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# the event may now be rejected where it was not before, or vice versa,
# in which case we need to update the rejected flags.
- if bool(context.rejected) != (event.rejected_reason is not None):
+ rejection_status_changed = bool(context.rejected) != (
+ event.rejected_reason is not None
+ )
+ if rejection_status_changed:
self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
self.db_pool.simple_delete_one_txn(
@@ -445,8 +469,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
keyvalues={"event_id": event.event_id},
)
- # TODO(faster_joins): need to do something about workers here
- # https://github.com/matrix-org/synapse/issues/12994
txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,))
txn.call_after(
self._get_state_group_for_event.prefill,
@@ -454,6 +476,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_group,
)
+ self.db_pool.simple_insert_txn(
+ txn,
+ "un_partial_stated_event_stream",
+ {
+ "stream_id": un_partial_state_event_stream_id,
+ "instance_name": self._instance_name,
+ "event_id": event.event_id,
+ "rejection_status_changed": rejection_status_changed,
+ },
+ )
+ txn.call_after(self.hs.get_notifier().on_new_replication_data)
+
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 356d4ca7..d7b7d0c3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -22,13 +22,14 @@ from typing_extensions import Counter
from twisted.internet.defer import DeferredLock
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.storage.databases.main.events_worker import InvalidEventError
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -554,7 +555,17 @@ class StatsStore(StateDeltasStore):
"get_initial_state_for_room", _fetch_current_state_stats
)
- state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
+ try:
+ state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
+ except InvalidEventError as e:
+ # If an exception occurs fetching events then the room is broken;
+ # skip process it to avoid being stuck on a room.
+ logger.warning(
+ "Failed to fetch events for room %s, skipping stats calculation: %r.",
+ room_id,
+ e,
+ )
+ return
room_state: Dict[str, Union[None, bool, str]] = {
"join_rules": None,
@@ -652,7 +663,7 @@ class StatsStore(StateDeltasStore):
from_ts: Optional[int] = None,
until_ts: Optional[int] = None,
order_by: Optional[str] = UserSortOrder.USER_ID.value,
- direction: Optional[str] = "f",
+ direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
@@ -703,7 +714,7 @@ class StatsStore(StateDeltasStore):
500, "Incorrect value for order_by provided: %s" % order_by
)
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index cc27ec38..818c4618 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -55,6 +55,7 @@ from typing_extensions import Literal
from twisted.internet import defer
+from synapse.api.constants import Direction
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -67,7 +68,7 @@ from synapse.storage.database import (
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
@@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
-
# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
@@ -104,7 +104,7 @@ class _EventsAround:
def generate_pagination_where_clause(
- direction: str,
+ direction: Direction,
column_names: Tuple[str, str],
from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]],
@@ -130,27 +130,26 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
- direction: Whether we're paginating backwards("b") or forwards ("f").
+ direction: Whether we're paginating backwards or forwards.
column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive
- minimum bound if direction is "f", and an inclusive maximum bound if
- direction is "b".
+ minimum bound if direction is forwards, and an inclusive maximum bound if
+ direction is backwards.
to_token: The endpoint point for the pagination. This is an inclusive
- maximum bound if direction is "f", and an exclusive minimum bound if
- direction is "b".
+ maximum bound if direction is forwards, and an exclusive minimum bound if
+ direction is backwards.
engine: The database engine to generate the clauses for
Returns:
The sql expression
"""
- assert direction in ("b", "f")
where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
- bound=">=" if direction == "b" else "<",
+ bound=">=" if direction == Direction.BACKWARDS else "<",
column_names=column_names,
values=from_token,
engine=engine,
@@ -160,7 +159,7 @@ def generate_pagination_where_clause(
if to_token:
where_clause.append(
_make_generic_sql_bound(
- bound="<" if direction == "b" else ">=",
+ bound="<" if direction == Direction.BACKWARDS else ">=",
column_names=column_names,
values=to_token,
engine=engine,
@@ -170,6 +169,104 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
+def generate_pagination_bounds(
+ direction: Direction,
+ from_token: Optional[RoomStreamToken],
+ to_token: Optional[RoomStreamToken],
+) -> Tuple[
+ str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]]
+]:
+ """
+ Generate a start and end point for this page of events.
+
+ Args:
+ direction: Whether pagination is going forwards or backwards.
+ from_token: The token to start pagination at, or None to start at the first value.
+ to_token: The token to end pagination at, or None to not limit the end point.
+
+ Returns:
+ A three tuple of:
+
+ ASC or DESC for sorting of the query.
+
+ The starting position as a tuple of ints representing
+ (topological position, stream position) or None if no from_token was
+ provided. The topological position may be None for live tokens.
+
+ The end position in the same format as the starting position, or None
+ if no to_token was provided.
+ """
+
+ # Tokens really represent positions between elements, but we use
+ # the convention of pointing to the event before the gap. Hence
+ # we have a bit of asymmetry when it comes to equalities.
+ if direction == Direction.BACKWARDS:
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ from_bound: Optional[Tuple[Optional[int], int]] = None
+ if from_token:
+ if from_token.topological is not None:
+ from_bound = from_token.as_historical_tuple()
+ elif direction == Direction.BACKWARDS:
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound: Optional[Tuple[Optional[int], int]] = None
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == Direction.BACKWARDS:
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
+ return order, from_bound, to_bound
+
+
+def generate_next_token(
+ direction: Direction, last_topo_ordering: int, last_stream_ordering: int
+) -> RoomStreamToken:
+ """
+ Generate the next room stream token based on the currently returned data.
+
+ Args:
+ direction: Whether pagination is going forwards or backwards.
+ last_topo_ordering: The last topological ordering being returned.
+ last_stream_ordering: The last stream ordering being returned.
+
+ Returns:
+ A new RoomStreamToken to return to the client.
+ """
+ if direction == Direction.BACKWARDS:
+ # Tokens are positions between events.
+ # This token points *after* the last event in the chunk.
+ # We need it to point to the event before it in the chunk
+ # when we are going backwards so we subtract one from the
+ # stream part.
+ last_stream_ordering -= 1
+ return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+
+
def _make_generic_sql_bound(
bound: str,
column_names: Tuple[str, str],
@@ -801,13 +898,66 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
before this stream ordering.
"""
- last_row = await self.get_room_event_before_stream_ordering(
- room_id=room_id,
- stream_ordering=end_token.stream,
+ def get_last_event_in_room_before_stream_ordering_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[str]:
+ # We need to handle the fact that the stream tokens can be vector
+ # clocks. We do this by getting all rows between the minimum and
+ # maximum stream ordering in the token, plus one row less than the
+ # minimum stream ordering. We then filter the results against the
+ # token and return the first row that matches.
+
+ sql = """
+ SELECT * FROM (
+ SELECT instance_name, stream_ordering, topological_ordering, event_id
+ FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE room_id = ?
+ AND ? < stream_ordering AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejections.event_id IS NULL
+ ORDER BY stream_ordering DESC
+ ) AS a
+ UNION
+ SELECT * FROM (
+ SELECT instance_name, stream_ordering, topological_ordering, event_id
+ FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE room_id = ?
+ AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejections.event_id IS NULL
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ ) AS b
+ """
+ txn.execute(
+ sql,
+ (
+ room_id,
+ end_token.stream,
+ end_token.get_max_stream_pos(),
+ room_id,
+ end_token.stream,
+ ),
+ )
+
+ for instance_name, stream_ordering, topological_ordering, event_id in txn:
+ if _filter_results(
+ lower_token=None,
+ upper_token=end_token,
+ instance_name=instance_name,
+ topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
+ ):
+ return event_id
+
+ return None
+
+ return await self.db_pool.runInteraction(
+ "get_last_event_in_room_before_stream_ordering",
+ get_last_event_in_room_before_stream_ordering_txn,
)
- if last_row:
- return last_row[2]
- return None
async def get_current_room_stream_token_for_room_id(
self, room_id: str
@@ -891,12 +1041,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
stream_key
"""
- sql = (
- "SELECT coalesce(MIN(topological_ordering), 0) FROM events"
- " WHERE room_id = ? AND stream_ordering >= ?"
- )
+ if isinstance(self.database_engine, PostgresEngine):
+ min_function = "LEAST"
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ min_function = "MIN"
+ else:
+ raise RuntimeError(f"Unknown database engine {self.database_engine}")
+
+ # This query used to be
+ # SELECT COALESCE(MIN(topological_ordering), 0) FROM events
+ # WHERE room_id = ? and events.stream_ordering >= {stream_key}
+ # which returns 0 if the stream_key is newer than any event in
+ # the room. That's not wrong, but it seems to interact oddly with backfill,
+ # requiring a second call to /messages to actually backfill from a remote
+ # homeserver.
+ #
+ # Instead, rollback the stream ordering to that after the most recent event in
+ # this room.
+ sql = f"""
+ WITH fallback(max_stream_ordering) AS (
+ SELECT MAX(stream_ordering)
+ FROM events
+ WHERE room_id = ?
+ )
+ SELECT COALESCE(MIN(topological_ordering), 0) FROM events
+ WHERE
+ room_id = ?
+ AND events.stream_ordering >= {min_function}(
+ ?,
+ (SELECT max_stream_ordering FROM fallback)
+ )
+ """
+
row = await self.db_pool.execute(
- "get_current_topological_token", None, sql, room_id, stream_key
+ "get_current_topological_token", None, sql, room_id, room_id, stream_key
)
return row[0][0] if row else 0
@@ -1022,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
before_token,
- direction="b",
+ direction=Direction.BACKWARDS,
limit=before_limit,
event_filter=event_filter,
)
@@ -1032,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
after_token,
- direction="f",
+ direction=Direction.FORWARDS,
limit=after_limit,
event_filter=event_filter,
)
@@ -1195,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
@@ -1206,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
from_token: The token used to stream from
to_token: A token which if given limits the results to only those before
- direction: Either 'b' or 'f' to indicate whether we are paginating
- forwards or backwards from `from_key`.
+ direction: Indicates whether we are paginating forwards or backwards
+ from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to
those that match the filter.
@@ -1219,47 +1397,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
`to_token`), or `limit` is zero.
"""
- # Tokens really represent positions between elements, but we use
- # the convention of pointing to the event before the gap. Hence
- # we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- # The bounds for the stream tokens are complicated by the fact
- # that we need to handle the instance_map part of the tokens. We do this
- # by fetching all events between the min stream token and the maximum
- # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
- # then filtering the results.
- if from_token.topological is not None:
- from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
- elif direction == "b":
- from_bound = (
- None,
- from_token.get_max_stream_pos(),
- )
- else:
- from_bound = (
- None,
- from_token.stream,
- )
- to_bound: Optional[Tuple[Optional[int], int]] = None
- if to_token:
- if to_token.topological is not None:
- to_bound = to_token.as_historical_tuple()
- elif direction == "b":
- to_bound = (
- None,
- to_token.stream,
- )
- else:
- to_bound = (
- None,
- to_token.get_max_stream_pos(),
- )
+ order, from_bound, to_bound = generate_pagination_bounds(
+ direction, from_token, to_token
+ )
bounds = generate_pagination_where_clause(
direction=direction,
@@ -1346,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
_EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
- lower_token=to_token if direction == "b" else from_token,
- upper_token=from_token if direction == "b" else to_token,
+ lower_token=to_token
+ if direction == Direction.BACKWARDS
+ else from_token,
+ upper_token=from_token
+ if direction == Direction.BACKWARDS
+ else to_token,
instance_name=instance_name,
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
@@ -1355,16 +1501,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
][:limit]
if rows:
- topo = rows[-1].topological_ordering
- token = rows[-1].stream_ordering
- if direction == "b":
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- token -= 1
- next_token = RoomStreamToken(topo, token)
+ assert rows[-1].topological_ordering is not None
+ next_token = generate_next_token(
+ direction, rows[-1].topological_ordering, rows[-1].stream_ordering
+ )
else:
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
@@ -1377,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
@@ -1387,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
from_key: The token used to stream from
to_key: A token which if given limits the results to only those before
- direction: Either 'b' or 'f' to indicate whether we are paginating
- forwards or backwards from `from_key`.
+ direction: Indicates whether we are paginating forwards or backwards
+ from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter.
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b0f5de67..d5500cdd 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,7 +17,8 @@
import logging
from typing import Any, Dict, Iterable, List, Tuple, cast
-from synapse.replication.tcp.streams import TagAccountDataStream
+from synapse.api.constants import AccountDataTypes
+from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str]], int, bool]:
"""Get updates for tags replication stream.
Args:
@@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
The token returned can be used in a subsequent call to this
function to get further updatees.
- The updates are a list of 2-tuples of stream ID and the row data
+ The updates are a list of tuples of stream ID, user ID and room ID
"""
if last_id == current_id:
@@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
"get_all_updated_tags", get_all_updated_tags_txn
)
- def get_tag_content(
- txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
- ) -> List[Tuple[int, Tuple[str, str, str]]]:
- sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
- results = []
- for stream_id, user_id, room_id in tag_ids:
- txn.execute(sql, (user_id, room_id))
- tags = []
- for tag, content in txn:
- tags.append(json_encoder.encode(tag) + ":" + content)
- tag_json = "{" + ",".join(tags) + "}"
- results.append((stream_id, (user_id, room_id, tag_json)))
-
- return results
-
- batch_size = 50
- results = []
- for i in range(0, len(tag_ids), batch_size):
- tags = await self.db_pool.runInteraction(
- "get_all_updated_tag_content",
- get_tag_content,
- tag_ids[i : i + batch_size],
- )
- results.extend(tags)
-
limited = False
upto_token = current_id
- if len(results) >= limit:
- upto_token = results[-1][0]
+ if len(tag_ids) >= limit:
+ upto_token = tag_ids[-1][0]
limited = True
- return results, upto_token, limited
+ return tag_ids, upto_token, limited
async def get_updated_tags(
self, user_id: str, stream_id: int
@@ -299,11 +275,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
token: int,
rows: Iterable[Any],
) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
+ if stream_name == AccountDataStream.NAME:
for row in rows:
- self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ if row.data_type == AccountDataTypes.TAG:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(
+ row.user_id, token
+ )
super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index f8c6877e..6b33d809 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
+from synapse.api.constants import Direction
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json
from synapse.storage.database import (
@@ -496,7 +497,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: int,
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
@@ -518,7 +519,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) -> Tuple[List[JsonDict], int]:
order_by_column = DestinationSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
@@ -550,7 +551,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
async def get_destination_rooms_paginate(
- self, destination: str, start: int, limit: int, direction: str = "f"
+ self,
+ destination: str,
+ start: int,
+ limit: int,
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the
@@ -569,7 +574,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"