summaryrefslogtreecommitdiff
path: root/synapse/storage/state.py
blob: b5ba1560d139553b09b85ef7eaa2684bd64c5876 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
    TYPE_CHECKING,
    Awaitable,
    Collection,
    Dict,
    Iterable,
    List,
    Mapping,
    Optional,
    Set,
    Tuple,
    TypeVar,
)

import attr
from frozendict import frozendict

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateKey, StateMap

if TYPE_CHECKING:
    from typing import FrozenSet  # noqa: used within quoted type hint; flake8 sad

    from synapse.server import HomeServer
    from synapse.storage.databases import Databases

logger = logging.getLogger(__name__)

# Used for generic functions below
T = TypeVar("T")


@attr.s(slots=True, frozen=True)
class StateFilter:
    """A filter used when querying for state.

    Attributes:
        types: Map from type to set of state keys (or None). This specifies
            which state_keys for the given type to fetch from the DB. If None
            then all events with that type are fetched. If the set is empty
            then no events with that type are fetched.
        include_others: Whether to fetch events with types that do not
            appear in `types`.
    """

    types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
    include_others = attr.ib(default=False, type=bool)

    def __attrs_post_init__(self):
        # If `include_others` is set we canonicalise the filter by removing
        # wildcards from the types dictionary
        if self.include_others:
            # this is needed to work around the fact that StateFilter is frozen
            object.__setattr__(
                self,
                "types",
                frozendict({k: v for k, v in self.types.items() if v is not None}),
            )

    @staticmethod
    def all() -> "StateFilter":
        """Creates a filter that fetches everything.

        Returns:
            The new state filter.
        """
        return StateFilter(types=frozendict(), include_others=True)

    @staticmethod
    def none() -> "StateFilter":
        """Creates a filter that fetches nothing.

        Returns:
            The new state filter.
        """
        return StateFilter(types=frozendict(), include_others=False)

    @staticmethod
    def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
        """Creates a filter that only fetches the given types

        Args:
            types: A list of type and state keys to fetch. A state_key of None
                fetches everything for that type

        Returns:
            The new state filter.
        """
        type_dict: Dict[str, Optional[Set[str]]] = {}
        for typ, s in types:
            if typ in type_dict:
                if type_dict[typ] is None:
                    continue

            if s is None:
                type_dict[typ] = None
                continue

            type_dict.setdefault(typ, set()).add(s)  # type: ignore

        return StateFilter(
            types=frozendict(
                (k, frozenset(v) if v is not None else None)
                for k, v in type_dict.items()
            )
        )

    @staticmethod
    def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
        """Creates a filter that returns all non-member events, plus the member
        events for the given users

        Args:
            members: Set of user IDs

        Returns:
            The new state filter
        """
        return StateFilter(
            types=frozendict({EventTypes.Member: frozenset(members)}),
            include_others=True,
        )

    @staticmethod
    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
        """
        Returns a (frozen) StateFilter with the same contents as the parameters
        specified here, which can be made of mutable types.
        """
        types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
        for state_types, state_keys in types.items():
            if state_keys is not None:
                types_with_frozen_values[state_types] = frozenset(state_keys)
            else:
                types_with_frozen_values[state_types] = None

        return StateFilter(
            frozendict(types_with_frozen_values), include_others=include_others
        )

    def return_expanded(self) -> "StateFilter":
        """Creates a new StateFilter where type wild cards have been removed
        (except for memberships). The returned filter is a superset of the
        current one, i.e. anything that passes the current filter will pass
        the returned filter.

        This helps the caching as the DictionaryCache knows if it has *all* the
        state, but does not know if it has all of the keys of a particular type,
        which makes wildcard lookups expensive unless we have a complete cache.
        Hence, if we are doing a wildcard lookup, populate the cache fully so
        that we can do an efficient lookup next time.

        Note that since we have two caches, one for membership events and one for
        other events, we can be a bit more clever than simply returning
        `StateFilter.all()` if `has_wildcards()` is True.

        We return a StateFilter where:
            1. the list of membership events to return is the same
            2. if there is a wildcard that matches non-member events we
               return all non-member events

        Returns:
            The new state filter.
        """

        if self.is_full():
            # If we're going to return everything then there's nothing to do
            return self

        if not self.has_wildcards():
            # If there are no wild cards, there's nothing to do
            return self

        if EventTypes.Member in self.types:
            get_all_members = self.types[EventTypes.Member] is None
        else:
            get_all_members = self.include_others

        has_non_member_wildcard = self.include_others or any(
            state_keys is None
            for t, state_keys in self.types.items()
            if t != EventTypes.Member
        )

        if not has_non_member_wildcard:
            # If there are no non-member wild cards we can just return ourselves
            return self

        if get_all_members:
            # We want to return everything.
            return StateFilter.all()
        else:
            # We want to return all non-members, but only particular
            # memberships
            return StateFilter(
                types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
                include_others=True,
            )

    def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
        """Converts the filter to an SQL clause.

        For example:

            f = StateFilter.from_types([("m.room.create", "")])
            clause, args = f.make_sql_filter_clause()
            clause == "(type = ? AND state_key = ?)"
            args == ['m.room.create', '']


        Returns:
            The SQL string (may be empty) and arguments. An empty SQL string is
            returned when the filter matches everything (i.e. is "full").
        """

        where_clause = ""
        where_args: List[str] = []

        if self.is_full():
            return where_clause, where_args

        if not self.include_others and not self.types:
            # i.e. this is an empty filter, so we need to return a clause that
            # will match nothing
            return "1 = 2", []

        # First we build up a lost of clauses for each type/state_key combo
        clauses = []
        for etype, state_keys in self.types.items():
            if state_keys is None:
                clauses.append("(type = ?)")
                where_args.append(etype)
                continue

            for state_key in state_keys:
                clauses.append("(type = ? AND state_key = ?)")
                where_args.extend((etype, state_key))

        # This will match anything that appears in `self.types`
        where_clause = " OR ".join(clauses)

        # If we want to include stuff that's not in the types dict then we add
        # a `OR type NOT IN (...)` clause to the end.
        if self.include_others:
            if where_clause:
                where_clause += " OR "

            where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
            where_args.extend(self.types)

        return where_clause, where_args

    def max_entries_returned(self) -> Optional[int]:
        """Returns the maximum number of entries this filter will return if
        known, otherwise returns None.

        For example a simple state filter asking for `("m.room.create", "")`
        will return 1, whereas the default state filter will return None.

        This is used to bail out early if the right number of entries have been
        fetched.
        """
        if self.has_wildcards():
            return None

        return len(self.concrete_types())

    def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
        """Returns the state filtered with by this StateFilter.

        Args:
            state: The state map to filter

        Returns:
            The filtered state map.
            This is a copy, so it's safe to mutate.
        """
        if self.is_full():
            return dict(state_dict)

        filtered_state = {}
        for k, v in state_dict.items():
            typ, state_key = k
            if typ in self.types:
                state_keys = self.types[typ]
                if state_keys is None or state_key in state_keys:
                    filtered_state[k] = v
            elif self.include_others:
                filtered_state[k] = v

        return filtered_state

    def is_full(self) -> bool:
        """Whether this filter fetches everything or not

        Returns:
            True if the filter fetches everything.
        """
        return self.include_others and not self.types

    def has_wildcards(self) -> bool:
        """Whether the filter includes wildcards or is attempting to fetch
        specific state.

        Returns:
            True if the filter includes wildcards.
        """

        return self.include_others or any(
            state_keys is None for state_keys in self.types.values()
        )

    def concrete_types(self) -> List[Tuple[str, str]]:
        """Returns a list of concrete type/state_keys (i.e. not None) that
        will be fetched. This will be a complete list if `has_wildcards`
        returns False, but otherwise will be a subset (or even empty).

        Returns:
            A list of type/state_keys tuples.
        """
        return [
            (t, s)
            for t, state_keys in self.types.items()
            if state_keys is not None
            for s in state_keys
        ]

    def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
        """Return the filter split into two: one which assumes it's exclusively
        matching against member state, and one which assumes it's matching
        against non member state.

        This is useful due to the returned filters giving correct results for
        `is_full()`, `has_wildcards()`, etc, when operating against maps that
        either exclusively contain member events or only contain non-member
        events. (Which is the case when dealing with the member vs non-member
        state caches).

        Returns:
            The member and non member filters
        """

        if EventTypes.Member in self.types:
            state_keys = self.types[EventTypes.Member]
            if state_keys is None:
                member_filter = StateFilter.all()
            else:
                member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
        elif self.include_others:
            member_filter = StateFilter.all()
        else:
            member_filter = StateFilter.none()

        non_member_filter = StateFilter(
            types=frozendict(
                {k: v for k, v in self.types.items() if k != EventTypes.Member}
            ),
            include_others=self.include_others,
        )

        return member_filter, non_member_filter

    def _decompose_into_four_parts(
        self,
    ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
        """
        Decomposes this state filter into 4 constituent parts, which can be
        thought of as this:
            all? - minus_wildcards + plus_wildcards + plus_state_keys

        where
        * all represents ALL state
        * minus_wildcards represents entire state types to remove
        * plus_wildcards represents entire state types to add
        * plus_state_keys represents individual state keys to add

        See `recompose_from_four_parts` for the other direction of this
        correspondence.
        """
        is_all = self.include_others
        excluded_types: Set[str] = {t for t in self.types if is_all}
        wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
        concrete_keys: Set[StateKey] = set(self.concrete_types())

        return (is_all, excluded_types), (wildcard_types, concrete_keys)

    @staticmethod
    def _recompose_from_four_parts(
        all_part: bool,
        minus_wildcards: Set[str],
        plus_wildcards: Set[str],
        plus_state_keys: Set[StateKey],
    ) -> "StateFilter":
        """
        Recomposes a state filter from 4 parts.

        See `decompose_into_four_parts` (the other direction of this
        correspondence) for descriptions on each of the parts.
        """

        # {state type -> set of state keys OR None for wildcard}
        # (The same structure as that of a StateFilter.)
        new_types: Dict[str, Optional[Set[str]]] = {}

        # if we start with all, insert the excluded statetypes as empty sets
        # to prevent them from being included
        if all_part:
            new_types.update({state_type: set() for state_type in minus_wildcards})

        # insert the plus wildcards
        new_types.update({state_type: None for state_type in plus_wildcards})

        # insert the specific state keys
        for state_type, state_key in plus_state_keys:
            if state_type in new_types:
                entry = new_types[state_type]
                if entry is not None:
                    entry.add(state_key)
            elif not all_part:
                # don't insert if the entire type is already included by
                # include_others as this would actually shrink the state allowed
                # by this filter.
                new_types[state_type] = {state_key}

        return StateFilter.freeze(new_types, include_others=all_part)

    def approx_difference(self, other: "StateFilter") -> "StateFilter":
        """
        Returns a state filter which represents `self - other`.

        This is useful for determining what state remains to be pulled out of the
        database if we want the state included by `self` but already have the state
        included by `other`.

        The returned state filter
        - MUST include all state events that are included by this filter (`self`)
          unless they are included by `other`;
        - MUST NOT include state events not included by this filter (`self`); and
        - MAY be an over-approximation: the returned state filter
          MAY additionally include some state events from `other`.

        This implementation attempts to return the narrowest such state filter.
        In the case that `self` contains wildcards for state types where
        `other` contains specific state keys, an approximation must be made:
        the returned state filter keeps the wildcard, as state filters are not
        able to express 'all state keys except some given examples'.
        e.g.
            StateFilter(m.room.member -> None (wildcard))
                minus
            StateFilter(m.room.member -> {'@wombat:example.org'})
                is approximated as
            StateFilter(m.room.member -> None (wildcard))
        """

        # We first transform self and other into an alternative representation:
        #   - whether or not they include all events to begin with ('all')
        #   - if so, which event types are excluded? ('excludes')
        #   - which entire event types to include ('wildcards')
        #   - which concrete state keys to include ('concrete state keys')
        (self_all, self_excludes), (
            self_wildcards,
            self_concrete_keys,
        ) = self._decompose_into_four_parts()
        (other_all, other_excludes), (
            other_wildcards,
            other_concrete_keys,
        ) = other._decompose_into_four_parts()

        # Start with an estimate of the difference based on self
        new_all = self_all
        # Wildcards from the other can be added to the exclusion filter
        new_excludes = self_excludes | other_wildcards
        # We remove wildcards that appeared as wildcards in the other
        new_wildcards = self_wildcards - other_wildcards
        # We filter out the concrete state keys that appear in the other
        # as wildcards or concrete state keys.
        new_concrete_keys = {
            (state_type, state_key)
            for (state_type, state_key) in self_concrete_keys
            if state_type not in other_wildcards
        } - other_concrete_keys

        if other_all:
            if self_all:
                # If self starts with all, then we add as wildcards any
                # types which appear in the other's exclusion filter (but
                # aren't in the self exclusion filter). This is as the other
                # filter will return everything BUT the types in its exclusion, so
                # we need to add those excluded types that also match the self
                # filter as wildcard types in the new filter.
                new_wildcards |= other_excludes.difference(self_excludes)

            # If other is an `include_others` then the difference isn't.
            new_all = False
            # (We have no need for excludes when we don't start with all, as there
            #  is nothing to exclude.)
            new_excludes = set()

            # We also filter out all state types that aren't in the exclusion
            # list of the other.
            new_wildcards &= other_excludes
            new_concrete_keys = {
                (state_type, state_key)
                for (state_type, state_key) in new_concrete_keys
                if state_type in other_excludes
            }

        # Transform our newly-constructed state filter from the alternative
        # representation back into the normal StateFilter representation.
        return StateFilter._recompose_from_four_parts(
            new_all, new_excludes, new_wildcards, new_concrete_keys
        )


class StateGroupStorage:
    """High level interface to fetching state for event."""

    def __init__(self, hs: "HomeServer", stores: "Databases"):
        self.stores = stores

    async def get_state_group_delta(
        self, state_group: int
    ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
        """Given a state group try to return a previous group and a delta between
        the old and the new.

        Args:
            state_group: The state group used to retrieve state deltas.

        Returns:
            A tuple of the previous group and a state map of the event IDs which
            make up the delta between the old and new state groups.
        """

        state_group_delta = await self.stores.state.get_state_group_delta(state_group)
        return state_group_delta.prev_group, state_group_delta.delta_ids

    async def get_state_groups_ids(
        self, _room_id: str, event_ids: Iterable[str]
    ) -> Dict[int, MutableStateMap[str]]:
        """Get the event IDs of all the state for the state groups for the given events

        Args:
            _room_id: id of the room for these events
            event_ids: ids of the events

        Returns:
            dict of state_group_id -> (dict of (type, state_key) -> event id)
        """
        if not event_ids:
            return {}

        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)

        groups = set(event_to_groups.values())
        group_to_state = await self.stores.state._get_state_for_groups(groups)

        return group_to_state

    async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
        """Get the event IDs of all the state in the given state group

        Args:
            state_group: A state group for which we want to get the state IDs.

        Returns:
            Resolves to a map of (type, state_key) -> event_id
        """
        group_to_state = await self._get_state_for_groups((state_group,))

        return group_to_state[state_group]

    async def get_state_groups(
        self, room_id: str, event_ids: Iterable[str]
    ) -> Dict[int, List[EventBase]]:
        """Get the state groups for the given list of event_ids

        Args:
            room_id: ID of the room for these events.
            event_ids: The event IDs to retrieve state for.

        Returns:
            dict of state_group_id -> list of state events.
        """
        if not event_ids:
            return {}

        group_to_ids = await self.get_state_groups_ids(room_id, event_ids)

        state_event_map = await self.stores.main.get_events(
            [
                ev_id
                for group_ids in group_to_ids.values()
                for ev_id in group_ids.values()
            ],
            get_prev_content=False,
        )

        return {
            group: [
                state_event_map[v]
                for v in event_id_map.values()
                if v in state_event_map
            ]
            for group, event_id_map in group_to_ids.items()
        }

    def _get_state_groups_from_groups(
        self, groups: List[int], state_filter: StateFilter
    ) -> Awaitable[Dict[int, StateMap[str]]]:
        """Returns the state groups for a given set of groups, filtering on
        types of state events.

        Args:
            groups: list of state group IDs to query
            state_filter: The state filter used to fetch state
                from the database.

        Returns:
            Dict of state group to state map.
        """

        return self.stores.state._get_state_groups_from_groups(groups, state_filter)

    async def get_state_for_events(
        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
    ) -> Dict[str, StateMap[EventBase]]:
        """Given a list of event_ids and type tuples, return a list of state
        dicts for each event.

        Args:
            event_ids: The events to fetch the state of.
            state_filter: The state filter used to fetch state.

        Returns:
            A dict of (event_id) -> (type, state_key) -> [state_events]
        """
        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)

        groups = set(event_to_groups.values())
        group_to_state = await self.stores.state._get_state_for_groups(
            groups, state_filter or StateFilter.all()
        )

        state_event_map = await self.stores.main.get_events(
            [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
            get_prev_content=False,
        )

        event_to_state = {
            event_id: {
                k: state_event_map[v]
                for k, v in group_to_state[group].items()
                if v in state_event_map
            }
            for event_id, group in event_to_groups.items()
        }

        return {event: event_to_state[event] for event in event_ids}

    async def get_state_ids_for_events(
        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
    ) -> Dict[str, StateMap[str]]:
        """
        Get the state dicts corresponding to a list of events, containing the event_ids
        of the state events (as opposed to the events themselves)

        Args:
            event_ids: events whose state should be returned
            state_filter: The state filter used to fetch state from the database.

        Returns:
            A dict from event_id -> (type, state_key) -> event_id
        """
        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)

        groups = set(event_to_groups.values())
        group_to_state = await self.stores.state._get_state_for_groups(
            groups, state_filter or StateFilter.all()
        )

        event_to_state = {
            event_id: group_to_state[group]
            for event_id, group in event_to_groups.items()
        }

        return {event: event_to_state[event] for event in event_ids}

    async def get_state_for_event(
        self, event_id: str, state_filter: Optional[StateFilter] = None
    ) -> StateMap[EventBase]:
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id: event whose state should be returned
            state_filter: The state filter used to fetch state from the database.

        Returns:
            A dict from (type, state_key) -> state_event
        """
        state_map = await self.get_state_for_events(
            [event_id], state_filter or StateFilter.all()
        )
        return state_map[event_id]

    async def get_state_ids_for_event(
        self, event_id: str, state_filter: Optional[StateFilter] = None
    ) -> StateMap[str]:
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id: event whose state should be returned
            state_filter: The state filter used to fetch state from the database.

        Returns:
            A dict from (type, state_key) -> state_event_id
        """
        state_map = await self.get_state_ids_for_events(
            [event_id], state_filter or StateFilter.all()
        )
        return state_map[event_id]

    def _get_state_for_groups(
        self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
    ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
        """Gets the state at each of a list of state groups, optionally
        filtering by type/state_key

        Args:
            groups: list of state groups for which we want to get the state.
            state_filter: The state filter used to fetch state.
                from the database.

        Returns:
            Dict of state group to state map.
        """
        return self.stores.state._get_state_for_groups(
            groups, state_filter or StateFilter.all()
        )

    async def store_state_group(
        self,
        event_id: str,
        room_id: str,
        prev_group: Optional[int],
        delta_ids: Optional[StateMap[str]],
        current_state_ids: StateMap[str],
    ) -> int:
        """Store a new set of state, returning a newly assigned state group.

        Args:
            event_id: The event ID for which the state was calculated.
            room_id: ID of the room for which the state was calculated.
            prev_group: A previous state group for the room, optional.
            delta_ids: The delta between state at `prev_group` and
                `current_state_ids`, if `prev_group` was given. Same format as
                `current_state_ids`.
            current_state_ids: The state to store. Map of (type, state_key)
                to event_id.

        Returns:
            The state group ID
        """
        return await self.stores.state.store_state_group(
            event_id, room_id, prev_group, delta_ids, current_state_ids
        )