summaryrefslogtreecommitdiff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_data.py10
-rw-r--r--synapse/handlers/admin.py12
-rw-r--r--synapse/handlers/appservice.py43
-rw-r--r--synapse/handlers/auth.py29
-rw-r--r--synapse/handlers/device.py54
-rw-r--r--synapse/handlers/devicemessage.py12
-rw-r--r--synapse/handlers/directory.py12
-rw-r--r--synapse/handlers/e2e_keys.py34
-rw-r--r--synapse/handlers/event_auth.py10
-rw-r--r--synapse/handlers/events.py6
-rw-r--r--synapse/handlers/federation.py163
-rw-r--r--synapse/handlers/federation_event.py272
-rw-r--r--synapse/handlers/groups_local.py503
-rw-r--r--synapse/handlers/initial_sync.py46
-rw-r--r--synapse/handlers/message.py270
-rw-r--r--synapse/handlers/oidc.py4
-rw-r--r--synapse/handlers/pagination.py48
-rw-r--r--synapse/handlers/presence.py20
-rw-r--r--synapse/handlers/profile.py83
-rw-r--r--synapse/handlers/receipts.py112
-rw-r--r--synapse/handlers/register.py3
-rw-r--r--synapse/handlers/relations.py85
-rw-r--r--synapse/handlers/room.py218
-rw-r--r--synapse/handlers/room_batch.py6
-rw-r--r--synapse/handlers/room_list.py3
-rw-r--r--synapse/handlers/room_member.py41
-rw-r--r--synapse/handlers/room_summary.py26
-rw-r--r--synapse/handlers/search.py26
-rw-r--r--synapse/handlers/stats.py6
-rw-r--r--synapse/handlers/sync.py168
-rw-r--r--synapse/handlers/typing.py22
-rw-r--r--synapse/handlers/user_directory.py6
32 files changed, 1134 insertions, 1219 deletions
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 4af9fbc5..0478448b 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -23,7 +23,7 @@ from synapse.replication.http.account_data import (
ReplicationUserAccountDataRestServlet,
)
from synapse.streams import EventSource
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -105,7 +105,7 @@ class AccountDataHandler:
)
self._notifier.on_new_event(
- "account_data_key", max_stream_id, users=[user_id]
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
await self._notify_modules(user_id, room_id, account_data_type, content)
@@ -141,7 +141,7 @@ class AccountDataHandler:
)
self._notifier.on_new_event(
- "account_data_key", max_stream_id, users=[user_id]
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
await self._notify_modules(user_id, None, account_data_type, content)
@@ -176,7 +176,7 @@ class AccountDataHandler:
)
self._notifier.on_new_event(
- "account_data_key", max_stream_id, users=[user_id]
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
return max_stream_id
else:
@@ -201,7 +201,7 @@ class AccountDataHandler:
)
self._notifier.on_new_event(
- "account_data_key", max_stream_id, users=[user_id]
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
return max_stream_id
else:
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 96376963..d4fe7df5 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
class AdminHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
async def get_whois(self, user: UserID) -> JsonDict:
connections = []
@@ -197,7 +197,9 @@ class AdminHandler:
from_key = events[-1].internal_metadata.after
- events = await filter_events_for_client(self.storage, user_id, events)
+ events = await filter_events_for_client(
+ self._storage_controllers, user_id, events
+ )
writer.write_events(room_id, events)
@@ -233,7 +235,9 @@ class AdminHandler:
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
- state = await self.state_store.get_state_for_event(event_id)
+ state = await self._state_storage_controller.get_state_for_event(
+ event_id
+ )
writer.write_state(room_id, event_id, state)
return writer.finished()
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 85bd5e47..814553e0 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -19,7 +19,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -38,6 +38,7 @@ from synapse.types import (
JsonDict,
RoomAlias,
RoomStreamToken,
+ StreamKeyType,
UserID,
)
from synapse.util.async_helpers import Linearizer
@@ -213,8 +214,8 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
- `stream_key` can be "typing_key", "receipt_key", "presence_key",
- "to_device_key" or "device_list_key". Any other value for `stream_key`
+ `stream_key` can be StreamKeyType.TYPING, StreamKeyType.RECEIPT, StreamKeyType.PRESENCE,
+ StreamKeyType.TO_DEVICE or StreamKeyType.DEVICE_LIST. Any other value for `stream_key`
will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into
@@ -235,11 +236,11 @@ class ApplicationServicesHandler:
# Only the following streams are currently supported.
# FIXME: We should use constants for these values.
if stream_key not in (
- "typing_key",
- "receipt_key",
- "presence_key",
- "to_device_key",
- "device_list_key",
+ StreamKeyType.TYPING,
+ StreamKeyType.RECEIPT,
+ StreamKeyType.PRESENCE,
+ StreamKeyType.TO_DEVICE,
+ StreamKeyType.DEVICE_LIST,
):
return
@@ -258,14 +259,14 @@ class ApplicationServicesHandler:
# Ignore to-device messages if the feature flag is not enabled
if (
- stream_key == "to_device_key"
+ stream_key == StreamKeyType.TO_DEVICE
and not self._msc2409_to_device_messages_enabled
):
return
# Ignore device lists if the feature flag is not enabled
if (
- stream_key == "device_list_key"
+ stream_key == StreamKeyType.DEVICE_LIST
and not self._msc3202_transaction_extensions_enabled
):
return
@@ -283,15 +284,15 @@ class ApplicationServicesHandler:
if (
stream_key
in (
- "typing_key",
- "receipt_key",
- "presence_key",
- "to_device_key",
+ StreamKeyType.TYPING,
+ StreamKeyType.RECEIPT,
+ StreamKeyType.PRESENCE,
+ StreamKeyType.TO_DEVICE,
)
and service.supports_ephemeral
)
or (
- stream_key == "device_list_key"
+ stream_key == StreamKeyType.DEVICE_LIST
and service.msc3202_transaction_extensions
)
]
@@ -317,7 +318,7 @@ class ApplicationServicesHandler:
logger.debug("Checking interested services for %s", stream_key)
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
- if stream_key == "typing_key":
+ if stream_key == StreamKeyType.TYPING:
# Note that we don't persist the token (via set_appservice_stream_type_pos)
# for typing_key due to performance reasons and due to their highly
# ephemeral nature.
@@ -333,7 +334,7 @@ class ApplicationServicesHandler:
async with self._ephemeral_events_linearizer.queue(
(service.id, stream_key)
):
- if stream_key == "receipt_key":
+ if stream_key == StreamKeyType.RECEIPT:
events = await self._handle_receipts(service, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -342,7 +343,7 @@ class ApplicationServicesHandler:
service, "read_receipt", new_token
)
- elif stream_key == "presence_key":
+ elif stream_key == StreamKeyType.PRESENCE:
events = await self._handle_presence(service, users, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -351,7 +352,7 @@ class ApplicationServicesHandler:
service, "presence", new_token
)
- elif stream_key == "to_device_key":
+ elif stream_key == StreamKeyType.TO_DEVICE:
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
@@ -366,7 +367,7 @@ class ApplicationServicesHandler:
service, "to_device", new_token
)
- elif stream_key == "device_list_key":
+ elif stream_key == StreamKeyType.DEVICE_LIST:
device_list_summary = await self._get_device_list_summary(
service, new_token
)
@@ -502,7 +503,7 @@ class ApplicationServicesHandler:
time_now = self.clock.time_msec()
events.extend(
{
- "type": "m.presence",
+ "type": EduTypes.PRESENCE,
"sender": event.user_id,
"content": format_user_presence_state(
event, time_now, include_user_id=False
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1b9050ea..fbafbbee 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -210,7 +210,8 @@ class AuthHandler:
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
- self._password_enabled = hs.config.auth.password_enabled
+ self._password_enabled_for_login = hs.config.auth.password_enabled_for_login
+ self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
self._third_party_rules = hs.get_third_party_event_rules()
@@ -387,13 +388,13 @@ class AuthHandler:
return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
- """Get a list of the authentication types this user can use"""
+ """Get a list of the user-interactive authentication types this user can use."""
ui_auth_types = set()
# if the HS supports password auth, and the user has a non-null password, we
# support password auth
- if self._password_localdb_enabled and self._password_enabled:
+ if self._password_localdb_enabled and self._password_enabled_for_reauth:
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
if lookupres:
_, password_hash = lookupres
@@ -402,7 +403,7 @@ class AuthHandler:
# also allow auth from password providers
for t in self.password_auth_provider.get_supported_login_types().keys():
- if t == LoginType.PASSWORD and not self._password_enabled:
+ if t == LoginType.PASSWORD and not self._password_enabled_for_reauth:
continue
ui_auth_types.add(t)
@@ -710,7 +711,7 @@ class AuthHandler:
return res
# fall back to the v1 login flow
- canonical_id, _ = await self.validate_login(authdict)
+ canonical_id, _ = await self.validate_login(authdict, is_reauth=True)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -1064,7 +1065,7 @@ class AuthHandler:
Returns:
Whether users on this server are allowed to change or set a password
"""
- return self._password_enabled and self._password_localdb_enabled
+ return self._password_enabled_for_login and self._password_localdb_enabled
def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
@@ -1089,9 +1090,9 @@ class AuthHandler:
# that comes first, where it's present.
if LoginType.PASSWORD in types:
types.remove(LoginType.PASSWORD)
- if self._password_enabled:
+ if self._password_enabled_for_login:
types.insert(0, LoginType.PASSWORD)
- elif self._password_localdb_enabled and self._password_enabled:
+ elif self._password_localdb_enabled and self._password_enabled_for_login:
types.insert(0, LoginType.PASSWORD)
return types
@@ -1100,6 +1101,7 @@ class AuthHandler:
self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
+ is_reauth: bool = False,
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Authenticates the user for the /login API
@@ -1110,6 +1112,9 @@ class AuthHandler:
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
ratelimit: whether to apply the failed_login_attempt ratelimiter
+ is_reauth: whether this is part of a User-Interactive Authorisation
+ flow to reauthenticate for a privileged action (rather than a
+ new login)
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
@@ -1132,8 +1137,14 @@ class AuthHandler:
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
+
if login_type == LoginType.PASSWORD:
- if not self._password_enabled:
+ if is_reauth:
+ passwords_allowed_here = self._password_enabled_for_reauth
+ else:
+ passwords_allowed_here = self._password_enabled_for_login
+
+ if not passwords_allowed_here:
raise SynapseError(400, "Password login has been disabled.")
if not isinstance(password, str):
raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index a91b1ee4..a0cbeedc 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -28,7 +28,7 @@ from typing import (
)
from synapse.api import errors
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import (
Codes,
FederationDeniedError,
@@ -43,6 +43,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.types import (
JsonDict,
+ StreamKeyType,
StreamToken,
UserID,
get_domain_from_id,
@@ -60,6 +61,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MAX_DEVICE_DISPLAY_NAME_LEN = 100
+DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
class DeviceWorkerHandler:
@@ -69,7 +71,7 @@ class DeviceWorkerHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
- self.state_store = hs.get_storage().state
+ self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
@@ -164,7 +166,7 @@ class DeviceWorkerHandler:
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
- current_state_ids = await self.store.get_current_state_ids(room_id)
+ current_state_ids = await self._state_storage.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
@@ -202,7 +204,9 @@ class DeviceWorkerHandler:
continue
# mapping from event_id -> state_dict
- prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
+ prev_state_ids = await self._state_storage.get_state_ids_for_events(
+ event_ids
+ )
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@@ -276,7 +280,8 @@ class DeviceHandler(DeviceWorkerHandler):
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
- "m.device_list_update", self.device_list_updater.incoming_device_list_update
+ EduTypes.DEVICE_LIST_UPDATE,
+ self.device_list_updater.incoming_device_list_update,
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@@ -291,6 +296,19 @@ class DeviceHandler(DeviceWorkerHandler):
# On start up check if there are any updates pending.
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
+ self._delete_stale_devices_after = hs.config.server.delete_stale_devices_after
+
+ # Ideally we would run this on a worker and condition this on the
+ # "run_background_tasks_on" setting, but this would mean making the notification
+ # of device list changes over federation work on workers, which is nontrivial.
+ if self._delete_stale_devices_after is not None:
+ self.clock.looping_call(
+ run_as_background_process,
+ DELETE_STALE_DEVICES_INTERVAL_MS,
+ "delete_stale_devices",
+ self._delete_stale_devices,
+ )
+
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
@@ -366,6 +384,19 @@ class DeviceHandler(DeviceWorkerHandler):
raise errors.StoreError(500, "Couldn't generate a device ID.")
+ async def _delete_stale_devices(self) -> None:
+ """Background task that deletes devices which haven't been accessed for more than
+ a configured time period.
+ """
+ # We should only be running this job if the config option is defined.
+ assert self._delete_stale_devices_after is not None
+ now_ms = self.clock.time_msec()
+ since_ms = now_ms - self._delete_stale_devices_after
+ devices = await self.store.get_local_devices_not_accessed_since(since_ms)
+
+ for user_id, user_devices in devices.items():
+ await self.delete_devices(user_id, user_devices)
+
@trace
async def delete_device(self, user_id: str, device_id: str) -> None:
"""Delete the given device
@@ -502,7 +533,7 @@ class DeviceHandler(DeviceWorkerHandler):
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
self.notifier.on_new_event(
- "device_list_key", position, users={user_id}, rooms=room_ids
+ StreamKeyType.DEVICE_LIST, position, users={user_id}, rooms=room_ids
)
# We may need to do some processing asynchronously for local user IDs.
@@ -523,7 +554,9 @@ class DeviceHandler(DeviceWorkerHandler):
from_user_id, user_ids
)
- self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
+ self.notifier.on_new_event(
+ StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
+ )
async def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
@@ -686,7 +719,8 @@ class DeviceHandler(DeviceWorkerHandler):
)
# TODO: when called, this isn't in a logging context.
# This leads to log spam, sentry event spam, and massive
- # memory usage. See #12552.
+ # memory usage.
+ # See https://github.com/matrix-org/synapse/issues/12552.
# log_kv(
# {"message": "sent device update to host", "host": host}
# )
@@ -760,6 +794,10 @@ class DeviceListUpdater:
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
prev_ids = edu_content.pop("prev_id", [])
+ if not isinstance(prev_ids, list):
+ raise SynapseError(
+ 400, "Device list update had an invalid 'prev_ids' field"
+ )
prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 4cb725d0..444c08bc 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
-from synapse.api.constants import ToDeviceEventTypes
+from synapse.api.constants import EduTypes, ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
@@ -26,7 +26,7 @@ from synapse.logging.opentracing import (
set_tag,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
-from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
@@ -59,11 +59,11 @@ class DeviceMessageHandler:
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.to_device:
hs.get_federation_registry().register_edu_handler(
- "m.direct_to_device", self.on_direct_to_device_edu
+ EduTypes.DIRECT_TO_DEVICE, self.on_direct_to_device_edu
)
else:
hs.get_federation_registry().register_instances_for_edu(
- "m.direct_to_device",
+ EduTypes.DIRECT_TO_DEVICE,
hs.config.worker.writers.to_device,
)
@@ -151,7 +151,7 @@ class DeviceMessageHandler:
# Notify listeners that there are new to-device messages to process,
# handing them the latest stream id.
self.notifier.on_new_event(
- "to_device_key", last_stream_id, users=local_messages.keys()
+ StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
)
async def _check_for_unknown_devices(
@@ -285,7 +285,7 @@ class DeviceMessageHandler:
# Notify listeners that there are new to-device messages to process,
# handing them the latest stream id.
self.notifier.on_new_event(
- "to_device_key", last_stream_id, users=local_messages.keys()
+ StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
)
if self.federation_sender:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 33d827a4..1459a046 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -45,6 +45,7 @@ class DirectoryHandler:
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.config = hs.config
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.require_membership = hs.config.server.require_membership_for_aliases
@@ -71,6 +72,9 @@ class DirectoryHandler:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
+ if ":" in room_alias.localpart:
+ raise SynapseError(400, "Invalid character in room alias localpart: ':'.")
+
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this.
@@ -316,7 +320,7 @@ class DirectoryHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
- alias_event = await self.state.get_current_state(
+ alias_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.CanonicalAlias, ""
)
@@ -460,7 +464,11 @@ class DirectoryHandler:
making_public = visibility == "public"
if making_public:
room_aliases = await self.store.get_aliases_for_room(room_id)
- canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
+ canonical_alias = (
+ await self._storage_controllers.state.get_canonical_alias_for_room(
+ room_id
+ )
+ )
if canonical_alias:
room_aliases.append(canonical_alias)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d6714228..52bb5c9c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -25,6 +25,7 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
+from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
@@ -66,13 +67,13 @@ class E2eKeysHandler:
# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
- "m.signing_key_update",
+ EduTypes.SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
- "org.matrix.signing_key_update",
+ EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
)
@@ -1105,22 +1106,19 @@ class E2eKeysHandler:
# can request over federation
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
- (
- key,
- key_id,
- verify_key,
- ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
-
- if key is None:
+ cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
+ user, key_type
+ )
+ if cross_signing_keys is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
- return key, key_id, verify_key
+ return cross_signing_keys
async def _retrieve_cross_signing_keys_for_remote_user(
self,
user: UserID,
desired_key_type: str,
- ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
+ ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1146,12 +1144,10 @@ class E2eKeysHandler:
type(e),
e,
)
- return None, None, None
+ return None
# Process each of the retrieved cross-signing keys
- desired_key = None
- desired_key_id = None
- desired_verify_key = None
+ desired_key_data = None
retrieved_device_ids = []
for key_type in ["master", "self_signing"]:
key_content = remote_result.get(key_type + "_key")
@@ -1196,9 +1192,7 @@ class E2eKeysHandler:
# If this is the desired key type, save it and its ID/VerifyKey
if key_type == desired_key_type:
- desired_key = key_content
- desired_verify_key = verify_key
- desired_key_id = key_id
+ desired_key_data = key_content, key_id, verify_key
# At the same time, store this key in the db for subsequent queries
await self.store.set_e2e_cross_signing_key(
@@ -1212,7 +1206,7 @@ class E2eKeysHandler:
user.to_string(), retrieved_device_ids
)
- return desired_key, desired_key_id, desired_verify_key
+ return desired_key_data
def _check_cross_signing_key(
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index d441ebb0..6bed4643 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -241,7 +241,15 @@ class EventAuthHandler:
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self._store.get_event(join_rules_event_id)
- return join_rules_event.content.get("join_rule") == JoinRules.RESTRICTED
+ content_join_rule = join_rules_event.content.get("join_rule")
+ if content_join_rule == JoinRules.RESTRICTED:
+ return True
+
+ # also check for MSC3787 behaviour
+ if room_version.msc3787_knock_restricted_join_rule:
+ return content_join_rule == JoinRules.KNOCK_RESTRICTED
+
+ return False
async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 82a5aac3..ac13340d 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -113,7 +113,7 @@ class EventStreamHandler:
states = await presence_handler.get_states(users)
to_add.extend(
{
- "type": EduTypes.Presence,
+ "type": EduTypes.PRESENCE,
"content": format_user_presence_state(state, time_now),
}
for state in states
@@ -139,7 +139,7 @@ class EventStreamHandler:
class EventHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
async def get_event(
self,
@@ -177,7 +177,7 @@ class EventHandler:
is_peeking = user.to_string() not in users
filtered = await filter_events_for_client(
- self.storage, user.to_string(), [event], is_peeking=is_peeking
+ self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38dc5b1f..6a143440 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,7 +20,16 @@ import itertools
import logging
from enum import Enum
from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
import attr
from signedjson.key import decode_verify_key_bytes
@@ -34,6 +43,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
FederationDeniedError,
+ FederationError,
HttpResponseException,
NotFoundError,
RequestSendFailed,
@@ -54,6 +64,7 @@ from synapse.replication.http.federation import (
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -124,8 +135,8 @@ class FederationHandler:
self.hs = hs
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@@ -158,6 +169,14 @@ class FederationHandler:
self.third_party_event_rules = hs.get_third_party_event_rules()
+ # if this is the main process, fire off a background process to resume
+ # any partial-state-resync operations which were in flight when we
+ # were shut down.
+ if not hs.config.worker.worker_app:
+ run_as_background_process(
+ "resume_sync_partial_state_room", self._resume_sync_partial_state_room
+ )
+
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
) -> bool:
@@ -323,7 +342,7 @@ class FederationHandler:
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
- self.storage,
+ self._storage_controllers,
self.server_name,
events_to_check,
redact=False,
@@ -352,7 +371,7 @@ class FederationHandler:
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = await self.state_handler.get_current_state(room_id)
+ curr_state = await self._storage_controllers.state.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)
@@ -459,6 +478,8 @@ class FederationHandler:
"""
# TODO: We should be able to call this on workers, but the upgrading of
# room stuff after join currently doesn't work on workers.
+ # TODO: Before we relax this condition, we need to allow re-syncing of
+ # partial room state to happen on workers.
assert self.config.worker.worker_app is None
logger.debug("Joining %s to %s", joinee, room_id)
@@ -539,12 +560,11 @@ class FederationHandler:
if ret.partial_state:
# Kick off the process of asynchronously fetching the state for this
# room.
- #
- # TODO(faster_joins): pick this up again on restart
run_as_background_process(
desc="sync_partial_state_room",
func=self._sync_partial_state_room,
- destination=origin,
+ initial_destination=origin,
+ other_destinations=ret.servers_in_room,
room_id=room_id,
)
@@ -659,7 +679,7 @@ class FederationHandler:
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -730,7 +750,9 @@ class FederationHandler:
# Note that this requires the /send_join request to come back to the
# same server.
if room_version.msc3083_join_rules:
- state_ids = await self.store.get_current_state_ids(room_id)
+ state_ids = await self._state_storage_controller.get_current_state_ids(
+ room_id
+ )
if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
@@ -848,7 +870,7 @@ class FederationHandler:
)
)
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -877,7 +899,7 @@ class FederationHandler:
await self.federation_client.send_leave(host_list, event)
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -1026,7 +1048,9 @@ class FederationHandler:
if event.internal_metadata.outlier:
raise NotFoundError("State not known at event %s" % (event_id,))
- state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
+ state_groups = await self._state_storage_controller.get_state_groups_ids(
+ room_id, [event_id]
+ )
# get_state_groups_ids should return exactly one result
assert len(state_groups) == 1
@@ -1075,7 +1099,9 @@ class FederationHandler:
],
)
- events = await filter_events_for_server(self.storage, origin, events)
+ events = await filter_events_for_server(
+ self._storage_controllers, origin, events
+ )
return events
@@ -1106,7 +1132,9 @@ class FederationHandler:
if not in_room:
raise AuthError(403, "Host not in room.")
- events = await filter_events_for_server(self.storage, origin, [event])
+ events = await filter_events_for_server(
+ self._storage_controllers, origin, [event]
+ )
event = events[0]
return event
else:
@@ -1135,7 +1163,7 @@ class FederationHandler:
)
missing_events = await filter_events_for_server(
- self.storage, origin, missing_events
+ self._storage_controllers, origin, missing_events
)
return missing_events
@@ -1259,7 +1287,9 @@ class FederationHandler:
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
+ )
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = await self.store.get_event(
@@ -1308,7 +1338,9 @@ class FederationHandler:
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
+ )
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
@@ -1441,17 +1473,35 @@ class FederationHandler:
# well.
return None
+ async def _resume_sync_partial_state_room(self) -> None:
+ """Resumes resyncing of all partial-state rooms after a restart."""
+ assert not self.config.worker.worker_app
+
+ partial_state_rooms = await self.store.get_partial_state_rooms_and_servers()
+ for room_id, servers_in_room in partial_state_rooms.items():
+ run_as_background_process(
+ desc="sync_partial_state_room",
+ func=self._sync_partial_state_room,
+ initial_destination=None,
+ other_destinations=servers_in_room,
+ room_id=room_id,
+ )
+
async def _sync_partial_state_room(
self,
- destination: str,
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
room_id: str,
) -> None:
"""Background process to resync the state of a partial-state room
Args:
- destination: homeserver to pull the state from
+ initial_destination: the initial homeserver to pull the state from
+ other_destinations: other homeservers to try to pull the state from, if
+ `initial_destination` is unavailable
room_id: room to be resynced
"""
+ assert not self.config.worker.worker_app
# TODO(faster_joins): do we need to lock to avoid races? What happens if other
# worker processes kick off a resync in parallel? Perhaps we should just elect
@@ -1461,8 +1511,29 @@ class FederationHandler:
# really leave, that might mean we have difficulty getting the room state over
# federation.
#
- # TODO(faster_joins): try other destinations if the one we have fails
+ # TODO(faster_joins): we need some way of prioritising which homeservers in
+ # `other_destinations` to try first, otherwise we'll spend ages trying dead
+ # homeservers for large rooms.
+
+ if initial_destination is None and len(other_destinations) == 0:
+ raise ValueError(
+ f"Cannot resync state of {room_id}: no destinations provided"
+ )
+
+ # Make an infinite iterator of destinations to try. Once we find a working
+ # destination, we'll stick with it until it flakes.
+ if initial_destination is not None:
+ # Move `initial_destination` to the front of the list.
+ destinations = list(other_destinations)
+ if initial_destination in destinations:
+ destinations.remove(initial_destination)
+ destinations = [initial_destination] + destinations
+ destination_iter = itertools.cycle(destinations)
+ else:
+ destination_iter = itertools.cycle(other_destinations)
+ # `destination` is the current remote homeserver we're pulling from.
+ destination = next(destination_iter)
logger.info("Syncing state for room %s via %s", room_id, destination)
# we work through the queue in order of increasing stream ordering.
@@ -1473,14 +1544,19 @@ class FederationHandler:
# clear the lazy-loading flag.
logger.info("Updating current state for %s", room_id)
assert (
- self.storage.persistence is not None
+ self._storage_controllers.persistence is not None
), "TODO(faster_joins): support for workers"
- await self.storage.persistence.update_current_state(room_id)
+ await self._storage_controllers.persistence.update_current_state(
+ room_id
+ )
logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id)
if success:
logger.info("State resync complete for %s", room_id)
+ self._storage_controllers.state.notify_room_un_partial_stated(
+ room_id
+ )
# TODO(faster_joins) update room stats and user directory?
return
@@ -1498,6 +1574,41 @@ class FederationHandler:
allow_rejected=True,
)
for event in events:
- await self._federation_event_handler.update_state_for_partial_state_event(
- destination, event
- )
+ for attempt in itertools.count():
+ try:
+ await self._federation_event_handler.update_state_for_partial_state_event(
+ destination, event
+ )
+ break
+ except FederationError as e:
+ if attempt == len(destinations) - 1:
+ # We have tried every remote server for this event. Give up.
+ # TODO(faster_joins) giving up isn't the right thing to do
+ # if there's a temporary network outage. retrying
+ # indefinitely is also not the right thing to do if we can
+ # reach all homeservers and they all claim they don't have
+ # the state we want.
+ logger.error(
+ "Failed to get state for %s at %s from %s because %s, "
+ "giving up!",
+ room_id,
+ event,
+ destination,
+ e,
+ )
+ raise
+
+ # Try the next remote server.
+ logger.info(
+ "Failed to get state for %s at %s from %s because %s",
+ room_id,
+ event,
+ destination,
+ e,
+ )
+ destination = next(destination_iter)
+ logger.info(
+ "Syncing state for room %s via %s instead",
+ room_id,
+ destination,
+ )
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 6cf927e4..87a06083 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -30,6 +30,7 @@ from typing import (
from prometheus_client import Counter
+from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -63,6 +64,7 @@ from synapse.replication.http.federation import (
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
@@ -96,14 +98,14 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
- self._storage = hs.get_storage()
- self._state_store = self._storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._state_handler = hs.get_state_handler()
self._event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self._message_handler = hs.get_message_handler()
- self._action_generator = hs.get_action_generator()
+ self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
self._state_resolution_handler = hs.get_state_resolution_handler()
# avoid a circular dependency by deferring execution here
self._get_room_member_handler = hs.get_room_member_handler
@@ -272,7 +274,7 @@ class FederationEventHandler:
affected=pdu.event_id,
)
- await self._process_received_pdu(origin, pdu, state=None)
+ await self._process_received_pdu(origin, pdu, state_ids=None)
async def on_send_membership_event(
self, origin: str, event: EventBase
@@ -461,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
- old_state=state,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in state
+ },
partial_state=partial_state,
)
@@ -475,7 +479,23 @@ class FederationEventHandler:
# and discover that we do not have it.
event.internal_metadata.proactively_send = False
- return await self.persist_events_and_notify(room_id, [(event, context)])
+ stream_id_after_persist = await self.persist_events_and_notify(
+ room_id, [(event, context)]
+ )
+
+ # If we're joining the room again, check if there is new marker
+ # state indicating that there is new history imported somewhere in
+ # the DAG. Multiple markers can exist in the current state with
+ # unique state_keys.
+ #
+ # Do this after the state from the remote join was persisted (via
+ # `persist_events_and_notify`). Otherwise we can run into a
+ # situation where the create event doesn't exist yet in the
+ # `current_state_events`
+ for e in state:
+ await self._handle_marker_event(origin, e)
+
+ return stream_id_after_persist
async def update_state_for_partial_state_event(
self, destination: str, event: EventBase
@@ -485,6 +505,9 @@ class FederationEventHandler:
Args:
destination: server to request full state from
event: partial-state event to be de-partial-stated
+
+ Raises:
+ FederationError if we fail to request state from the remote server.
"""
logger.info("Updating state for %s", event.event_id)
with nested_logging_context(suffix=event.event_id):
@@ -494,12 +517,12 @@ class FederationEventHandler:
#
# This is the same operation as we do when we receive a regular event
# over federation.
- state = await self._resolve_state_at_missing_prevs(destination, event)
+ state_ids = await self._resolve_state_at_missing_prevs(destination, event)
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
- old_state=state,
+ state_ids_before_event=state_ids,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
@@ -515,7 +538,9 @@ class FederationEventHandler:
)
return
await self._store.update_state_for_partial_state_event(event, context)
- self._state_store.notify_event_un_partial_stated(event.event_id)
+ self._state_storage_controller.notify_event_un_partial_stated(
+ event.event_id
+ )
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
@@ -749,11 +774,12 @@ class FederationEventHandler:
return
try:
- state = await self._resolve_state_at_missing_prevs(origin, event)
+ state_ids = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state
+
await self._process_received_pdu(
- origin, event, state=state, backfilled=backfilled
+ origin, event, state_ids=state_ids, backfilled=backfilled
)
except FederationError as e:
if e.code == 403:
@@ -763,7 +789,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
- ) -> Optional[Iterable[EventBase]]:
+ ) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
@@ -790,8 +816,12 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
- if we already had all the prev events, `None`. Otherwise, returns a list of
- the events in the state at `event`.
+ if we already had all the prev events, `None`. Otherwise, returns
+ the event ids of the state at `event`.
+
+ Raises:
+ FederationError if we fail to get the state from the remote server after any
+ missing `prev_event`s.
"""
room_id = event.room_id
event_id = event.event_id
@@ -811,10 +841,12 @@ class FederationEventHandler:
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
- event_map = {event_id: event}
+
try:
# Get the state of the events we know about
- ours = await self._state_store.get_state_groups_ids(room_id, seen)
+ ours = await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen
+ )
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
@@ -831,40 +863,23 @@ class FederationEventHandler:
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- remote_state = await self._get_state_after_missing_prev_event(
- dest, room_id, p
+ remote_state_map = (
+ await self._get_state_ids_after_missing_prev_event(
+ dest, room_id, p
+ )
)
- remote_state_map = {
- (x.type, x.state_key): x.event_id for x in remote_state
- }
state_maps.append(remote_state_map)
- for x in remote_state:
- event_map[x.event_id] = x
-
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
- event_map,
+ event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
)
- # We need to give _process_received_pdu the actual state events
- # rather than event ids, so generate that now.
-
- # First though we need to fetch all the events that are in
- # state_map, so we can build up the state below.
- evs = await self._store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.as_is,
- )
- event_map.update(evs)
-
- state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
@@ -876,14 +891,14 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
- return state
+ return state_map
- async def _get_state_after_missing_prev_event(
+ async def _get_state_ids_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
- ) -> List[EventBase]:
+ ) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@@ -892,7 +907,11 @@ class FederationEventHandler:
event_id: The id of the event we want the state at.
Returns:
- A list of events in the state, including the event itself
+ The event ids of the state *after* the given event.
+
+ Raises:
+ InvalidResponseError: if the remote homeserver's response contains fields
+ of the wrong type.
"""
(
state_event_ids,
@@ -907,19 +926,17 @@ class FederationEventHandler:
len(auth_event_ids),
)
- # start by just trying to fetch the events from the store
+ # Start by checking events we already have in the DB
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
- fetched_events = await self._store.get_events(
- desired_events, allow_rejected=True
- )
+ have_events = await self._store.have_seen_events(room_id, desired_events)
- missing_desired_events = desired_events - fetched_events.keys()
+ missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
- len(fetched_events),
+ len(have_events),
)
# We probably won't need most of the auth events, so let's just check which
@@ -930,7 +947,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
- missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
@@ -956,47 +973,51 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events
)
- # we need to make sure we re-load from the database to get the rejected
- # state correct.
- fetched_events.update(
- await self._store.get_events(missing_desired_events, allow_rejected=True)
- )
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
+ # We now need to fill out the state map, which involves fetching the
+ # type and state key for each event ID in the state.
+ state_map = {}
- bad_events = [
- (event_id, event.room_id)
- for event_id, event in fetched_events.items()
- if event.room_id != room_id
- ]
+ event_metadata = await self._store.get_metadata_for_events(state_event_ids)
+ for state_event_id, metadata in event_metadata.items():
+ if metadata.room_id != room_id:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ #
+ # This can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ state_event_id,
+ metadata.room_id,
+ room_id,
+ )
+ continue
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
+ if metadata.state_key is None:
+ logger.warning(
+ "Remote server gave us non-state event in state: %s", state_event_id
+ )
+ continue
- del fetched_events[bad_event_id]
+ state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem.
- remote_event = fetched_events.get(event_id)
+ remote_event = await self._store.get_event(
+ event_id,
+ allow_none=True,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.as_is,
+ )
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
- failed_to_fetch = desired_events - fetched_events.keys()
+ failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@@ -1004,14 +1025,12 @@ class FederationEventHandler:
failed_to_fetch,
)
- remote_state = [
- fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
- ]
-
if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
+ state_map[
+ (remote_event.type, remote_event.state_key)
+ ] = remote_event.event_id
- return remote_state
+ return state_map
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
@@ -1038,7 +1057,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
- state: Optional[Iterable[EventBase]],
+ state_ids: Optional[StateMap[str]],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@@ -1060,7 +1079,7 @@ class FederationEventHandler:
event: event to be persisted
- state: Normally None, but if we are handling a gap in the graph
+ state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
@@ -1072,7 +1091,8 @@ class FederationEventHandler:
try:
context = await self._state_handler.compute_event_context(
- event, old_state=state
+ event,
+ state_ids_before_event=state_ids,
)
context = await self._check_event_auth(
origin,
@@ -1089,7 +1109,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- await self._check_for_soft_fail(event, state, origin=origin)
+ await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@@ -1228,6 +1248,14 @@ class FederationEventHandler:
# Nothing to retrieve then (invalid marker)
return
+ already_seen_insertion_event = await self._store.have_seen_event(
+ marker_event.room_id, insertion_event_id
+ )
+ if already_seen_insertion_event:
+ # No need to process a marker again if we have already seen the
+ # insertion event that it was pointing to
+ return
+
logger.debug(
"_handle_marker_event: backfilling insertion event %s", insertion_event_id
)
@@ -1423,7 +1451,7 @@ class FederationEventHandler:
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
try:
validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth)
@@ -1500,7 +1528,11 @@ class FederationEventHandler:
return context
# now check auth against what we think the auth events *should* be.
- prev_state_ids = await context.get_prev_state_ids()
+ event_types = event_auth.auth_types_for_event(event.room_version, event)
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types(event_types)
+ )
+
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -1552,14 +1584,16 @@ class FederationEventHandler:
if guest_access == GuestAccess.CAN_JOIN:
return
- current_state_map = await self._state_handler.get_current_state(event.room_id)
- current_state = list(current_state_map.values())
- await self._get_room_member_handler().kick_guest_users(current_state)
+ current_state = await self._storage_controllers.state.get_current_state(
+ event.room_id
+ )
+ current_state_list = list(current_state.values())
+ await self._get_room_member_handler().kick_guest_users(current_state_list)
async def _check_for_soft_fail(
self,
event: EventBase,
- state: Optional[Iterable[EventBase]],
+ state_ids: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@@ -1567,7 +1601,7 @@ class FederationEventHandler:
Args:
event
- state: The state at the event if we don't have all the event's prev events
+ state_ids: The state at the event if we don't have all the event's prev events
origin: The host the event originates from.
"""
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
@@ -1582,8 +1616,11 @@ class FederationEventHandler:
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ # The event types we want to pull from the "current" state.
+ auth_types = auth_types_for_event(room_version_obj, event)
+
# Calculate the "current state".
- if state is not None:
+ if state_ids is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
@@ -1596,20 +1633,25 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets_d = await self._state_store.get_state_groups(
+ state_sets_d = await self._state_storage_controller.get_state_groups_ids(
event.room_id, extrem_ids
)
- state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
- state_sets.append(state)
- current_states = await self._state_handler.resolve_events(
- room_version, state_sets, event
+ state_sets: List[StateMap[str]] = list(state_sets_d.values())
+ state_sets.append(state_ids)
+ current_state_ids = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version,
+ state_sets,
+ event_map=None,
+ state_res_store=StateResolutionStore(self._store),
+ )
)
- current_state_ids: StateMap[str] = {
- k: e.event_id for k, e in current_states.items()
- }
else:
- current_state_ids = await self._state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
+ current_state_ids = (
+ await self._state_storage_controller.get_current_state_ids(
+ event.room_id, StateFilter.from_types(auth_types)
+ )
)
logger.debug(
@@ -1619,7 +1661,6 @@ class FederationEventHandler:
)
# Now check if event pass auth against said current state
- auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
@@ -1865,7 +1906,7 @@ class FederationEventHandler:
# create a new state group as a delta from the existing one.
prev_group = context.state_group
- state_group = await self._state_store.store_state_group(
+ state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -1874,10 +1915,10 @@ class FederationEventHandler:
)
return EventContext.with_state(
+ storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
+ state_delta_due_to_event=state_updates,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
@@ -1913,7 +1954,7 @@ class FederationEventHandler:
min_depth,
)
else:
- await self._action_generator.handle_push_actions_for_event(
+ await self._bulk_push_rule_evaluator.action_for_event_by_user(
event, context
)
@@ -1964,11 +2005,14 @@ class FederationEventHandler:
)
return result["max_stream_id"]
else:
- assert self._storage.persistence
+ assert self._storage_controllers.persistence
# Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs.
- events, max_stream_token = await self._storage.persistence.persist_events(
+ (
+ events,
+ max_stream_token,
+ ) = await self._storage_controllers.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
deleted file mode 100644
index e7a39978..00000000
--- a/synapse/handlers/groups_local.py
+++ /dev/null
@@ -1,503 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
-
-from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, JsonDict, get_domain_from_id
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
- """Returns an async function that looks at the group id and calls the function
- on federation or the local group server if the group is local
- """
-
- async def f(
- self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
- ) -> JsonDict:
- if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
-
- if self.is_mine_id(group_id):
- return await getattr(self.groups_server_handler, func_name)(
- group_id, *args, **kwargs
- )
- else:
- destination = get_domain_from_id(group_id)
-
- try:
- return await getattr(self.transport_client, func_name)(
- destination, group_id, *args, **kwargs
- )
- except HttpResponseException as e:
- # Capture errors returned by the remote homeserver and
- # re-throw specific errors as SynapseErrors. This is so
- # when the remote end responds with things like 403 Not
- # In Group, we can communicate that to the client instead
- # of a 500.
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- return f
-
-
-class GroupsLocalWorkerHandler:
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.store = hs.get_datastores().main
- self.room_list_handler = hs.get_room_list_handler()
- self.groups_server_handler = hs.get_groups_server_handler()
- self.transport_client = hs.get_federation_transport_client()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.keyring = hs.get_keyring()
- self.is_mine_id = hs.is_mine_id
- self.signing_key = hs.signing_key
- self.server_name = hs.hostname
- self.notifier = hs.get_notifier()
- self.attestations = hs.get_groups_attestation_signing()
-
- self.profile_handler = hs.get_profile_handler()
-
- # The following functions merely route the query to the local groups server
- # or federation depending on if the group is local or remote
-
- get_group_profile = _create_rerouter("get_group_profile")
- get_rooms_in_group = _create_rerouter("get_rooms_in_group")
- get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
- get_group_category = _create_rerouter("get_group_category")
- get_group_categories = _create_rerouter("get_group_categories")
- get_group_role = _create_rerouter("get_group_role")
- get_group_roles = _create_rerouter("get_group_roles")
-
- async def get_group_summary(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get the group summary for a group.
-
- If the group is remote we check that the users have valid attestations.
- """
- if self.is_mine_id(group_id):
- res = await self.groups_server_handler.get_group_summary(
- group_id, requester_user_id
- )
- else:
- try:
- res = await self.transport_client.get_group_summary(
- get_domain_from_id(group_id), group_id, requester_user_id
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- group_server_name = get_domain_from_id(group_id)
-
- # Loop through the users and validate the attestations.
- chunk = res["users_section"]["users"]
- valid_users = []
- for entry in chunk:
- g_user_id = entry["user_id"]
- attestation = entry.pop("attestation", {})
- try:
- if get_domain_from_id(g_user_id) != group_server_name:
- await self.attestations.verify_attestation(
- attestation,
- group_id=group_id,
- user_id=g_user_id,
- server_name=get_domain_from_id(g_user_id),
- )
- valid_users.append(entry)
- except Exception as e:
- logger.info("Failed to verify user is in group: %s", e)
-
- res["users_section"]["users"] = valid_users
-
- res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
- res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
-
- # Add `is_publicised` flag to indicate whether the user has publicised their
- # membership of the group on their profile
- result = await self.store.get_publicised_groups_for_user(requester_user_id)
- is_publicised = group_id in result
-
- res.setdefault("user", {})["is_publicised"] = is_publicised
-
- return res
-
- async def get_users_in_group(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get users in a group"""
- if self.is_mine_id(group_id):
- return await self.groups_server_handler.get_users_in_group(
- group_id, requester_user_id
- )
-
- group_server_name = get_domain_from_id(group_id)
-
- try:
- res = await self.transport_client.get_users_in_group(
- get_domain_from_id(group_id), group_id, requester_user_id
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- chunk = res["chunk"]
- valid_entries = []
- for entry in chunk:
- g_user_id = entry["user_id"]
- attestation = entry.pop("attestation", {})
- try:
- if get_domain_from_id(g_user_id) != group_server_name:
- await self.attestations.verify_attestation(
- attestation,
- group_id=group_id,
- user_id=g_user_id,
- server_name=get_domain_from_id(g_user_id),
- )
- valid_entries.append(entry)
- except Exception as e:
- logger.info("Failed to verify user is in group: %s", e)
-
- res["chunk"] = valid_entries
-
- return res
-
- async def get_joined_groups(self, user_id: str) -> JsonDict:
- group_ids = await self.store.get_joined_groups(user_id)
- return {"groups": group_ids}
-
- async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
- if self.hs.is_mine_id(user_id):
- result = await self.store.get_publicised_groups_for_user(user_id)
-
- # Check AS associated groups for this user - this depends on the
- # RegExps in the AS registration file (under `users`)
- for app_service in self.store.get_app_services():
- result.extend(app_service.get_groups_for_user(user_id))
-
- return {"groups": result}
- else:
- try:
- bulk_result = await self.transport_client.bulk_get_publicised_groups(
- get_domain_from_id(user_id), [user_id]
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- result = bulk_result.get("users", {}).get(user_id)
- # TODO: Verify attestations
- return {"groups": result}
-
- async def bulk_get_publicised_groups(
- self, user_ids: Iterable[str], proxy: bool = True
- ) -> JsonDict:
- destinations: Dict[str, Set[str]] = {}
- local_users = set()
-
- for user_id in user_ids:
- if self.hs.is_mine_id(user_id):
- local_users.add(user_id)
- else:
- destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id)
-
- if not proxy and destinations:
- raise SynapseError(400, "Some user_ids are not local")
-
- results = {}
- failed_results: List[str] = []
- for destination, dest_user_ids in destinations.items():
- try:
- r = await self.transport_client.bulk_get_publicised_groups(
- destination, list(dest_user_ids)
- )
- results.update(r["users"])
- except Exception:
- failed_results.extend(dest_user_ids)
-
- for uid in local_users:
- results[uid] = await self.store.get_publicised_groups_for_user(uid)
-
- # Check AS associated groups for this user - this depends on the
- # RegExps in the AS registration file (under `users`)
- for app_service in self.store.get_app_services():
- results[uid].extend(app_service.get_groups_for_user(uid))
-
- return {"users": results}
-
-
-class GroupsLocalHandler(GroupsLocalWorkerHandler):
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
- # Ensure attestations get renewed
- hs.get_groups_attestation_renewer()
-
- # The following functions merely route the query to the local groups server
- # or federation depending on if the group is local or remote
-
- update_group_profile = _create_rerouter("update_group_profile")
-
- add_room_to_group = _create_rerouter("add_room_to_group")
- update_room_in_group = _create_rerouter("update_room_in_group")
- remove_room_from_group = _create_rerouter("remove_room_from_group")
-
- update_group_summary_room = _create_rerouter("update_group_summary_room")
- delete_group_summary_room = _create_rerouter("delete_group_summary_room")
-
- update_group_category = _create_rerouter("update_group_category")
- delete_group_category = _create_rerouter("delete_group_category")
-
- update_group_summary_user = _create_rerouter("update_group_summary_user")
- delete_group_summary_user = _create_rerouter("delete_group_summary_user")
-
- update_group_role = _create_rerouter("update_group_role")
- delete_group_role = _create_rerouter("delete_group_role")
-
- set_group_join_policy = _create_rerouter("set_group_join_policy")
-
- async def create_group(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Create a group"""
-
- logger.info("Asking to create group with ID: %r", group_id)
-
- if self.is_mine_id(group_id):
- res = await self.groups_server_handler.create_group(
- group_id, user_id, content
- )
- local_attestation = None
- remote_attestation = None
- else:
- raise SynapseError(400, "Unable to create remote groups")
-
- is_publicised = content.get("publicise", False)
- token = await self.store.register_user_group_membership(
- group_id,
- user_id,
- membership="join",
- is_admin=True,
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- is_publicised=is_publicised,
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
-
- return res
-
- async def join_group(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Request to join a group"""
- if self.is_mine_id(group_id):
- await self.groups_server_handler.join_group(group_id, user_id, content)
- local_attestation = None
- remote_attestation = None
- else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content["attestation"] = local_attestation
-
- try:
- res = await self.transport_client.join_group(
- get_domain_from_id(group_id), group_id, user_id, content
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- remote_attestation = res["attestation"]
-
- await self.attestations.verify_attestation(
- remote_attestation,
- group_id=group_id,
- user_id=user_id,
- server_name=get_domain_from_id(group_id),
- )
-
- # TODO: Check that the group is public and we're being added publicly
- is_publicised = content.get("publicise", False)
-
- token = await self.store.register_user_group_membership(
- group_id,
- user_id,
- membership="join",
- is_admin=False,
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- is_publicised=is_publicised,
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
-
- return {}
-
- async def accept_invite(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Accept an invite to a group"""
- if self.is_mine_id(group_id):
- await self.groups_server_handler.accept_invite(group_id, user_id, content)
- local_attestation = None
- remote_attestation = None
- else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content["attestation"] = local_attestation
-
- try:
- res = await self.transport_client.accept_group_invite(
- get_domain_from_id(group_id), group_id, user_id, content
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- remote_attestation = res["attestation"]
-
- await self.attestations.verify_attestation(
- remote_attestation,
- group_id=group_id,
- user_id=user_id,
- server_name=get_domain_from_id(group_id),
- )
-
- # TODO: Check that the group is public and we're being added publicly
- is_publicised = content.get("publicise", False)
-
- token = await self.store.register_user_group_membership(
- group_id,
- user_id,
- membership="join",
- is_admin=False,
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- is_publicised=is_publicised,
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
-
- return {}
-
- async def invite(
- self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
- ) -> JsonDict:
- """Invite a user to a group"""
- content = {"requester_user_id": requester_user_id, "config": config}
- if self.is_mine_id(group_id):
- res = await self.groups_server_handler.invite_to_group(
- group_id, user_id, requester_user_id, content
- )
- else:
- try:
- res = await self.transport_client.invite_to_group(
- get_domain_from_id(group_id),
- group_id,
- user_id,
- requester_user_id,
- content,
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- return res
-
- async def on_invite(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """One of our users were invited to a group"""
- # TODO: Support auto join and rejection
-
- if not self.is_mine_id(user_id):
- raise SynapseError(400, "User not on this server")
-
- local_profile = {}
- if "profile" in content:
- if "name" in content["profile"]:
- local_profile["name"] = content["profile"]["name"]
- if "avatar_url" in content["profile"]:
- local_profile["avatar_url"] = content["profile"]["avatar_url"]
-
- token = await self.store.register_user_group_membership(
- group_id,
- user_id,
- membership="invite",
- content={"profile": local_profile, "inviter": content["inviter"]},
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
- try:
- user_profile = await self.profile_handler.get_profile(user_id)
- except Exception as e:
- logger.warning("No profile for user %s: %s", user_id, e)
- user_profile = {}
-
- return {"state": "invite", "user_profile": user_profile}
-
- async def remove_user_from_group(
- self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """Remove a user from a group"""
- if user_id == requester_user_id:
- token = await self.store.register_user_group_membership(
- group_id, user_id, membership="leave"
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
-
- # TODO: Should probably remember that we tried to leave so that we can
- # retry if the group server is currently down.
-
- if self.is_mine_id(group_id):
- res = await self.groups_server_handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content
- )
- else:
- content["requester_user_id"] = requester_user_id
- try:
- res = await self.transport_client.remove_user_from_group(
- get_domain_from_id(group_id),
- group_id,
- requester_user_id,
- user_id,
- content,
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- return res
-
- async def user_removed_from_group(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> None:
- """One of our users was removed/kicked from a group"""
- # TODO: Check if user in group
- token = await self.store.register_user_group_membership(
- group_id, user_id, membership="leave"
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 7b94770f..85b472f2 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -30,6 +30,7 @@ from synapse.types import (
Requester,
RoomStreamToken,
StateMap,
+ StreamKeyType,
StreamToken,
UserID,
)
@@ -66,8 +67,8 @@ class InitialSyncHandler:
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
async def snapshot_all_rooms(
self,
@@ -143,7 +144,7 @@ class InitialSyncHandler:
to_key=int(now_token.receipt_key),
)
if self.hs.config.experimental.msc2285_enabled:
- receipt = ReceiptEventSource.filter_out_private(receipt, user_id)
+ receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -189,7 +190,7 @@ class InitialSyncHandler:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
- self.state_handler.get_current_state, event.room_id
+ self._state_storage_controller.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
@@ -197,7 +198,8 @@ class InitialSyncHandler:
event.stream_ordering,
)
deferred_room_state = run_in_background(
- self.state_store.get_state_for_events, [event.event_id]
+ self._state_storage_controller.get_state_for_events,
+ [event.event_id],
).addCallback(
lambda states: cast(StateMap[EventBase], states[event.event_id])
)
@@ -217,11 +219,13 @@ class InitialSyncHandler:
).addErrback(unwrapFirstError)
messages = await filter_events_for_client(
- self.storage, user_id, messages
+ self._storage_controllers, user_id, messages
)
- start_token = now_token.copy_and_replace("room_key", token)
- end_token = now_token.copy_and_replace("room_key", room_end_token)
+ start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
+ end_token = now_token.copy_and_replace(
+ StreamKeyType.ROOM, room_end_token
+ )
time_now = self.clock.time_msec()
d["messages"] = {
@@ -271,7 +275,7 @@ class InitialSyncHandler:
"rooms": rooms_ret,
"presence": [
{
- "type": "m.presence",
+ "type": EduTypes.PRESENCE,
"content": format_user_presence_state(event, now),
}
for event in presence
@@ -352,7 +356,9 @@ class InitialSyncHandler:
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
- room_state = await self.state_store.get_state_for_event(member_event_id)
+ room_state = await self._state_storage_controller.get_state_for_event(
+ member_event_id
+ )
limit = pagin_config.limit if pagin_config else None
if limit is None:
@@ -366,11 +372,11 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
- self.storage, user_id, messages, is_peeking=is_peeking
+ self._storage_controllers, user_id, messages, is_peeking=is_peeking
)
- start_token = StreamToken.START.copy_and_replace("room_key", token)
- end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
+ start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
+ end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token)
time_now = self.clock.time_msec()
@@ -401,7 +407,9 @@ class InitialSyncHandler:
membership: str,
is_peeking: bool,
) -> JsonDict:
- current_state = await self.state.get_current_state(room_id=room_id)
+ current_state = await self._storage_controllers.state.get_current_state(
+ room_id=room_id
+ )
# TODO: These concurrently
time_now = self.clock.time_msec()
@@ -436,7 +444,7 @@ class InitialSyncHandler:
return [
{
- "type": EduTypes.Presence,
+ "type": EduTypes.PRESENCE,
"content": format_user_presence_state(s, time_now),
}
for s in states
@@ -449,7 +457,9 @@ class InitialSyncHandler:
if not receipts:
return []
if self.hs.config.experimental.msc2285_enabled:
- receipts = ReceiptEventSource.filter_out_private(receipts, user_id)
+ receipts = ReceiptEventSource.filter_out_private_receipts(
+ receipts, user_id
+ )
return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable(
@@ -469,10 +479,10 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
- self.storage, user_id, messages, is_peeking=is_peeking
+ self._storage_controllers, user_id, messages, is_peeking=is_peeking
)
- start_token = now_token.copy_and_replace("room_key", token)
+ start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
end_token = now_token
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c28b792e..f455158a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -28,6 +28,7 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
GuestAccess,
+ HistoryVisibility,
Membership,
RelationTypes,
UserTypes,
@@ -44,7 +45,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.event_auth import validate_event_for_room_version
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
@@ -54,12 +55,19 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
-from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
+from synapse.types import (
+ MutableStateMap,
+ Requester,
+ RoomAlias,
+ StreamToken,
+ UserID,
+ create_requester,
+)
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
-from synapse.visibility import filter_events_for_client
+from synapse.visibility import get_effective_room_visibility_from_state
if TYPE_CHECKING:
from synapse.events.third_party_rules import ThirdPartyEventRules
@@ -76,8 +84,8 @@ class MessageHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
@@ -117,14 +125,16 @@ class MessageHandler:
)
if membership == Membership.JOIN:
- data = await self.state.get_current_state(room_id, event_type, state_key)
+ data = await self._storage_controllers.state.get_current_state_event(
+ room_id, event_type, state_key
+ )
elif membership == Membership.LEAVE:
key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist.
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
- room_state = await self.state_store.get_state_for_events(
+ room_state = await self._state_storage_controller.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@@ -175,49 +185,31 @@ class MessageHandler:
state_filter = state_filter or StateFilter.all()
if at_token:
- last_event = await self.store.get_last_event_in_room_before_stream_ordering(
- room_id,
- end_token=at_token.room_key,
+ last_event_id = (
+ await self.store.get_last_event_in_room_before_stream_ordering(
+ room_id,
+ end_token=at_token.room_key,
+ )
)
- if not last_event:
+ if not last_event_id:
raise NotFoundError("Can't find event for token %s" % (at_token,))
- # check whether the user is in the room at that time to determine
- # whether they should be treated as peeking.
- state_map = await self.state_store.get_state_for_event(
- last_event.event_id,
- StateFilter.from_types([(EventTypes.Member, user_id)]),
- )
-
- joined = False
- membership_event = state_map.get((EventTypes.Member, user_id))
- if membership_event:
- joined = membership_event.membership == Membership.JOIN
-
- is_peeking = not joined
-
- visible_events = await filter_events_for_client(
- self.storage,
- user_id,
- [last_event],
- filter_send_to_client=False,
- is_peeking=is_peeking,
- )
-
- if visible_events:
- room_state_events = await self.state_store.get_state_for_events(
- [last_event.event_id], state_filter=state_filter
- )
- room_state: Mapping[Any, EventBase] = room_state_events[
- last_event.event_id
- ]
- else:
+ if not await self._user_can_see_state_at_event(
+ user_id, room_id, last_event_id
+ ):
raise AuthError(
403,
"User %s not allowed to view events in room %s at token %s"
% (user_id, room_id, at_token),
)
+
+ room_state_events = (
+ await self._state_storage_controller.get_state_for_events(
+ [last_event_id], state_filter=state_filter
+ )
+ )
+ room_state: Mapping[Any, EventBase] = room_state_events[last_event_id]
else:
(
membership,
@@ -227,7 +219,7 @@ class MessageHandler:
)
if membership == Membership.JOIN:
- state_ids = await self.store.get_filtered_current_state_ids(
+ state_ids = await self._state_storage_controller.get_current_state_ids(
room_id, state_filter=state_filter
)
room_state = await self.store.get_events(state_ids.values())
@@ -236,8 +228,10 @@ class MessageHandler:
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
- room_state_events = await self.state_store.get_state_for_events(
- [membership_event_id], state_filter=state_filter
+ room_state_events = (
+ await self._state_storage_controller.get_state_for_events(
+ [membership_event_id], state_filter=state_filter
+ )
)
room_state = room_state_events[membership_event_id]
@@ -245,6 +239,65 @@ class MessageHandler:
events = self._event_serializer.serialize_events(room_state.values(), now)
return events
+ async def _user_can_see_state_at_event(
+ self, user_id: str, room_id: str, event_id: str
+ ) -> bool:
+ # check whether the user was in the room, and the history visibility,
+ # at that time.
+ state_map = await self._state_storage_controller.get_state_for_event(
+ event_id,
+ StateFilter.from_types(
+ [
+ (EventTypes.Member, user_id),
+ (EventTypes.RoomHistoryVisibility, ""),
+ ]
+ ),
+ )
+
+ membership = None
+ membership_event = state_map.get((EventTypes.Member, user_id))
+ if membership_event:
+ membership = membership_event.membership
+
+ # if the user was a member of the room at the time of the event,
+ # they can see it.
+ if membership == Membership.JOIN:
+ return True
+
+ # otherwise, it depends on the history visibility.
+ visibility = get_effective_room_visibility_from_state(state_map)
+
+ if visibility == HistoryVisibility.JOINED:
+ # we weren't a member at the time of the event, so we can't see this event.
+ return False
+
+ # otherwise *invited* is good enough
+ if membership == Membership.INVITE:
+ return True
+
+ if visibility == HistoryVisibility.INVITED:
+ # we weren't invited, so we can't see this event.
+ return False
+
+ if visibility == HistoryVisibility.WORLD_READABLE:
+ return True
+
+ # So it's SHARED, and the user was not a member at the time. The user cannot
+ # see history, unless they have *subsequently* joined the room.
+ #
+ # XXX: if the user has subsequently joined and then left again,
+ # ideally we would share history up to the point they left. But
+ # we don't know when they left. We just treat it as though they
+ # never joined, and restrict access.
+
+ (
+ current_membership,
+ _,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ user_id, event_id
+ )
+ return current_membership == Membership.JOIN
+
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
"""Get all the joined members in the room and their profile information.
@@ -394,7 +447,7 @@ class EventCreationHandler:
self.auth = hs.get_auth()
self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
@@ -426,7 +479,7 @@ class EventCreationHandler:
# This is to stop us from diverging history *too* much.
self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
- self.action_generator = hs.get_action_generator()
+ self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules: "ThirdPartyEventRules" = (
@@ -634,7 +687,9 @@ class EventCreationHandler:
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.Member, None)])
+ )
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = (
await self.store.get_event(prev_event_id, allow_none=True)
@@ -757,7 +812,13 @@ class EventCreationHandler:
The previous version of the event is returned, if it is found in the
event context. Otherwise, None is returned.
"""
- prev_state_ids = await context.get_prev_state_ids()
+ if event.internal_metadata.is_outlier():
+ # This can happen due to out of band memberships
+ return None
+
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(event.type, None)])
+ )
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return None
@@ -877,11 +938,39 @@ class EventCreationHandler:
event.sender,
)
- spam_error = await self.spam_checker.check_event_for_spam(event)
- if spam_error:
- if not isinstance(spam_error, str):
- spam_error = "Spam is not permitted here"
- raise SynapseError(403, spam_error, Codes.FORBIDDEN)
+ spam_check_result = await self.spam_checker.check_event_for_spam(event)
+ if spam_check_result != self.spam_checker.NOT_SPAM:
+ if isinstance(spam_check_result, tuple):
+ try:
+ [code, dict] = spam_check_result
+ raise SynapseError(
+ 403,
+ "This message had been rejected as probable spam",
+ code,
+ dict,
+ )
+ except ValueError:
+ logger.error(
+ "Spam-check module returned invalid error value. Expecting [code, dict], got %s",
+ spam_check_result,
+ )
+ spam_check_result = Codes.FORBIDDEN
+
+ if isinstance(spam_check_result, Codes):
+ raise SynapseError(
+ 403,
+ "This message has been rejected as probable spam",
+ spam_check_result,
+ )
+
+ # Backwards compatibility: if the return value is not an error code, it
+ # means the module returned an error message to be included in the
+ # SynapseError (which is now deprecated).
+ raise SynapseError(
+ 403,
+ spam_check_result,
+ Codes.FORBIDDEN,
+ )
ev = await self.handle_new_client_event(
requester=requester,
@@ -1001,7 +1090,7 @@ class EventCreationHandler:
# after it is created
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
elif (
event.type == EventTypes.MSC2716_INSERTION
and state_event_ids
@@ -1013,8 +1102,35 @@ class EventCreationHandler:
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
- old_state = await self.store.get_events_as_list(state_event_ids)
- context = await self.state.compute_event_context(event, old_state=old_state)
+ metadata = await self.store.get_metadata_for_events(state_event_ids)
+
+ state_map_for_event: MutableStateMap[str] = {}
+ for state_id in state_event_ids:
+ data = metadata.get(state_id)
+ if data is None:
+ # We're trying to persist a new historical batch of events
+ # with the given state, e.g. via
+ # `RoomBatchSendEventRestServlet`. The state can be inferred
+ # by Synapse or set directly by the client.
+ #
+ # Either way, we should have persisted all the state before
+ # getting here.
+ raise Exception(
+ f"State event {state_id} not found in DB,"
+ " Synapse should have persisted it before using it."
+ )
+
+ if data.state_key is None:
+ raise Exception(
+ f"Trying to set non-state event {state_id} as state"
+ )
+
+ state_map_for_event[(data.event_type, data.state_key)] = state_id
+
+ context = await self.state.compute_event_context(
+ event,
+ state_ids_before_event=state_map_for_event,
+ )
else:
context = await self.state.compute_event_context(event)
@@ -1056,20 +1172,11 @@ class EventCreationHandler:
SynapseError if the event is invalid.
"""
- relation = event.content.get("m.relates_to")
+ relation = relation_from_event(event)
if not relation:
return
- relation_type = relation.get("rel_type")
- if not relation_type:
- return
-
- # Ensure the parent is real.
- relates_to = relation.get("event_id")
- if not relates_to:
- return
-
- parent_event = await self.store.get_event(relates_to, allow_none=True)
+ parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
if parent_event:
# And in the same room.
if parent_event.room_id != event.room_id:
@@ -1078,28 +1185,31 @@ class EventCreationHandler:
else:
# There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine.
- if not await self.store.event_is_target_of_relation(relates_to):
+ if not await self.store.event_is_target_of_relation(relation.parent_id):
# Otherwise, the client can't know about the parent event!
raise SynapseError(400, "Can't send relation to unknown event")
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
# event multiple times).
- if relation_type == RelationTypes.ANNOTATION:
- aggregation_key = relation["key"]
+ if relation.rel_type == RelationTypes.ANNOTATION:
+ aggregation_key = relation.aggregation_key
+
+ if aggregation_key is None:
+ raise SynapseError(400, "Missing aggregation key")
if len(aggregation_key) > 500:
raise SynapseError(400, "Aggregation key is too long")
already_exists = await self.store.has_user_annotated_event(
- relates_to, event.type, aggregation_key, event.sender
+ relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
# Don't attempt to start a thread if the parent event is a relation.
- elif relation_type == RelationTypes.THREAD:
- if await self.store.event_includes_relation(relates_to):
+ elif relation.rel_type == RelationTypes.THREAD:
+ if await self.store.event_includes_relation(relation.parent_id):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
)
@@ -1245,7 +1355,9 @@ class EventCreationHandler:
# and `state_groups` because they have `prev_events` that aren't persisted yet
# (historical messages persisted in reverse-chronological order).
if not event.internal_metadata.is_historical():
- await self.action_generator.handle_push_actions_for_event(event, context)
+ await self._bulk_push_rule_evaluator.action_for_event_by_user(
+ event, context
+ )
try:
# If we're a worker we need to hit out to the master.
@@ -1391,7 +1503,7 @@ class EventCreationHandler:
"""
extra_users = extra_users or []
- assert self.storage.persistence is not None
+ assert self._storage_controllers.persistence is not None
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
@@ -1547,7 +1659,11 @@ class EventCreationHandler:
"Redacting MSC2716 events is not supported in this room version",
)
- prev_state_ids = await context.get_prev_state_ids()
+ event_types = event_auth.auth_types_for_event(event.room_version, event)
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types(event_types)
+ )
+
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -1621,7 +1737,7 @@ class EventCreationHandler:
event,
event_pos,
max_stream_token,
- ) = await self.storage.persistence.persist_event(
+ ) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
)
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index f6ffb7d1..9de61d55 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -224,7 +224,7 @@ class OidcHandler:
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
- logger.exception("Could not verify session for OIDC callback")
+ logger.warning("Could not verify session for OIDC callback: %s", e)
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
@@ -827,7 +827,7 @@ class OidcProvider:
logger.debug("Exchanging OAuth2 code for a token")
token = await self._exchange_code(code)
except OidcError as e:
- logger.exception("Could not exchange OAuth2 code")
+ logger.warning("Could not exchange OAuth2 code: %s", e)
self._sso_handler.render_error(request, e.error, e.error_description)
return
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 7ee33403..6262a358 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -27,7 +27,7 @@ from synapse.handlers.room import ShutdownRoomResponse
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester
+from synapse.types import JsonDict, Requester, StreamKeyType
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -129,8 +129,8 @@ class PaginationHandler:
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -239,7 +239,7 @@ class PaginationHandler:
# defined in the server's configuration, we can safely assume that's the
# case and use it for this room.
max_lifetime = (
- retention_policy["max_lifetime"] or self._retention_default_max_lifetime
+ retention_policy.max_lifetime or self._retention_default_max_lifetime
)
# Cap the effective max_lifetime to be within the range allowed in the
@@ -352,7 +352,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self.pagination_lock.write(room_id):
- await self.storage.purge_events.purge_history(
+ await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] complete")
@@ -414,7 +414,7 @@ class PaginationHandler:
if joined:
raise SynapseError(400, "Users are still joined to this room")
- await self.storage.purge_events.purge_room(room_id)
+ await self._storage_controllers.purge_events.purge_room(room_id)
async def get_messages(
self,
@@ -448,7 +448,7 @@ class PaginationHandler:
)
# We expect `/messages` to use historic pagination tokens by default but
# `/messages` should still works with live tokens when manually provided.
- assert from_token.room_key.topological
+ assert from_token.room_key.topological is not None
if pagin_config.limit is None:
# This shouldn't happen as we've set a default limit before this
@@ -491,7 +491,7 @@ class PaginationHandler:
if leave_token.topological < curr_topo:
from_token = from_token.copy_and_replace(
- "room_key", leave_token
+ StreamKeyType.ROOM, leave_token
)
await self.hs.get_federation_handler().maybe_backfill(
@@ -513,16 +513,30 @@ class PaginationHandler:
event_filter=event_filter,
)
- next_token = from_token.copy_and_replace("room_key", next_key)
+ next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
- if events:
- if event_filter:
- events = await event_filter.filter(events)
+ # if no events are returned from pagination, that implies
+ # we have reached the end of the available events.
+ # In that case we do not return end, to tell the client
+ # there is no need for further queries.
+ if not events:
+ return {
+ "chunk": [],
+ "start": await from_token.to_string(self.store),
+ }
- events = await filter_events_for_client(
- self.storage, user_id, events, is_peeking=(member_event_id is None)
- )
+ if event_filter:
+ events = await event_filter.filter(events)
+
+ events = await filter_events_for_client(
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+ # if after the filter applied there are no more events
+ # return immediately - but there might be more in next_token batch
if not events:
return {
"chunk": [],
@@ -539,7 +553,7 @@ class PaginationHandler:
(EventTypes.Member, event.sender) for event in events
)
- state_ids = await self.state_store.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
@@ -653,7 +667,7 @@ class PaginationHandler:
400, "Users are still joined to this room"
)
- await self.storage.purge_events.purge_room(room_id)
+ await self._storage_controllers.purge_events.purge_room(room_id)
logger.info("complete")
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 268481ec..895ea63e 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -49,7 +49,7 @@ from prometheus_client import Counter
from typing_extensions import ContextManager
import synapse.metrics
-from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.appservice import ApplicationService
@@ -66,7 +66,7 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
@@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC):
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.presence_router = hs.get_presence_router()
self.state = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
@@ -394,7 +395,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
- "m.presence",
+ EduTypes.PRESENCE,
hs.config.worker.writers.presence,
)
@@ -522,7 +523,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key",
+ StreamKeyType.PRESENCE,
stream_id,
rooms=room_ids_to_states.keys(),
users=users_to_states.keys(),
@@ -649,7 +650,9 @@ class PresenceHandler(BasePresenceHandler):
federation_registry = hs.get_federation_registry()
- federation_registry.register_edu_handler("m.presence", self.incoming_presence)
+ federation_registry.register_edu_handler(
+ EduTypes.PRESENCE, self.incoming_presence
+ )
LaterGauge(
"synapse_handlers_presence_user_to_current_state_size",
@@ -1145,7 +1148,7 @@ class PresenceHandler(BasePresenceHandler):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key",
+ StreamKeyType.PRESENCE,
stream_id,
rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states],
@@ -1346,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler):
self._event_pos,
room_max_stream_ordering,
)
- max_pos, deltas = await self.store.get_current_state_deltas(
+ (
+ max_pos,
+ deltas,
+ ) = await self._storage_controllers.state.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 239b0aa7..6eed3826 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -23,14 +23,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.types import (
- JsonDict,
- Requester,
- UserID,
- create_requester,
- get_domain_from_id,
-)
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -50,9 +43,6 @@ class ProfileHandler:
delegate to master when necessary.
"""
- PROFILE_UPDATE_MS = 60 * 1000
- PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
-
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -73,11 +63,6 @@ class ProfileHandler:
self._third_party_rules = hs.get_third_party_event_rules()
- if hs.config.worker.run_background_tasks:
- self.clock.looping_call(
- self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
- )
-
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
@@ -116,30 +101,6 @@ class ProfileHandler:
raise SynapseError(502, "Failed to fetch profile")
raise e.to_synapse_error()
- async def get_profile_from_cache(self, user_id: str) -> JsonDict:
- """Get the profile information from our local cache. If the user is
- ours then the profile information will always be correct. Otherwise,
- it may be out of date/missing.
- """
- target_user = UserID.from_string(user_id)
- if self.hs.is_mine(target_user):
- try:
- displayname = await self.store.get_profile_displayname(
- target_user.localpart
- )
- avatar_url = await self.store.get_profile_avatar_url(
- target_user.localpart
- )
- except StoreError as e:
- if e.code == 404:
- raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
- raise
-
- return {"displayname": displayname, "avatar_url": avatar_url}
- else:
- profile = await self.store.get_from_remote_profile_cache(user_id)
- return profile or {}
-
async def get_displayname(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user):
try:
@@ -509,45 +470,3 @@ class ProfileHandler:
# so we act as if we couldn't find the profile.
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
raise
-
- @wrap_as_background_process("Update remote profile")
- async def _update_remote_profile_cache(self) -> None:
- """Called periodically to check profiles of remote users we haven't
- checked in a while.
- """
- entries = await self.store.get_remote_profile_cache_entries_that_expire(
- last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
- )
-
- for user_id, displayname, avatar_url in entries:
- is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
- user_id
- )
- if not is_subscribed:
- await self.store.maybe_delete_remote_profile_cache(user_id)
- continue
-
- try:
- profile = await self.federation.make_query(
- destination=get_domain_from_id(user_id),
- query_type="profile",
- args={"user_id": user_id},
- ignore_backoff=True,
- )
- except Exception:
- logger.exception("Failed to get avatar_url")
-
- await self.store.update_remote_profile_cache(
- user_id, displayname, avatar_url
- )
- continue
-
- new_name = profile.get("displayname")
- if not isinstance(new_name, str):
- new_name = None
- new_avatar = profile.get("avatar_url")
- if not isinstance(new_avatar, str):
- new_avatar = None
-
- # We always hit update to update the last_check timestamp
- await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d61535..43d2882b 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,10 +14,16 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService
from synapse.streams import EventSource
-from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ ReadReceipt,
+ StreamKeyType,
+ UserID,
+ get_domain_from_id,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -46,11 +52,11 @@ class ReceiptsHandler:
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
hs.get_federation_registry().register_edu_handler(
- "m.receipt", self._received_remote_receipt
+ EduTypes.RECEIPT, self._received_remote_receipt
)
else:
hs.get_federation_registry().register_instances_for_edu(
- "m.receipt",
+ EduTypes.RECEIPT,
hs.config.worker.writers.receipts,
)
@@ -129,7 +135,9 @@ class ReceiptsHandler:
affected_room_ids = list({r.room_id for r in receipts})
- self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
+ self.notifier.on_new_event(
+ StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
+ )
# Note that the min here shouldn't be relied upon to be accurate.
await self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
@@ -165,43 +173,69 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
self.config = hs.config
@staticmethod
- def filter_out_private(events: List[JsonDict], user_id: str) -> List[JsonDict]:
- """
- This method takes in what is returned by
- get_linearized_receipts_for_rooms() and goes through read receipts
- filtering out m.read.private receipts if they were not sent by the
- current user.
+ def filter_out_private_receipts(
+ rooms: List[JsonDict], user_id: str
+ ) -> List[JsonDict]:
"""
+ Filters a list of serialized receipts (as returned by /sync and /initialSync)
+ and removes private read receipts of other users.
- visible_events = []
-
- # filter out private receipts the user shouldn't see
- for event in events:
- content = event.get("content", {})
- new_event = event.copy()
- new_event["content"] = {}
-
- for event_id, event_content in content.items():
- receipt_event = {}
- for receipt_type, receipt_content in event_content.items():
- if receipt_type == ReceiptTypes.READ_PRIVATE:
- user_rr = receipt_content.get(user_id, None)
- if user_rr:
- receipt_event[ReceiptTypes.READ_PRIVATE] = {
- user_id: user_rr.copy()
- }
- else:
- receipt_event[receipt_type] = receipt_content.copy()
+ This operates on the return value of get_linearized_receipts_for_rooms(),
+ which is wrapped in a cache. Care must be taken to ensure that the input
+ values are not modified.
- # Only include the receipt event if it is non-empty.
- if receipt_event:
- new_event["content"][event_id] = receipt_event
+ Args:
+ rooms: A list of mappings, each mapping has a `content` field, which
+ is a map of event ID -> receipt type -> user ID -> receipt information.
- # Append new_event to visible_events unless empty
- if len(new_event["content"].keys()) > 0:
- visible_events.append(new_event)
+ Returns:
+ The same as rooms, but filtered.
+ """
- return visible_events
+ result = []
+
+ # Iterate through each room's receipt content.
+ for room in rooms:
+ # The receipt content with other user's private read receipts removed.
+ content = {}
+
+ # Iterate over each event ID / receipts for that event.
+ for event_id, orig_event_content in room.get("content", {}).items():
+ event_content = orig_event_content
+ # If there are private read receipts, additional logic is necessary.
+ if ReceiptTypes.READ_PRIVATE in event_content:
+ # Make a copy without private read receipts to avoid leaking
+ # other user's private read receipts..
+ event_content = {
+ receipt_type: receipt_value
+ for receipt_type, receipt_value in event_content.items()
+ if receipt_type != ReceiptTypes.READ_PRIVATE
+ }
+
+ # Copy the current user's private read receipt from the
+ # original content, if it exists.
+ user_private_read_receipt = orig_event_content[
+ ReceiptTypes.READ_PRIVATE
+ ].get(user_id, None)
+ if user_private_read_receipt:
+ event_content[ReceiptTypes.READ_PRIVATE] = {
+ user_id: user_private_read_receipt
+ }
+
+ # Include the event if there is at least one non-private read
+ # receipt or the current user has a private read receipt.
+ if event_content:
+ content[event_id] = event_content
+
+ # Include the event if there is at least one non-private read receipt
+ # or the current user has a private read receipt.
+ if content:
+ # Build a new event to avoid mutating the cache.
+ new_room = {k: v for k, v in room.items() if k != "content"}
+ new_room["content"] = content
+ result.append(new_room)
+
+ return result
async def get_new_events(
self,
@@ -223,7 +257,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
)
if self.config.experimental.msc2285_enabled:
- events = ReceiptEventSource.filter_out_private(events, user.to_string())
+ events = ReceiptEventSource.filter_out_private_receipts(
+ events, user.to_string()
+ )
return events, to_key
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 05bb1e02..33820428 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -87,6 +87,7 @@ class LoginDict(TypedDict):
class RegistrationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self.hs = hs
self.auth = hs.get_auth()
@@ -528,7 +529,7 @@ class RegistrationHandler:
if requires_invite:
# If the server is in the room, check if the room is public.
- state = await self.store.get_filtered_current_state_ids(
+ state = await self._storage_controllers.state.get_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index c2754ec9..0b63cd21 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,24 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections.abc
import logging
-from typing import (
- TYPE_CHECKING,
- Collection,
- Dict,
- FrozenSet,
- Iterable,
- List,
- Optional,
- Tuple,
-)
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
import attr
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
@@ -70,7 +60,7 @@ class BundledAggregations:
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
- self._storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
@@ -144,7 +134,10 @@ class RelationsHandler:
)
events = await filter_events_for_client(
- self._storage, user_id, events, is_peeking=(member_event_id is None)
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
)
now = self._clock.time_msec()
@@ -254,13 +247,19 @@ class RelationsHandler:
return filtered_results
- async def get_threads_for_events(
- self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
+ async def _get_threads_for_events(
+ self,
+ events_by_id: Dict[str, EventBase],
+ relations_by_id: Dict[str, str],
+ user_id: str,
+ ignored_users: FrozenSet[str],
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.
Args:
- event_ids: Events to get aggregations for threads.
+ events_by_id: A map of event_id to events to get aggregations for threads.
+ relations_by_id: A map of event_id to the relation type, if one exists
+ for that event.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
@@ -271,16 +270,34 @@ class RelationsHandler:
"""
user = UserID.from_string(user_id)
+ # It is not valid to start a thread on an event which itself relates to another event.
+ event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
+
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
- # Only fetch participated for a limited selection based on what had
- # summaries.
+ # Limit fetching whether the requester has participated in a thread to
+ # events which are thread roots.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
- participated = await self._main_store.get_threads_participated(
- thread_event_ids, user_id
+
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {
+ event_id: events_by_id[event_id].sender == user_id
+ for event_id in thread_event_ids
+ }
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [
+ event_id
+ for event_id in thread_event_ids
+ if not participated[event_id]
+ ],
+ user_id,
+ )
)
# Then subtract off the results for any ignored users.
@@ -341,7 +358,8 @@ class RelationsHandler:
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
- current_user_participated=participated[event_id],
+ current_user_participated=events_by_id[event_id].sender == user_id
+ or participated[event_id],
)
return results
@@ -373,20 +391,21 @@ class RelationsHandler:
if event.is_state():
continue
- relates_to = event.content.get("m.relates_to")
- relation_type = None
- if isinstance(relates_to, collections.abc.Mapping):
- relation_type = relates_to.get("rel_type")
+ relates_to = relation_from_event(event)
+ if relates_to:
# An event which is a replacement (ie edit) or annotation (ie,
# reaction) may not have any other event related to it.
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ if relates_to.rel_type in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REPLACE,
+ ):
continue
+ # Track the event's relation information for later.
+ relations_by_id[event.event_id] = relates_to.rel_type
+
# The event should get bundled aggregations.
events_by_id[event.event_id] = event
- # Track the event's relation information for later.
- if isinstance(relation_type, str):
- relations_by_id[event.event_id] = relation_type
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
@@ -398,9 +417,9 @@ class RelationsHandler:
# events to be fetched. Thus, we check those first!
# Fetch thread summaries (but only for the directly requested events).
- threads = await self.get_threads_for_events(
- # It is not valid to start a thread on an event which itself relates to another event.
- [eid for eid in events_by_id.keys() if eid not in relations_by_id],
+ threads = await self._get_threads_for_events(
+ events_by_id,
+ relations_by_id,
user_id,
ignored_users,
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 604eb6ec..520663f1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -33,6 +33,7 @@ from typing import (
import attr
from typing_extensions import TypedDict
+import synapse.events.snapshot
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -72,12 +73,12 @@ from synapse.types import (
RoomID,
RoomStreamToken,
StateMap,
+ StreamKeyType,
StreamToken,
UserID,
create_requester,
)
from synapse.util import stringutils
-from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
@@ -106,6 +107,7 @@ class EventContext:
class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.hs = hs
@@ -149,10 +151,11 @@ class RoomCreationHandler:
)
preset_config["encrypted"] = encrypted
- self._replication = hs.get_replication_data_handler()
+ self._default_power_level_content_override = (
+ self.config.room.default_power_level_content_override
+ )
- # linearizer to stop two upgrades happening at once
- self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
+ self._replication = hs.get_replication_data_handler()
# If a user tries to update the same room multiple times in quick
# succession, only process the first attempt and return its result to
@@ -196,6 +199,39 @@ class RoomCreationHandler:
400, "An upgrade for this room is currently in progress"
)
+ # Check whether the room exists and 404 if it doesn't.
+ # We could go straight for the auth check, but that will raise a 403 instead.
+ old_room = await self.store.get_room(old_room_id)
+ if old_room is None:
+ raise NotFoundError("Unknown room id %s" % (old_room_id,))
+
+ new_room_id = self._generate_room_id()
+
+ # Check whether the user has the power level to carry out the upgrade.
+ # `check_auth_rules_from_context` will check that they are in the room and have
+ # the required power level to send the tombstone event.
+ (
+ tombstone_event,
+ tombstone_context,
+ ) = await self.event_creation_handler.create_event(
+ requester,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
+ },
+ },
+ )
+ old_room_version = await self.store.get_room_version(old_room_id)
+ validate_event_for_room_version(old_room_version, tombstone_event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ old_room_version, tombstone_event, tombstone_context
+ )
+
# Upgrade the room
#
# If this user has sent multiple upgrade requests for the same room
@@ -206,19 +242,35 @@ class RoomCreationHandler:
self._upgrade_room,
requester,
old_room_id,
- new_version, # args for _upgrade_room
+ old_room, # args for _upgrade_room
+ new_room_id,
+ new_version,
+ tombstone_event,
+ tombstone_context,
)
return ret
async def _upgrade_room(
- self, requester: Requester, old_room_id: str, new_version: RoomVersion
+ self,
+ requester: Requester,
+ old_room_id: str,
+ old_room: Dict[str, Any],
+ new_room_id: str,
+ new_version: RoomVersion,
+ tombstone_event: EventBase,
+ tombstone_context: synapse.events.snapshot.EventContext,
) -> str:
"""
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
- new_versions: the version to upgrade the room to
+ old_room: a dict containing room information for the room to be replaced,
+ as returned by `RoomWorkerStore.get_room`.
+ new_room_id: the id of the replacement room
+ new_version: the version to upgrade the room to
+ tombstone_event: the tombstone event to send to the old room
+ tombstone_context: the context for the tombstone event
Raises:
ShadowBanError if the requester is shadow-banned.
@@ -226,40 +278,15 @@ class RoomCreationHandler:
user_id = requester.user.to_string()
assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,)
- # start by allocating a new room id
- r = await self.store.get_room(old_room_id)
- if r is None:
- raise NotFoundError("Unknown room id %s" % (old_room_id,))
- new_room_id = await self._generate_room_id(
- creator_id=user_id,
- is_public=r["is_public"],
- room_version=new_version,
- )
-
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
- # we create and auth the tombstone event before properly creating the new
- # room, to check our user has perms in the old room.
- (
- tombstone_event,
- tombstone_context,
- ) = await self.event_creation_handler.create_event(
- requester,
- {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- },
- },
- )
- old_room_version = await self.store.get_room_version(old_room_id)
- validate_event_for_room_version(old_room_version, tombstone_event)
- await self._event_auth_handler.check_auth_rules_from_context(
- old_room_version, tombstone_event, tombstone_context
+ # create the new room. may raise a `StoreError` in the exceedingly unlikely
+ # event of a room ID collision.
+ await self.store.store_room(
+ room_id=new_room_id,
+ room_creator_user_id=user_id,
+ is_public=old_room["is_public"],
+ room_version=new_version,
)
await self.clone_existing_room(
@@ -277,7 +304,10 @@ class RoomCreationHandler:
context=tombstone_context,
)
- old_room_state = await tombstone_context.get_current_state_ids()
+ state_filter = StateFilter.from_types(
+ [(EventTypes.CanonicalAlias, ""), (EventTypes.PowerLevels, "")]
+ )
+ old_room_state = await tombstone_context.get_current_state_ids(state_filter)
# We know the tombstone event isn't an outlier so it has current state.
assert old_room_state is not None
@@ -401,7 +431,7 @@ class RoomCreationHandler:
requester: the user requesting the upgrade
old_room_id : the id of the room to be replaced
new_room_id: the id to give the new room (should already have been
- created with _gemerate_room_id())
+ created with _generate_room_id())
new_room_version: the new room version to use
tombstone_event_id: the ID of the tombstone event in the old room.
"""
@@ -439,21 +469,22 @@ class RoomCreationHandler:
(EventTypes.RoomAvatar, ""),
(EventTypes.RoomEncryption, ""),
(EventTypes.ServerACL, ""),
- (EventTypes.RelatedGroups, ""),
(EventTypes.PowerLevels, ""),
]
- # If the old room was a space, copy over the room type and the rooms in
- # the space.
- if (
- old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
- == RoomTypes.SPACE
- ):
- creation_content[EventContentFields.ROOM_TYPE] = RoomTypes.SPACE
- types_to_copy.append((EventTypes.SpaceChild, None))
+ # Copy the room type as per MSC3818.
+ room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
+ if room_type is not None:
+ creation_content[EventContentFields.ROOM_TYPE] = room_type
- old_room_state_ids = await self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types(types_to_copy)
+ # If the old room was a space, copy over the rooms in the space.
+ if room_type == RoomTypes.SPACE:
+ types_to_copy.append((EventTypes.SpaceChild, None))
+
+ old_room_state_ids = (
+ await self._storage_controllers.state.get_current_state_ids(
+ old_room_id, StateFilter.from_types(types_to_copy)
+ )
)
# map from event_id to BaseEvent
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
@@ -530,8 +561,10 @@ class RoomCreationHandler:
)
# Transfer membership events
- old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
+ old_room_member_state_ids = (
+ await self._storage_controllers.state.get_current_state_ids(
+ old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
+ )
)
# map from event_id to BaseEvent
@@ -725,6 +758,21 @@ class RoomCreationHandler:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
+ if ":" in config["room_alias_name"]:
+ # Prevent someone from trying to pass in a full alias here.
+ # Note that it's permissible for a room alias to have multiple
+ # hash symbols at the start (notably bridged over from IRC, too),
+ # but the first colon in the alias is defined to separate the local
+ # part from the server name.
+ # (remember server names can contain port numbers, also separated
+ # by a colon. But under no circumstances should the local part be
+ # allowed to contain a colon!)
+ raise SynapseError(
+ 400,
+ "':' is not permitted in the room alias name. "
+ "Please note this expects a local part — 'wombat', not '#wombat:example.com'.",
+ )
+
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
mapping = await self.store.get_association_from_room_alias(room_alias)
@@ -778,7 +826,7 @@ class RoomCreationHandler:
visibility = config.get("visibility", "private")
is_public = visibility == "public"
- room_id = await self._generate_room_id(
+ room_id = await self._generate_and_create_room_id(
creator_id=user_id,
is_public=is_public,
room_version=room_version,
@@ -1042,9 +1090,19 @@ class RoomCreationHandler:
for invitee in invite_list:
power_level_content["users"][invitee] = 100
- # Power levels overrides are defined per chat preset
+ # If the user supplied a preset name e.g. "private_chat",
+ # we apply that preset
power_level_content.update(config["power_level_content_override"])
+ # If the server config contains default_power_level_content_override,
+ # and that contains information for this room preset, apply it.
+ if self._default_power_level_content_override:
+ override = self._default_power_level_content_override.get(preset_config)
+ if override is not None:
+ power_level_content.update(override)
+
+ # Finally, if the user supplied specific permissions for this room,
+ # apply those.
if power_level_content_override:
power_level_content.update(power_level_content_override)
@@ -1090,7 +1148,26 @@ class RoomCreationHandler:
return last_sent_stream_id
- async def _generate_room_id(
+ def _generate_room_id(self) -> str:
+ """Generates a random room ID.
+
+ Room IDs look like "!opaque_id:domain" and are case-sensitive as per the spec
+ at https://spec.matrix.org/v1.2/appendices/#room-ids-and-event-ids.
+
+ Does not check for collisions with existing rooms or prevent future calls from
+ returning the same room ID. To ensure the uniqueness of a new room ID, use
+ `_generate_and_create_room_id` instead.
+
+ Synapse's room IDs are 18 [a-zA-Z] characters long, which comes out to around
+ 102 bits.
+
+ Returns:
+ A random room ID of the form "!opaque_id:domain".
+ """
+ random_string = stringutils.random_string(18)
+ return RoomID(random_string, self.hs.hostname).to_string()
+
+ async def _generate_and_create_room_id(
self,
creator_id: str,
is_public: bool,
@@ -1101,8 +1178,7 @@ class RoomCreationHandler:
attempts = 0
while attempts < 5:
try:
- random_string = stringutils.random_string(18)
- gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
+ gen_room_id = self._generate_room_id()
await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
@@ -1120,8 +1196,8 @@ class RoomContextHandler:
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._relations_handler = hs.get_relations_handler()
async def get_event_context(
@@ -1164,7 +1240,10 @@ class RoomContextHandler:
if use_admin_priviledge:
return events
return await filter_events_for_client(
- self.storage, user.to_string(), events, is_peeking=is_peeking
+ self._storage_controllers,
+ user.to_string(),
+ events,
+ is_peeking=is_peeking,
)
event = await self.store.get_event(
@@ -1221,7 +1300,7 @@ class RoomContextHandler:
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = await self.state_store.get_state_for_events(
+ state = await self._state_storage_controller.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -1239,10 +1318,10 @@ class RoomContextHandler:
events_after=events_after,
state=await filter_evts(state_events),
aggregations=aggregations,
- start=await token.copy_and_replace("room_key", results.start).to_string(
- self.store
- ),
- end=await token.copy_and_replace("room_key", results.end).to_string(
+ start=await token.copy_and_replace(
+ StreamKeyType.ROOM, results.start
+ ).to_string(self.store),
+ end=await token.copy_and_replace(StreamKeyType.ROOM, results.end).to_string(
self.store
),
)
@@ -1254,6 +1333,7 @@ class TimestampLookupHandler:
self.store = hs.get_datastores().main
self.state_handler = hs.get_state_handler()
self.federation_client = hs.get_federation_client()
+ self._storage_controllers = hs.get_storage_controllers()
async def get_event_for_timestamp(
self,
@@ -1327,7 +1407,9 @@ class TimestampLookupHandler:
)
# Find other homeservers from the given state in the room
- curr_state = await self.state_handler.get_current_state(room_id)
+ curr_state = await self._storage_controllers.state.get_current_state(
+ room_id
+ )
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains if domain != self.server_name
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 29de7e5b..1414e575 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -17,7 +17,7 @@ class RoomBatchHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
- self.state_store = hs.get_storage().state
+ self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -53,6 +53,7 @@ class RoomBatchHandler:
# We want to use the successor event depth so they appear after `prev_event` because
# it has a larger `depth` but before the successor event because the `stream_ordering`
# is negative before the successor event.
+ assert most_recent_prev_event_id is not None
successor_event_ids = await self.store.get_successor_events(
most_recent_prev_event_id
)
@@ -139,7 +140,8 @@ class RoomBatchHandler:
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
- prev_state_map = await self.state_store.get_state_ids_for_event(
+ assert most_recent_event_id is not None
+ prev_state_map = await self._state_storage_controller.get_state_ids_for_event(
most_recent_event_id
)
# List of state event ID's
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index f3577b5d..183d4ae3 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.hs = hs
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[
@@ -274,7 +275,7 @@ class RoomListHandler:
if aliases:
result["aliases"] = aliases
- current_state_ids = await self.store.get_current_state_ids(
+ current_state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 802e57c4..d1199a06 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -38,6 +38,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
Requester,
@@ -67,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
@@ -362,7 +364,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical=historical,
)
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.Member, None)])
+ )
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -991,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
- current_state_ids = await self.store.get_current_state_ids(room_id)
+ current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+ room_id
+ )
# If restricted join rules are not being used, a local join can always
# be used.
@@ -1078,17 +1084,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Transfer alias mappings in the room directory
await self.store.update_aliases_for_room(old_room_id, room_id)
- # Check if any groups we own contain the predecessor room
- local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
- for group_id in local_group_ids:
- # Add new the new room to those groups
- await self.store.add_room_to_group(
- group_id, room_id, old_room is not None and old_room["is_public"]
- )
-
- # Remove the old room from those groups
- await self.store.remove_room_from_group(group_id, old_room_id)
-
async def copy_user_state_on_room_upgrade(
self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
) -> None:
@@ -1160,7 +1155,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
else:
requester = types.create_requester(target_user)
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.GuestAccess, None)])
+ )
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = await self._can_guest_join(prev_state_ids)
@@ -1404,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
- room_state = await self.state_handler.get_current_state(room_id)
+ room_state = await self._storage_controllers.state.get_current_state(
+ room_id,
+ StateFilter.from_types(
+ [
+ (EventTypes.Member, user.to_string()),
+ (EventTypes.CanonicalAlias, ""),
+ (EventTypes.Name, ""),
+ (EventTypes.Create, ""),
+ (EventTypes.JoinRules, ""),
+ (EventTypes.RoomAvatar, ""),
+ ]
+ ),
+ )
inviter_display_name = ""
inviter_avatar_url = ""
@@ -1800,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
- member = await self.state_handler.get_current_state(
+ member = await self._storage_controllers.state.get_current_state_event(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index ff24ec80..13098f56 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -90,6 +90,7 @@ class RoomSummaryHandler:
def __init__(self, hs: "HomeServer"):
self._event_auth_handler = hs.get_event_auth_handler()
self._store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
@@ -537,7 +538,7 @@ class RoomSummaryHandler:
Returns:
True if the room is accessible to the requesting user or server.
"""
- state_ids = await self._store.get_current_state_ids(room_id)
+ state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
# If there's no state for the room, it isn't known.
if not state_ids:
@@ -562,8 +563,13 @@ class RoomSummaryHandler:
if join_rules_event_id:
join_rules_event = await self._store.get_event(join_rules_event_id)
join_rule = join_rules_event.content.get("join_rule")
- if join_rule == JoinRules.PUBLIC or (
- room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+ if (
+ join_rule == JoinRules.PUBLIC
+ or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
+ or (
+ room_version.msc3787_knock_restricted_join_rule
+ and join_rule == JoinRules.KNOCK_RESTRICTED
+ )
):
return True
@@ -657,7 +663,8 @@ class RoomSummaryHandler:
# The API doesn't return the room version so assume that a
# join rule of knock is valid.
if (
- room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
+ room.get("join_rule")
+ in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED)
or room.get("world_readable") is True
):
return True
@@ -696,7 +703,9 @@ class RoomSummaryHandler:
# there should always be an entry
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
- current_state_ids = await self._store.get_current_state_ids(room_id)
+ current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+ room_id
+ )
create_event = await self._store.get_event(
current_state_ids[(EventTypes.Create, "")]
)
@@ -708,9 +717,6 @@ class RoomSummaryHandler:
"canonical_alias": stats["canonical_alias"],
"num_joined_members": stats["joined_members"],
"avatar_url": stats["avatar"],
- # plural join_rules is a documentation error but kept for historical
- # purposes. Should match /publicRooms.
- "join_rules": stats["join_rules"],
"join_rule": stats["join_rules"],
"world_readable": (
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
@@ -757,7 +763,9 @@ class RoomSummaryHandler:
"""
# look for child rooms/spaces.
- current_state_ids = await self._store.get_current_state_ids(room_id)
+ current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+ room_id
+ )
events = await self._store.get_events_as_list(
[
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 5619f8f5..bcab98c6 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -24,7 +24,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -55,8 +55,8 @@ class SearchHandler:
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.auth = hs.get_auth()
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
@@ -348,7 +348,7 @@ class SearchHandler:
state_results = {}
if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}:
- state = await self.state_handler.get_current_state(room_id)
+ state = await self._storage_controllers.state.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = await self._relations_handler.get_bundled_aggregations(
@@ -460,7 +460,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
+ self._storage_controllers, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@@ -559,7 +559,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
+ self._storage_controllers, user.to_string(), filtered_events
)
room_events.extend(events)
@@ -644,22 +644,22 @@ class SearchHandler:
)
events_before = await filter_events_for_client(
- self.storage, user.to_string(), res.events_before
+ self._storage_controllers, user.to_string(), res.events_before
)
events_after = await filter_events_for_client(
- self.storage, user.to_string(), res.events_after
+ self._storage_controllers, user.to_string(), res.events_after
)
context: JsonDict = {
"events_before": events_before,
"events_after": events_after,
"start": await now_token.copy_and_replace(
- "room_key", res.start
+ StreamKeyType.ROOM, res.start
+ ).to_string(self.store),
+ "end": await now_token.copy_and_replace(
+ StreamKeyType.ROOM, res.end
).to_string(self.store),
- "end": await now_token.copy_and_replace("room_key", res.end).to_string(
- self.store
- ),
}
if include_profile:
@@ -677,7 +677,7 @@ class SearchHandler:
[(EventTypes.Member, sender) for sender in senders]
)
- state = await self.state_store.get_state_for_event(
+ state = await self._state_storage_controller.get_state_for_event(
last_event_id, state_filter
)
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 436cd971..f45e06eb 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -40,6 +40,7 @@ class StatsHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self.server_name = hs.hostname
self.clock = hs.get_clock()
@@ -105,7 +106,10 @@ class StatsHandler:
logger.debug(
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
)
- max_pos, deltas = await self.store.get_current_state_deltas(
+ (
+ max_pos,
+ deltas,
+ ) = await self._storage_controllers.state.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2c555a66..b4ead79f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.types import (
Requester,
RoomStreamToken,
StateMap,
+ StreamKeyType,
StreamToken,
UserID,
)
@@ -165,16 +166,6 @@ class KnockedSyncResult:
return True
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class GroupsSyncResult:
- join: JsonDict
- invite: JsonDict
- leave: JsonDict
-
- def __bool__(self) -> bool:
- return bool(self.join or self.invite or self.leave)
-
-
@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
@@ -205,7 +196,6 @@ class SyncResult:
for this device
device_unused_fallback_key_types: List of key types that have an unused fallback
key
- groups: Group updates, if any
"""
next_batch: StreamToken
@@ -219,7 +209,6 @@ class SyncResult:
device_lists: DeviceListUpdates
device_one_time_keys_count: JsonDict
device_unused_fallback_key_types: List[str]
- groups: Optional[GroupsSyncResult]
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -235,7 +224,6 @@ class SyncResult:
or self.account_data
or self.to_device
or self.device_lists
- or self.groups
)
@@ -250,8 +238,8 @@ class SyncHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
@@ -410,10 +398,10 @@ class SyncHandler:
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result
- async def push_rules_for_user(self, user: UserID) -> JsonDict:
+ async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
user_id = user.to_string()
- rules = await self.store.get_push_rules_for_user(user_id)
- rules = format_push_rules_for_user(user, rules)
+ rules_raw = await self.store.get_push_rules_for_user(user_id)
+ rules = format_push_rules_for_user(user, rules_raw)
return rules
async def ephemeral_by_room(
@@ -449,7 +437,7 @@ class SyncHandler:
room_ids=room_ids,
is_guest=sync_config.is_guest,
)
- now_token = now_token.copy_and_replace("typing_key", typing_key)
+ now_token = now_token.copy_and_replace(StreamKeyType.TYPING, typing_key)
ephemeral_by_room: JsonDict = {}
@@ -471,7 +459,7 @@ class SyncHandler:
room_ids=room_ids,
is_guest=sync_config.is_guest,
)
- now_token = now_token.copy_and_replace("receipt_key", receipt_key)
+ now_token = now_token.copy_and_replace(StreamKeyType.RECEIPT, receipt_key)
for event in receipts:
room_id = event["room_id"]
@@ -518,13 +506,15 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents):
- current_state_ids_map = await self.store.get_current_state_ids(
- room_id
+ current_state_ids_map = (
+ await self._state_storage_controller.get_current_state_ids(
+ room_id
+ )
)
current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@@ -537,7 +527,9 @@ class SyncHandler:
prev_batch_token = now_token
if recents:
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+ prev_batch_token = now_token.copy_and_replace(
+ StreamKeyType.ROOM, room_key
+ )
return TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=False
@@ -584,13 +576,16 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
- current_state_ids_map = await self.store.get_current_state_ids(
- room_id
+ # FIXME(faster_joins): We use the partial state here as
+ # we don't want to block `/sync` on finishing a lazy join.
+ # Is this the correct way of doing it?
+ current_state_ids_map = (
+ await self.store.get_partial_current_state_ids(room_id)
)
current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@@ -611,7 +606,7 @@ class SyncHandler:
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+ prev_batch_token = now_token.copy_and_replace(StreamKeyType.ROOM, room_key)
# Don't bother to bundle aggregations if the timeline is unlimited,
# as clients will have all the necessary information.
@@ -631,21 +626,32 @@ class SyncHandler:
)
async def get_state_after_event(
- self, event: EventBase, state_filter: Optional[StateFilter] = None
+ self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the room state after the given event
Args:
- event: event of interest
+ event_id: event of interest
state_filter: The state filter used to fetch state from the database.
"""
- state_ids = await self.state_store.get_state_ids_for_event(
- event.event_id, state_filter=state_filter or StateFilter.all()
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
+ event_id, state_filter=state_filter or StateFilter.all()
)
- if event.is_state():
+
+ # using get_metadata_for_events here (instead of get_event) sidesteps an issue
+ # with redactions: if `event_id` is a redaction event, and we don't have the
+ # original (possibly because it got purged), get_event will refuse to return
+ # the redaction event, which isn't terribly helpful here.
+ #
+ # (To be fair, in that case we could assume it's *not* a state event, and
+ # therefore we don't need to worry about it. But still, it seems cleaner just
+ # to pull the metadata.)
+ m = (await self.store.get_metadata_for_events([event_id]))[event_id]
+ if m.state_key is not None and m.rejection_reason is None:
state_ids = dict(state_ids)
- state_ids[(event.type, event.state_key)] = event.event_id
+ state_ids[(m.event_type, m.state_key)] = event_id
+
return state_ids
async def get_state_at(
@@ -664,14 +670,14 @@ class SyncHandler:
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time.
- last_event = await self.store.get_last_event_in_room_before_stream_ordering(
+ last_event_id = await self.store.get_last_event_in_room_before_stream_ordering(
room_id,
end_token=stream_position.room_key,
)
- if last_event:
+ if last_event_id:
state = await self.get_state_after_event(
- last_event, state_filter=state_filter or StateFilter.all()
+ last_event_id, state_filter=state_filter or StateFilter.all()
)
else:
@@ -720,7 +726,7 @@ class SyncHandler:
return None
last_event = last_events[-1]
- state_ids = await self.state_store.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -898,12 +904,16 @@ class SyncHandler:
if full_state:
if batch:
- current_state_ids = await self.state_store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter
+ current_state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[-1].event_id, state_filter=state_filter
+ )
)
- state_ids = await self.state_store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[0].event_id, state_filter=state_filter
+ )
)
else:
@@ -923,7 +933,7 @@ class SyncHandler:
elif batch.limited:
if batch:
state_at_timeline_start = (
- await self.state_store.get_state_ids_for_event(
+ await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
)
@@ -957,8 +967,10 @@ class SyncHandler:
)
if batch:
- current_state_ids = await self.state_store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter
+ current_state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[-1].event_id, state_filter=state_filter
+ )
)
else:
# Its not clear how we get here, but empirically we do
@@ -988,7 +1000,7 @@ class SyncHandler:
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
- state_ids = await self.state_store.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
@@ -1154,10 +1166,6 @@ class SyncHandler:
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
- if self.hs_config.experimental.groups_enabled:
- logger.debug("Fetching group data")
- await self._generate_sync_entry_for_groups(sync_result_builder)
-
num_events = 0
# debug for https://github.com/matrix-org/synapse/issues/9424
@@ -1181,57 +1189,11 @@ class SyncHandler:
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
- groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
- @measure_func("_generate_sync_entry_for_groups")
- async def _generate_sync_entry_for_groups(
- self, sync_result_builder: "SyncResultBuilder"
- ) -> None:
- user_id = sync_result_builder.sync_config.user.to_string()
- since_token = sync_result_builder.since_token
- now_token = sync_result_builder.now_token
-
- if since_token and since_token.groups_key:
- results = await self.store.get_groups_changes_for_user(
- user_id, since_token.groups_key, now_token.groups_key
- )
- else:
- results = await self.store.get_all_groups_for_user(
- user_id, now_token.groups_key
- )
-
- invited = {}
- joined = {}
- left = {}
- for result in results:
- membership = result["membership"]
- group_id = result["group_id"]
- gtype = result["type"]
- content = result["content"]
-
- if membership == "join":
- if gtype == "membership":
- # TODO: Add profile
- content.pop("membership", None)
- joined[group_id] = content["content"]
- else:
- joined.setdefault(group_id, {})[gtype] = content
- elif membership == "invite":
- if gtype == "membership":
- content.pop("membership", None)
- invited[group_id] = content["content"]
- else:
- if gtype == "membership":
- left[group_id] = content["content"]
-
- sync_result_builder.groups = GroupsSyncResult(
- join=joined, invite=invited, leave=left
- )
-
@measure_func("_generate_sync_entry_for_device_list")
async def _generate_sync_entry_for_device_list(
self,
@@ -1398,7 +1360,7 @@ class SyncHandler:
now_token.to_device_key,
)
sync_result_builder.now_token = now_token.copy_and_replace(
- "to_device_key", stream_id
+ StreamKeyType.TO_DEVICE, stream_id
)
sync_result_builder.to_device = messages
else:
@@ -1503,7 +1465,7 @@ class SyncHandler:
)
assert presence_key
sync_result_builder.now_token = now_token.copy_and_replace(
- "presence_key", presence_key
+ StreamKeyType.PRESENCE, presence_key
)
extra_users_ids = set(newly_joined_or_invited_users)
@@ -1826,7 +1788,7 @@ class SyncHandler:
# stream token as it'll only be used in the context of this
# room. (c.f. the docstring of `to_room_stream_token`).
leave_token = since_token.copy_and_replace(
- "room_key", leave_position.to_room_stream_token()
+ StreamKeyType.ROOM, leave_position.to_room_stream_token()
)
# If this is an out of band message, like a remote invite
@@ -1875,7 +1837,9 @@ class SyncHandler:
if room_entry:
events, start_key = room_entry
- prev_batch_token = now_token.copy_and_replace("room_key", start_key)
+ prev_batch_token = now_token.copy_and_replace(
+ StreamKeyType.ROOM, start_key
+ )
entry = RoomSyncResultBuilder(
room_id=room_id,
@@ -1972,7 +1936,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
- "room_key", RoomStreamToken(None, event.stream_ordering)
+ StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(
@@ -2328,7 +2292,6 @@ class SyncResultBuilder:
invited
knocked
archived
- groups
to_device
"""
@@ -2344,7 +2307,6 @@ class SyncResultBuilder:
invited: List[InvitedSyncResult] = attr.Factory(list)
knocked: List[KnockedSyncResult] = attr.Factory(list)
archived: List[ArchivedSyncResult] = attr.Factory(list)
- groups: Optional[GroupsSyncResult] = None
to_device: List[JsonDict] = attr.Factory(list)
def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6854428b..d104ea07 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
+from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import (
@@ -25,7 +26,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -58,6 +59,7 @@ class FollowerTypingHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -68,7 +70,7 @@ class FollowerTypingHandler:
if hs.get_instance_name() not in hs.config.worker.writers.typing:
hs.get_federation_registry().register_instances_for_edu(
- "m.typing",
+ EduTypes.TYPING,
hs.config.worker.writers.typing,
)
@@ -130,7 +132,6 @@ class FollowerTypingHandler:
return
try:
- users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@@ -138,12 +139,15 @@ class FollowerTypingHandler:
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
- for domain in {get_domain_from_id(u) for u in users}:
+ hosts = await self._storage_controllers.state.get_current_hosts_in_room(
+ member.room_id
+ )
+ for domain in hosts:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
destination=domain,
- edu_type="m.typing",
+ edu_type=EduTypes.TYPING,
content={
"room_id": member.room_id,
"user_id": member.user_id,
@@ -218,7 +222,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self.hs = hs
- hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+ hs.get_federation_registry().register_edu_handler(
+ EduTypes.TYPING, self._recv_edu
+ )
hs.get_distributor().observe("user_left_room", self.user_left_room)
@@ -382,7 +388,7 @@ class TypingWriterHandler(FollowerTypingHandler):
)
self.notifier.on_new_event(
- "typing_key", self._latest_room_serial, rooms=[member.room_id]
+ StreamKeyType.TYPING, self._latest_room_serial, rooms=[member.room_id]
)
async def get_all_typing_updates(
@@ -458,7 +464,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
def _make_event_for(self, room_id: str) -> JsonDict:
typing = self.get_typing_handler()._room_typing[room_id]
return {
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": room_id,
"content": {"user_ids": list(typing)},
}
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 74f7fdfe..8c3c52e1 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler):
super().__init__(hs)
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
@@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug(
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
)
- max_pos, deltas = await self.store.get_current_state_deltas(
+ (
+ max_pos,
+ deltas,
+ ) = await self._storage_controllers.state.get_current_state_deltas(
self.pos, room_max_stream_ordering
)