summaryrefslogtreecommitdiff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-02-08 23:16:11 +0100
committerAndrej Shadura <andrewsh@debian.org>2022-02-08 23:16:11 +0100
commitc67eee54fd48a2e25a91e34498ad01c5f902d53c (patch)
tree25c0fa9e55e01bf86dae6a0395f20d868364eddc /synapse/storage/databases
parent0027c02b907486b437772b1cdecbea14d18597d9 (diff)
New upstream version 1.52.0
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/account_data.py83
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/purge_events.py1
-rw-r--r--synapse/storage/databases/main/relations.py65
-rw-r--r--synapse/storage/databases/main/signatures.py54
-rw-r--r--synapse/storage/databases/main/stream.py22
-rw-r--r--synapse/storage/databases/main/transactions.py48
9 files changed, 209 insertions, 75 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef475e18..5bfa408f 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -26,6 +26,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
@@ -44,7 +45,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -158,9 +159,9 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"get_account_data_for_user", get_account_data_for_user_txn
)
- @cached(num_args=2, max_entries=5000)
+ @cached(num_args=2, max_entries=5000, tree=True)
async def get_global_account_data_by_type_for_user(
- self, data_type: str, user_id: str
+ self, user_id: str, data_type: str
) -> Optional[JsonDict]:
"""
Returns:
@@ -179,7 +180,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
else:
return None
- @cached(num_args=2)
+ @cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
@@ -210,7 +211,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
- @cached(num_args=3, max_entries=5000)
+ @cached(num_args=3, max_entries=5000, tree=True)
async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
@@ -392,7 +393,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
- (row.data_type, row.user_id)
+ (row.user_id, row.data_type)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
@@ -476,7 +477,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
- (account_data_type, user_id)
+ (user_id, account_data_type)
)
return self._account_data_id_gen.get_current_token()
@@ -546,6 +547,74 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+ async def purge_account_data_for_user(self, user_id: str) -> None:
+ """
+ Removes the account data for a user.
+
+ This is intended to be used upon user deactivation and also removes any
+ derived information from account data (e.g. push rules and ignored users).
+
+ Args:
+ user_id: The user ID to remove data for.
+ """
+
+ def purge_account_data_for_user_txn(txn: LoggingTransaction) -> None:
+ # Purge from the primary account_data tables.
+ self.db_pool.simple_delete_txn(
+ txn, table="account_data", keyvalues={"user_id": user_id}
+ )
+
+ self.db_pool.simple_delete_txn(
+ txn, table="room_account_data", keyvalues={"user_id": user_id}
+ )
+
+ # Purge from ignored_users where this user is the ignorer.
+ # N.B. We don't purge where this user is the ignoree, because that
+ # interferes with other users' account data.
+ # It's also not this user's data to delete!
+ self.db_pool.simple_delete_txn(
+ txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
+ )
+
+ # Remove the push rules
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules", keyvalues={"user_name": user_id}
+ )
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules_enable", keyvalues={"user_name": user_id}
+ )
+ self.db_pool.simple_delete_txn(
+ txn, table="push_rules_stream", keyvalues={"user_id": user_id}
+ )
+
+ # Invalidate caches as appropriate
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_room_and_type, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_global_account_data_by_type_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_account_data_for_room, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_push_rules_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_push_rules_enabled_for_user, (user_id,)
+ )
+ # This user might be contained in the ignored_by cache for other users,
+ # so we have to invalidate it all.
+ self._invalidate_all_cache_and_stream(txn, self.ignored_by)
+
+ await self.db_pool.runInteraction(
+ "purge_account_data_for_user_txn",
+ purge_account_data_for_user_txn,
+ )
+
class AccountDataStore(AccountDataWorkerStore):
pass
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 92c95a41..2bb52884 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -384,7 +384,7 @@ class ApplicationServiceTransactionWorkerStore(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
- events = await self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids, get_prev_content=True)
return upper_bound, events
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index a556f17d..ca71f073 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -65,7 +65,7 @@ class _NoChainCoverIndex(Exception):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
-class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1ae1ebe1..b7554154 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1389,6 +1389,8 @@ class PersistEventsStore:
"received_ts",
"sender",
"contains_url",
+ "state_key",
+ "rejection_reason",
),
values=(
(
@@ -1405,8 +1407,10 @@ class PersistEventsStore:
self._clock.time_msec(),
event.sender,
"url" in event.content and isinstance(event.content["url"], str),
+ event.get_state_key(),
+ context.rejected or None,
)
- for event, _ in events_and_contexts
+ for event, context in events_and_contexts
),
)
@@ -1456,6 +1460,7 @@ class PersistEventsStore:
for event, context in events_and_contexts:
if context.rejected:
# Insert the event_id into the rejections table
+ # (events.rejection_reason has already been done)
self._store_rejections_txn(txn, event.event_id, context.rejected)
to_remove.add(event)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 91b0576b..e87a8fb8 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -390,7 +390,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_search",
"events",
"group_rooms",
- "public_room_list_stream",
"receipts_graph",
"receipts_linearized",
"room_aliases",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c..37468a51 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
# limitations under the License.
import logging
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Iterable,
- List,
- Optional,
- Tuple,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
import attr
from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ latest_event: EventBase
+ count: int
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -60,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._msc1849_enabled = hs.config.experimental.msc1849_enabled
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
@cached(tree=True)
@@ -585,7 +599,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +630,24 @@ class RelationsWorkerStore(SQLBaseStore):
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
- aggregations: Dict[str, Any] = {}
+ aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
- aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+ aggregations.annotations = annotations.to_dict()
references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
- aggregations[RelationTypes.REFERENCE] = references.to_dict()
+ aggregations.references = references.to_dict()
edit = None
if event.type == EventTypes.Message:
edit = await self.get_applicable_edit(event_id, room_id)
if edit:
- aggregations[RelationTypes.REPLACE] = edit
+ aggregations.replace = edit
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
@@ -644,11 +658,11 @@ class RelationsWorkerStore(SQLBaseStore):
event_id, room_id, user_id
)
if latest_thread_event:
- aggregations[RelationTypes.THREAD] = {
- "latest_event": latest_thread_event,
- "count": thread_count,
- "current_user_participated": participated,
- }
+ aggregations.thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ current_user_participated=participated,
+ )
# Store the bundled aggregations in the event metadata for later use.
return aggregations
@@ -657,7 +671,7 @@ class RelationsWorkerStore(SQLBaseStore):
self,
events: Iterable[EventBase],
user_id: str,
- ) -> Dict[str, Dict[str, Any]]:
+ ) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
@@ -668,15 +682,12 @@ class RelationsWorkerStore(SQLBaseStore):
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
- # If bundled aggregations are disabled, nothing to do.
- if not self._msc1849_enabled:
- return {}
# TODO Parallelize.
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result is not None:
+ if event_result:
results[event.event_id] = event_result
return results
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 3201623f..0518b8b9 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Iterable, List, Tuple
+from typing import Collection, Dict, List, Tuple
from unpaddedbase64 import encode_base64
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.types import Cursor
+from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.storage.databases.main.events_worker import (
+ EventRedactBehaviour,
+ EventsWorkerStore,
+)
from synapse.util.caches.descriptors import cached, cachedList
-class SignatureWorkerStore(SQLBaseStore):
+class SignatureWorkerStore(EventsWorkerStore):
@cached()
def get_event_reference_hash(self, event_id):
# This is a dummy function to allow get_event_reference_hashes
@@ -32,7 +35,7 @@ class SignatureWorkerStore(SQLBaseStore):
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
async def get_event_reference_hashes(
- self, event_ids: Iterable[str]
+ self, event_ids: Collection[str]
) -> Dict[str, Dict[str, bytes]]:
"""Get all hashes for given events.
@@ -41,18 +44,27 @@ class SignatureWorkerStore(SQLBaseStore):
Returns:
A mapping of event ID to a mapping of algorithm to hash.
+ Returns an empty dict for a given event id if that event is unknown.
"""
+ events = await self.get_events(
+ event_ids,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ allow_rejected=True,
+ )
- def f(txn):
- return {
- event_id: self._get_event_reference_hashes_txn(txn, event_id)
- for event_id in event_ids
- }
+ hashes: Dict[str, Dict[str, bytes]] = {}
+ for event_id in event_ids:
+ event = events.get(event_id)
+ if event is None:
+ hashes[event_id] = {}
+ else:
+ ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+ hashes[event_id] = {ref_alg: ref_hash_bytes}
- return await self.db_pool.runInteraction("get_event_reference_hashes", f)
+ return hashes
async def add_event_hashes(
- self, event_ids: Iterable[str]
+ self, event_ids: Collection[str]
) -> List[Tuple[str, Dict[str, str]]]:
"""
@@ -70,24 +82,6 @@ class SignatureWorkerStore(SQLBaseStore):
return list(encoded_hashes.items())
- def _get_event_reference_hashes_txn(
- self, txn: Cursor, event_id: str
- ) -> Dict[str, bytes]:
- """Get all the hashes for a given PDU.
- Args:
- txn:
- event_id: Id for the Event.
- Returns:
- A mapping of algorithm -> hash.
- """
- query = (
- "SELECT algorithm, hash"
- " FROM event_reference_hashes"
- " WHERE event_id = ?"
- )
- txn.execute(query, (event_id,))
- return {k: v for k, v in txn}
-
class SignatureStore(SignatureWorkerStore):
"""Persistence for event signatures and hashes"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 319464b1..a898f847 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -81,6 +81,14 @@ class _EventDictReturn:
stream_ordering: int
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventsAround:
+ events_before: List[EventBase]
+ events_after: List[EventBase]
+ start: RoomStreamToken
+ end: RoomStreamToken
+
+
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
before_limit: int,
after_limit: int,
event_filter: Optional[Filter] = None,
- ) -> dict:
+ ) -> _EventsAround:
"""Retrieve events and pagination tokens around a given event in a
room.
"""
@@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
list(results["after"]["event_ids"]), get_prev_content=True
)
- return {
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- }
+ return _EventsAround(
+ events_before=events_before,
+ events_after=events_after,
+ start=results["before"]["token"],
+ end=results["after"]["token"],
+ )
def _get_events_around_txn(
self,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 4b78b4d0..ba79e19f 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -561,6 +561,54 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
"get_destinations_paginate_txn", get_destinations_paginate_txn
)
+ async def get_destination_rooms_paginate(
+ self, destination: str, start: int, limit: int, direction: str = "f"
+ ) -> Tuple[List[JsonDict], int]:
+ """Function to retrieve a paginated list of destination's rooms.
+ This will return a json list of rooms and the
+ total number of rooms.
+
+ Args:
+ destination: the destination to query
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ direction: sort ascending or descending by room_id
+ Returns:
+ A tuple of a dict of rooms and a count of total rooms.
+ """
+
+ def get_destination_rooms_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT COUNT(*) as total_rooms
+ FROM destination_rooms
+ WHERE destination = ?
+ """
+ txn.execute(sql, [destination])
+ count = cast(Tuple[int], txn.fetchone())[0]
+
+ rooms = self.db_pool.simple_select_list_paginate_txn(
+ txn=txn,
+ table="destination_rooms",
+ orderby="room_id",
+ start=start,
+ limit=limit,
+ retcols=("room_id", "stream_ordering"),
+ order_direction=order,
+ )
+ return rooms, count
+
+ return await self.db_pool.runInteraction(
+ "get_destination_rooms_paginate_txn", get_destination_rooms_paginate_txn
+ )
+
async def is_destination_known(self, destination: str) -> bool:
"""Check if a destination is known to the server."""
result = await self.db_pool.simple_select_one_onecol(