summaryrefslogtreecommitdiff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py65
1 files changed, 38 insertions, 27 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c..37468a51 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
# limitations under the License.
import logging
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Iterable,
- List,
- Optional,
- Tuple,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
import attr
from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ latest_event: EventBase
+ count: int
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -60,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._msc1849_enabled = hs.config.experimental.msc1849_enabled
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
@cached(tree=True)
@@ -585,7 +599,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +630,24 @@ class RelationsWorkerStore(SQLBaseStore):
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
- aggregations: Dict[str, Any] = {}
+ aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
- aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+ aggregations.annotations = annotations.to_dict()
references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
- aggregations[RelationTypes.REFERENCE] = references.to_dict()
+ aggregations.references = references.to_dict()
edit = None
if event.type == EventTypes.Message:
edit = await self.get_applicable_edit(event_id, room_id)
if edit:
- aggregations[RelationTypes.REPLACE] = edit
+ aggregations.replace = edit
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
@@ -644,11 +658,11 @@ class RelationsWorkerStore(SQLBaseStore):
event_id, room_id, user_id
)
if latest_thread_event:
- aggregations[RelationTypes.THREAD] = {
- "latest_event": latest_thread_event,
- "count": thread_count,
- "current_user_participated": participated,
- }
+ aggregations.thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ current_user_participated=participated,
+ )
# Store the bundled aggregations in the event metadata for later use.
return aggregations
@@ -657,7 +671,7 @@ class RelationsWorkerStore(SQLBaseStore):
self,
events: Iterable[EventBase],
user_id: str,
- ) -> Dict[str, Dict[str, Any]]:
+ ) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
@@ -668,15 +682,12 @@ class RelationsWorkerStore(SQLBaseStore):
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
- # If bundled aggregations are disabled, nothing to do.
- if not self._msc1849_enabled:
- return {}
# TODO Parallelize.
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result is not None:
+ if event_result:
results[event.event_id] = event_result
return results