summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/api/auth_blocking.py18
-rw-r--r--synapse/api/constants.py13
-rw-r--r--synapse/api/room_versions.py22
-rw-r--r--synapse/api/urls.py4
-rw-r--r--synapse/app/_base.py10
-rw-r--r--synapse/app/admin_cmd.py8
-rw-r--r--synapse/app/generic_worker.py10
-rw-r--r--synapse/app/homeserver.py14
-rw-r--r--synapse/app/phone_stats_home.py8
-rw-r--r--synapse/config/_base.py9
-rw-r--r--synapse/config/consent.py9
-rw-r--r--synapse/config/logger.py4
-rw-r--r--synapse/config/password_auth_providers.py4
-rw-r--r--synapse/config/server.py2
-rw-r--r--synapse/config/user_directory.py14
-rw-r--r--synapse/crypto/context_factory.py12
-rw-r--r--synapse/crypto/keyring.py2
-rw-r--r--synapse/event_auth.py20
-rw-r--r--synapse/events/__init__.py34
-rw-r--r--synapse/events/snapshot.py14
-rw-r--r--synapse/events/spamcheck.py44
-rw-r--r--synapse/events/third_party_rules.py4
-rw-r--r--synapse/events/utils.py8
-rw-r--r--synapse/federation/federation_base.py6
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/federation_server.py8
-rw-r--r--synapse/federation/sender/__init__.py2
-rw-r--r--synapse/federation/sender/per_destination_queue.py4
-rw-r--r--synapse/federation/transport/server/_base.py4
-rw-r--r--synapse/groups/groups_server.py6
-rw-r--r--synapse/handlers/_base.py14
-rw-r--r--synapse/handlers/account_data.py13
-rw-r--r--synapse/handlers/account_validity.py6
-rw-r--r--synapse/handlers/appservice.py26
-rw-r--r--synapse/handlers/auth.py81
-rw-r--r--synapse/handlers/cas.py26
-rw-r--r--synapse/handlers/deactivate_account.py12
-rw-r--r--synapse/handlers/device.py2
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/e2e_keys.py4
-rw-r--r--synapse/handlers/event_auth.py7
-rw-r--r--synapse/handlers/federation.py159
-rw-r--r--synapse/handlers/federation_event.py259
-rw-r--r--synapse/handlers/groups_local.py8
-rw-r--r--synapse/handlers/identity.py12
-rw-r--r--synapse/handlers/initial_sync.py12
-rw-r--r--synapse/handlers/message.py51
-rw-r--r--synapse/handlers/oidc.py36
-rw-r--r--synapse/handlers/pagination.py19
-rw-r--r--synapse/handlers/password_policy.py4
-rw-r--r--synapse/handlers/presence.py67
-rw-r--r--synapse/handlers/profile.py24
-rw-r--r--synapse/handlers/receipts.py17
-rw-r--r--synapse/handlers/register.py24
-rw-r--r--synapse/handlers/room.py58
-rw-r--r--synapse/handlers/room_list.py14
-rw-r--r--synapse/handlers/room_member.py22
-rw-r--r--synapse/handlers/room_summary.py4
-rw-r--r--synapse/handlers/saml.py29
-rw-r--r--synapse/handlers/send_email.py4
-rw-r--r--synapse/handlers/sso.py16
-rw-r--r--synapse/handlers/stats.py4
-rw-r--r--synapse/handlers/sync.py19
-rw-r--r--synapse/handlers/typing.py15
-rw-r--r--synapse/handlers/ui_auth/checkers.py19
-rw-r--r--synapse/handlers/user_directory.py4
-rw-r--r--synapse/http/client.py7
-rw-r--r--synapse/http/matrixfederationclient.py24
-rw-r--r--synapse/http/server.py90
-rw-r--r--synapse/http/site.py75
-rw-r--r--synapse/logging/opentracing.py6
-rw-r--r--synapse/module_api/__init__.py92
-rw-r--r--synapse/notifier.py2
-rw-r--r--synapse/push/emailpusher.py2
-rw-r--r--synapse/push/httppusher.py6
-rw-r--r--synapse/push/mailer.py10
-rw-r--r--synapse/push/pusher.py10
-rw-r--r--synapse/push/pusherpool.py2
-rw-r--r--synapse/replication/http/_base.py4
-rw-r--r--synapse/replication/tcp/handler.py4
-rw-r--r--synapse/rest/__init__.py11
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/devices.py2
-rw-r--r--synapse/rest/admin/registration_tokens.py2
-rw-r--r--synapse/rest/admin/rooms.py6
-rw-r--r--synapse/rest/admin/server_notice_servlet.py2
-rw-r--r--synapse/rest/admin/users.py6
-rw-r--r--synapse/rest/client/account.py40
-rw-r--r--synapse/rest/client/auth.py10
-rw-r--r--synapse/rest/client/devices.py4
-rw-r--r--synapse/rest/client/login.py18
-rw-r--r--synapse/rest/client/password_policy.py8
-rw-r--r--synapse/rest/client/register.py30
-rw-r--r--synapse/rest/client/room_batch.py130
-rw-r--r--synapse/rest/client/user_directory.py2
-rw-r--r--synapse/rest/client/versions.py6
-rw-r--r--synapse/rest/client/voip.py12
-rw-r--r--synapse/rest/consent/consent_resource.py52
-rw-r--r--synapse/rest/health.py3
-rw-r--r--synapse/rest/key/v2/__init__.py7
-rw-r--r--synapse/rest/key/v2/local_key_resource.py25
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py35
-rw-r--r--synapse/rest/media/v1/_base.py115
-rw-r--r--synapse/rest/media/v1/config_resource.py6
-rw-r--r--synapse/rest/media/v1/download_resource.py5
-rw-r--r--synapse/rest/media/v1/filepath.py17
-rw-r--r--synapse/rest/media/v1/media_repository.py82
-rw-r--r--synapse/rest/media/v1/media_storage.py74
-rw-r--r--synapse/rest/media/v1/oembed.py188
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py345
-rw-r--r--synapse/rest/media/v1/storage_provider.py19
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py94
-rw-r--r--synapse/rest/media/v1/thumbnailer.py2
-rw-r--r--synapse/rest/media/v1/upload_resource.py6
-rw-r--r--synapse/rest/synapse/client/__init__.py4
-rw-r--r--synapse/rest/synapse/client/new_user_consent.py6
-rw-r--r--synapse/rest/synapse/client/oidc/__init__.py6
-rw-r--r--synapse/rest/synapse/client/oidc/callback_resource.py5
-rw-r--r--synapse/rest/synapse/client/password_reset.py10
-rw-r--r--synapse/rest/synapse/client/pick_username.py9
-rw-r--r--synapse/rest/synapse/client/saml2/__init__.py6
-rw-r--r--synapse/rest/synapse/client/saml2/metadata_resource.py11
-rw-r--r--synapse/rest/synapse/client/saml2/response_resource.py7
-rw-r--r--synapse/rest/well_known.py20
-rw-r--r--synapse/server.py16
-rw-r--r--synapse/server_notices/consent_server_notices.py11
-rw-r--r--synapse/server_notices/server_notices_manager.py23
-rw-r--r--synapse/state/__init__.py34
-rw-r--r--synapse/storage/database.py21
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/account_data.py4
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/client_ips.py31
-rw-r--r--synapse/storage/databases/main/deviceinbox.py6
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py30
-rw-r--r--synapse/storage/databases/main/events.py48
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py4
-rw-r--r--synapse/storage/databases/main/events_worker.py2
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py2
-rw-r--r--synapse/storage/databases/main/purge_events.py22
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/storage/databases/main/registration.py4
-rw-r--r--synapse/storage/databases/main/room_batch.py36
-rw-r--r--synapse/storage/databases/main/roommember.py13
-rw-r--r--synapse/storage/databases/main/search.py34
-rw-r--r--synapse/storage/databases/main/state.py4
-rw-r--r--synapse/storage/databases/main/state_deltas.py2
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/stream.py8
-rw-r--r--synapse/storage/databases/main/ui_auth.py6
-rw-r--r--synapse/storage/databases/main/user_directory.py155
-rw-r--r--synapse/storage/databases/state/bg_updates.py60
-rw-r--r--synapse/storage/databases/state/store.py142
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/relations.py4
-rw-r--r--synapse/storage/schema/__init__.py13
-rw-r--r--synapse/storage/schema/main/delta/30/as_users.py2
-rw-r--r--synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.postgres23
-rw-r--r--synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.sqlite37
-rw-r--r--synapse/storage/state.py48
-rw-r--r--synapse/streams/__init__.py22
-rw-r--r--synapse/streams/config.py2
-rw-r--r--synapse/streams/events.py49
-rw-r--r--synapse/types.py30
-rw-r--r--synapse/util/caches/__init__.py31
-rw-r--r--synapse/util/caches/deferred_cache.py2
-rw-r--r--synapse/util/caches/descriptors.py5
-rw-r--r--synapse/util/caches/dictionary_cache.py4
-rw-r--r--synapse/util/caches/expiringcache.py10
-rw-r--r--synapse/util/caches/lrucache.py20
-rw-r--r--synapse/util/iterutils.py19
175 files changed, 2503 insertions, 1770 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 5f5cff1d..b8979c36 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.43.0"
+__version__ = "1.44.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 05699714..e6ca9232 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -70,8 +70,8 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
- self._track_appservice_user_ips = hs.config.track_appservice_user_ips
- self._macaroon_secret_key = hs.config.macaroon_secret_key
+ self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
+ self._macaroon_secret_key = hs.config.key.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
async def check_user_in_room(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index e6bced93..08fe160c 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -30,13 +30,15 @@ class AuthBlocking:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- self._server_notices_mxid = hs.config.server_notices_mxid
- self._hs_disabled = hs.config.hs_disabled
- self._hs_disabled_message = hs.config.hs_disabled_message
- self._admin_contact = hs.config.admin_contact
- self._max_mau_value = hs.config.max_mau_value
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
+ self._hs_disabled = hs.config.server.hs_disabled
+ self._hs_disabled_message = hs.config.server.hs_disabled_message
+ self._admin_contact = hs.config.server.admin_contact
+ self._max_mau_value = hs.config.server.max_mau_value
+ self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
+ self._mau_limits_reserved_threepids = (
+ hs.config.server.mau_limits_reserved_threepids
+ )
self._server_name = hs.hostname
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
@@ -79,7 +81,7 @@ class AuthBlocking:
# We never block the server from doing actions on behalf of
# users.
return
- elif requester.app_service and not self._track_appservice_user_ips:
+ if requester.app_service and not self._track_appservice_user_ips:
# If we're authenticated as an appservice then we only block
# auth if `track_appservice_user_ips` is set, as that option
# implicitly means that application services are part of MAU
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 236f0c7f..a31f0377 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -121,7 +121,7 @@ class EventTypes:
SpaceParent = "m.space.parent"
MSC2716_INSERTION = "org.matrix.msc2716.insertion"
- MSC2716_CHUNK = "org.matrix.msc2716.chunk"
+ MSC2716_BATCH = "org.matrix.msc2716.batch"
MSC2716_MARKER = "org.matrix.msc2716.marker"
@@ -209,14 +209,17 @@ class EventContentFields:
# Used on normal messages to indicate they were historically imported after the fact
MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
- # For "insertion" events to indicate what the next chunk ID should be in
+ # For "insertion" events to indicate what the next batch ID should be in
# order to connect to it
- MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id"
- # Used on "chunk" events to indicate which insertion event it connects to
- MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id"
+ MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id"
+ # Used on "batch" events to indicate which insertion event it connects to
+ MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id"
# For "marker" events
MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
+ # The authorising user for joining a restricted room.
+ AUTHORISING_USER = "join_authorised_via_users_server"
+
class RoomTypes:
"""Understood values of the room_type field of m.room.create events."""
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 61d9c658..0a895bba 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -244,24 +244,8 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
)
- MSC2716 = RoomVersion(
- "org.matrix.msc2716",
- RoomDisposition.UNSTABLE,
- EventFormatVersions.V3,
- StateResolutionVersions.V2,
- enforce_key_validity=True,
- special_case_aliases_auth=False,
- strict_canonicaljson=True,
- limit_notifications_power_levels=True,
- msc2176_redaction_rules=False,
- msc3083_join_rules=False,
- msc3375_redaction_rules=False,
- msc2403_knocking=True,
- msc2716_historical=True,
- msc2716_redactions=False,
- )
- MSC2716v2 = RoomVersion(
- "org.matrix.msc2716v2",
+ MSC2716v3 = RoomVersion(
+ "org.matrix.msc2716v3",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
@@ -289,9 +273,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V6,
RoomVersions.MSC2176,
RoomVersions.V7,
- RoomVersions.MSC2716,
RoomVersions.V8,
RoomVersions.V9,
+ RoomVersions.MSC2716v3,
)
}
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index d3270cd6..032c69b2 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -39,12 +39,12 @@ class ConsentURIBuilder:
Args:
hs_config (synapse.config.homeserver.HomeServerConfig):
"""
- if hs_config.form_secret is None:
+ if hs_config.key.form_secret is None:
raise ConfigError("form_secret not set in config")
if hs_config.server.public_baseurl is None:
raise ConfigError("public_baseurl not set in config")
- self._hmac_secret = hs_config.form_secret.encode("utf-8")
+ self._hmac_secret = hs_config.key.form_secret.encode("utf-8")
self._public_baseurl = hs_config.server.public_baseurl
def build_user_consent_uri(self, user_id):
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index d1aa2e7f..548f6dcd 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -88,8 +88,8 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
appname,
soft_file_limit=config.soft_file_limit,
gc_thresholds=config.gc_thresholds,
- pid_file=config.worker_pid_file,
- daemonize=config.worker_daemonize,
+ pid_file=config.worker.worker_pid_file,
+ daemonize=config.worker.worker_daemonize,
print_pidfile=config.print_pidfile,
logger=logger,
run_command=run_command,
@@ -424,12 +424,14 @@ def setup_sentry(hs):
hs (synapse.server.HomeServer)
"""
- if not hs.config.sentry_enabled:
+ if not hs.config.metrics.sentry_enabled:
return
import sentry_sdk
- sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse))
+ sentry_sdk.init(
+ dsn=hs.config.metrics.sentry_dsn, release=get_version_string(synapse)
+ )
# We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope:
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 5e956b1e..f2c5b752 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -186,13 +186,13 @@ def start(config_options):
config.worker.worker_app = "synapse.app.admin_cmd"
if (
- not config.worker_daemonize
- and not config.worker_log_file
- and not config.worker_log_config
+ not config.worker.worker_daemonize
+ and not config.worker.worker_log_file
+ and not config.worker.worker_log_config
):
# Since we're meant to be run as a "command" let's not redirect stdio
# unless we've actually set log config.
- config.no_redirect_stdio = True
+ config.logging.no_redirect_stdio = True
# Explicitly disable background processes
config.update_user_directory = False
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 33afd59c..3036e1b4 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -140,7 +140,7 @@ class KeyUploadServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.http_client = hs.get_simple_http_client()
- self.main_uri = hs.config.worker_main_http_uri
+ self.main_uri = hs.config.worker.worker_main_http_uri
async def on_POST(self, request: Request, device_id: Optional[str]):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -321,7 +321,7 @@ class GenericWorkerServer(HomeServer):
elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media":
- if self.config.can_load_media_repo:
+ if self.config.media.can_load_media_repo:
media_repo = self.get_media_repository_resource()
# We need to serve the admin servlets for media on the
@@ -384,7 +384,7 @@ class GenericWorkerServer(HomeServer):
logger.info("Synapse worker now listening on port %d", port)
def start_listening(self):
- for listener in self.config.worker_listeners:
+ for listener in self.config.worker.worker_listeners:
if listener.type == "http":
self._listen_http(listener)
elif listener.type == "manhole":
@@ -395,7 +395,7 @@ class GenericWorkerServer(HomeServer):
manhole_globals={"hs": self},
)
elif listener.type == "metrics":
- if not self.config.enable_metrics:
+ if not self.config.metrics.enable_metrics:
logger.warning(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -488,7 +488,7 @@ def start(config_options):
register_start(_base.start, hs)
# redirect stdio to the logs, if configured.
- if not hs.config.no_redirect_stdio:
+ if not hs.config.logging.no_redirect_stdio:
redirect_stdio_to_logs()
_base.start_worker_reactor("synapse-generic-worker", config)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index b909f8db..205831dc 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -195,7 +195,7 @@ class SynapseHomeServer(HomeServer):
}
)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
from synapse.rest.synapse.client.password_reset import (
PasswordResetSubmitTokenResource,
)
@@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer):
)
if name in ["media", "federation", "client"]:
- if self.config.enable_media_repo:
+ if self.config.media.enable_media_repo:
media_repo = self.get_media_repository_resource()
resources.update(
{MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo}
@@ -269,7 +269,7 @@ class SynapseHomeServer(HomeServer):
# https://twistedmatrix.com/trac/ticket/7678
resources[WEB_CLIENT_PREFIX] = File(webclient_loc)
- if name == "metrics" and self.config.enable_metrics:
+ if name == "metrics" and self.config.metrics.enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
if name == "replication":
@@ -278,7 +278,7 @@ class SynapseHomeServer(HomeServer):
return resources
def start_listening(self):
- if self.config.redis_enabled:
+ if self.config.redis.redis_enabled:
# If redis is enabled we connect via the replication command handler
# in the same way as the workers (since we're effectively a client
# rather than a server).
@@ -305,7 +305,7 @@ class SynapseHomeServer(HomeServer):
for s in services:
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener.type == "metrics":
- if not self.config.enable_metrics:
+ if not self.config.metrics.enable_metrics:
logger.warning(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -366,7 +366,7 @@ def setup(config_options):
async def start():
# Load the OIDC provider metadatas, if OIDC is enabled.
- if hs.config.oidc_enabled:
+ if hs.config.oidc.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
@@ -455,7 +455,7 @@ def main():
hs = setup(sys.argv[1:])
# redirect stdio to the logs, if configured.
- if not hs.config.no_redirect_stdio:
+ if not hs.config.logging.no_redirect_stdio:
redirect_stdio_to_logs()
run(hs)
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 4a95da90..49e7a45e 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -131,10 +131,12 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
log_level = synapse_logger.getEffectiveLevel()
stats["log_level"] = logging.getLevelName(log_level)
- logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
+ logger.info(
+ "Reporting stats to %s: %s" % (hs.config.metrics.report_stats_endpoint, stats)
+ )
try:
await hs.get_proxied_http_client().put_json(
- hs.config.report_stats_endpoint, stats
+ hs.config.metrics.report_stats_endpoint, stats
)
except Exception as e:
logger.warning("Error reporting stats: %s", e)
@@ -188,7 +190,7 @@ def start_phone_stats_home(hs):
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings
- if hs.config.report_stats:
+ if hs.config.metrics.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 2cc24278..d974a1a2 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -200,11 +200,7 @@ class Config:
@classmethod
def ensure_directory(cls, dir_path):
dir_path = cls.abspath(dir_path)
- try:
- os.makedirs(dir_path)
- except OSError as e:
- if e.errno != errno.EEXIST:
- raise
+ os.makedirs(dir_path, exist_ok=True)
if not os.path.isdir(dir_path):
raise ConfigError("%s is not a directory" % (dir_path,))
return dir_path
@@ -693,8 +689,7 @@ class RootConfig:
open_private_ports=config_args.open_private_ports,
)
- if not path_exists(config_dir_path):
- os.makedirs(config_dir_path)
+ os.makedirs(config_dir_path, exist_ok=True)
with open(config_path, "w") as config_file:
config_file.write(config_str)
config_file.write("\n\n# vim:ft=yaml")
diff --git a/synapse/config/consent.py b/synapse/config/consent.py
index b05a9bd9..ecc43b08 100644
--- a/synapse/config/consent.py
+++ b/synapse/config/consent.py
@@ -13,6 +13,7 @@
# limitations under the License.
from os import path
+from typing import Optional
from synapse.config import ConfigError
@@ -78,8 +79,8 @@ class ConsentConfig(Config):
def __init__(self, *args):
super().__init__(*args)
- self.user_consent_version = None
- self.user_consent_template_dir = None
+ self.user_consent_version: Optional[str] = None
+ self.user_consent_template_dir: Optional[str] = None
self.user_consent_server_notice_content = None
self.user_consent_server_notice_to_guests = False
self.block_events_without_consent_error = None
@@ -94,7 +95,9 @@ class ConsentConfig(Config):
return
self.user_consent_version = str(consent_config["version"])
self.user_consent_template_dir = self.abspath(consent_config["template_dir"])
- if not path.isdir(self.user_consent_template_dir):
+ if not isinstance(self.user_consent_template_dir, str) or not path.isdir(
+ self.user_consent_template_dir
+ ):
raise ConfigError(
"Could not find template directory '%s'"
% (self.user_consent_template_dir,)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index aca9d467..0a08231e 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -322,7 +322,9 @@ def setup_logging(
"""
log_config_path = (
- config.worker_log_config if use_worker_options else config.log_config
+ config.worker.worker_log_config
+ if use_worker_options
+ else config.logging.log_config
)
# Perform one-time logging configuration.
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 0f5b2b39..83994df7 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List
+from typing import Any, List, Tuple, Type
from synapse.util.module_loader import load_module
@@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
- self.password_providers: List[Any] = []
+ self.password_providers: List[Tuple[Type, Any]] = []
providers = []
# We want to be backwards compatible with the old `ldap_config`
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7b9109a5..ad8715da 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1447,7 +1447,7 @@ def read_gc_thresholds(thresholds):
return None
try:
assert len(thresholds) == 3
- return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2]))
+ return int(thresholds[0]), int(thresholds[1]), int(thresholds[2])
except Exception:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index b10df8a2..2552f688 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -45,12 +45,16 @@ class UserDirectoryConfig(Config):
#enabled: false
# Defines whether to search all users visible to your HS when searching
- # the user directory, rather than limiting to users visible in public
- # rooms. Defaults to false.
+ # the user directory. If false, search results will only contain users
+ # visible in public rooms and users sharing a room with the requester.
+ # Defaults to false.
#
- # If you set it true, you'll have to rebuild the user_directory search
- # indexes, see:
- # https://matrix-org.github.io/synapse/latest/user_directory.html
+ # NB. If you set this to true, and the last time the user_directory search
+ # indexes were (re)built was before Synapse 1.44, you'll have to
+ # rebuild the indexes in order to search through all known users.
+ # These indexes are built the first time Synapse starts; admins can
+ # manually trigger a rebuild following the instructions at
+ # https://matrix-org.github.io/synapse/latest/user_directory.html
#
# Uncomment to return search results containing all known users, even if that
# user does not share a room with the requester.
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index c644b4df..2a6110eb 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -74,8 +74,8 @@ class ServerContextFactory(ContextFactory):
context.set_options(
SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1
)
- context.use_certificate_chain_file(config.tls_certificate_file)
- context.use_privatekey(config.tls_private_key)
+ context.use_certificate_chain_file(config.tls.tls_certificate_file)
+ context.use_privatekey(config.tls.tls_private_key)
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
context.set_cipher_list(
@@ -102,7 +102,7 @@ class FederationPolicyForHTTPS:
self._config = config
# Check if we're using a custom list of a CA certificates
- trust_root = config.federation_ca_trust_root
+ trust_root = config.tls.federation_ca_trust_root
if trust_root is None:
# Use CA root certs provided by OpenSSL
trust_root = platformTrust()
@@ -113,7 +113,7 @@ class FederationPolicyForHTTPS:
# moving to TLS 1.2 by default, we want to respect the config option if
# it is set to 1.0 (which the alternate option, raiseMinimumTo, will not
# let us do).
- minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
+ minTLS = _TLS_VERSION_MAP[config.tls.federation_client_minimum_tls_version]
_verify_ssl = CertificateOptions(
trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
@@ -125,10 +125,10 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
- self._should_verify = self._config.federation_verify_certificates
+ self._should_verify = self._config.tls.federation_verify_certificates
self._federation_certificate_verification_whitelist = (
- self._config.federation_certificate_verification_whitelist
+ self._config.tls.federation_certificate_verification_whitelist
)
def get_options(self, host: bytes):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 9e9b1c1c..e1e13a24 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -572,7 +572,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
- self.key_servers = self.config.key_servers
+ self.key_servers = self.config.key.key_servers
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index cb133f3f..65040283 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -115,11 +115,11 @@ def check(
is_invite_via_allow_rule = (
event.type == EventTypes.Member
and event.membership == Membership.JOIN
- and "join_authorised_via_users_server" in event.content
+ and EventContentFields.AUTHORISING_USER in event.content
)
if is_invite_via_allow_rule:
authoriser_domain = get_domain_from_id(
- event.content["join_authorised_via_users_server"]
+ event.content[EventContentFields.AUTHORISING_USER]
)
if not event.signatures.get(authoriser_domain):
raise AuthError(403, "Event not signed by authorising server")
@@ -213,7 +213,7 @@ def check(
if (
event.type == EventTypes.MSC2716_INSERTION
- or event.type == EventTypes.MSC2716_CHUNK
+ or event.type == EventTypes.MSC2716_BATCH
or event.type == EventTypes.MSC2716_MARKER
):
check_historical(room_version_obj, event, auth_events)
@@ -381,7 +381,9 @@ def _is_membership_change_allowed(
# Note that if the caller is in the room or invited, then they do
# not need to meet the allow rules.
if not caller_in_room and not caller_invited:
- authorising_user = event.content.get("join_authorised_via_users_server")
+ authorising_user = event.content.get(
+ EventContentFields.AUTHORISING_USER
+ )
if authorising_user is None:
raise AuthError(403, "Join event is missing authorising user.")
@@ -552,14 +554,14 @@ def check_historical(
auth_events: StateMap[EventBase],
) -> None:
"""Check whether the event sender is allowed to send historical related
- events like "insertion", "chunk", and "marker".
+ events like "insertion", "batch", and "marker".
Returns:
None
Raises:
AuthError if the event sender is not allowed to send historical related events
- ("insertion", "chunk", and "marker").
+ ("insertion", "batch", and "marker").
"""
# Ignore the auth checks in room versions that do not support historical
# events
@@ -573,7 +575,7 @@ def check_historical(
if user_level < historical_level:
raise AuthError(
403,
- 'You don\'t have permission to send send historical related events ("insertion", "chunk", and "marker")',
+ 'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
)
@@ -836,10 +838,10 @@ def auth_types_for_event(
auth_types.add(key)
if room_version.msc3083_join_rules and membership == Membership.JOIN:
- if "join_authorised_via_users_server" in event.content:
+ if EventContentFields.AUTHORISING_USER in event.content:
key = (
EventTypes.Member,
- event.content["join_authorised_via_users_server"],
+ event.content[EventContentFields.AUTHORISING_USER],
)
auth_types.add(key)
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index a730c171..49190459 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -344,6 +344,18 @@ class EventBase(metaclass=abc.ABCMeta):
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ return "<%s event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
+ self.__class__.__name__,
+ self.event_id,
+ self.get("type", None),
+ self.get("state_key", None),
+ self.internal_metadata.is_outlier(),
+ )
+
class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1
@@ -392,17 +404,6 @@ class FrozenEvent(EventBase):
def event_id(self) -> str:
return self._event_id
- def __str__(self):
- return self.__repr__()
-
- def __repr__(self):
- return "<FrozenEvent event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
- self.get("event_id", None),
- self.get("type", None),
- self.get("state_key", None),
- self.internal_metadata.is_outlier(),
- )
-
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2
@@ -478,17 +479,6 @@ class FrozenEventV2(EventBase):
"""
return self.auth_events
- def __str__(self):
- return self.__repr__()
-
- def __repr__(self):
- return "<%s event_id=%r, type=%r, state_key=%r>" % (
- self.__class__.__name__,
- self.event_id,
- self.get("type", None),
- self.get("state_key", None),
- )
-
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f8d898c3..5ba01eee 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -80,9 +80,7 @@ class EventContext:
(type, state_key) -> event_id
- FIXME: what is this for an outlier? it seems ill-defined. It seems like
- it could be either {}, or the state we were given by the remote
- server, depending on $THINGS
+ For an outlier, this is {}
Note that this is a private attribute: it should be accessed via
``get_current_state_ids``. _AsyncEventContext impl calculates this
@@ -96,7 +94,7 @@ class EventContext:
(type, state_key) -> event_id
- FIXME: again, what is this for an outlier?
+ For an outlier, this is {}
As with _current_state_ids, this is a private attribute. It should be
accessed via get_prev_state_ids.
@@ -130,6 +128,14 @@ class EventContext:
delta_ids=delta_ids,
)
+ @staticmethod
+ def for_outlier():
+ """Return an EventContext instance suitable for persisting an outlier event"""
+ return EventContext(
+ current_state_ids={},
+ prev_state_ids={},
+ )
+
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 57f1d53f..c389f70b 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -46,6 +46,9 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
]
USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
+USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[
+ [str, List[str], List[Dict[str, str]]], Awaitable[bool]
+]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
@@ -78,7 +81,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
"""
spam_checkers: List[Any] = []
api = hs.get_module_api()
- for module, config in hs.config.spam_checkers:
+ for module, config in hs.config.spamchecker.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
@@ -164,6 +167,9 @@ class SpamChecker:
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
+ self._user_may_create_room_with_invites_callbacks: List[
+ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
+ ] = []
self._user_may_create_room_alias_callbacks: List[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = []
@@ -183,6 +189,9 @@ class SpamChecker:
check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
+ user_may_create_room_with_invites: Optional[
+ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
+ ] = None,
user_may_create_room_alias: Optional[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = None,
@@ -203,6 +212,11 @@ class SpamChecker:
if user_may_create_room is not None:
self._user_may_create_room_callbacks.append(user_may_create_room)
+ if user_may_create_room_with_invites is not None:
+ self._user_may_create_room_with_invites_callbacks.append(
+ user_may_create_room_with_invites,
+ )
+
if user_may_create_room_alias is not None:
self._user_may_create_room_alias_callbacks.append(
user_may_create_room_alias,
@@ -283,6 +297,34 @@ class SpamChecker:
return True
+ async def user_may_create_room_with_invites(
+ self,
+ userid: str,
+ invites: List[str],
+ threepid_invites: List[Dict[str, str]],
+ ) -> bool:
+ """Checks if a given user may create a room with invites
+
+ If this method returns false, the creation request will be rejected.
+
+ Args:
+ userid: The ID of the user attempting to create a room
+ invites: The IDs of the Matrix users to be invited if the room creation is
+ allowed.
+ threepid_invites: The threepids to be invited if the room creation is allowed,
+ as a dict including a "medium" key indicating the threepid's medium (e.g.
+ "email") and an "address" key indicating the threepid's address (e.g.
+ "alice@example.com")
+
+ Returns:
+ True if the user may create the room, otherwise False
+ """
+ for callback in self._user_may_create_room_with_invites_callbacks:
+ if await callback(userid, invites, threepid_invites) is False:
+ return False
+
+ return True
+
async def user_may_create_room_alias(
self, userid: str, room_alias: RoomAlias
) -> bool:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 7a6eb3e5..d94b1bb4 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -42,10 +42,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
"""Wrapper that loads a third party event rules module configured using the old
configuration, and registers the hooks they implement.
"""
- if hs.config.third_party_event_rules is None:
+ if hs.config.thirdpartyrules.third_party_event_rules is None:
return
- module, config = hs.config.third_party_event_rules
+ module, config = hs.config.thirdpartyrules.third_party_event_rules
api = hs.get_module_api()
third_party_rules = module(config=config, module_api=api)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index fb22337e..38fccd1e 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -105,7 +105,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
if event_type == EventTypes.Member:
add_fields("membership")
if room_version.msc3375_redaction_rules:
- add_fields("join_authorised_via_users_server")
+ add_fields(EventContentFields.AUTHORISING_USER)
elif event_type == EventTypes.Create:
# MSC2176 rules state that create events cannot be redacted.
if room_version.msc2176_redaction_rules:
@@ -141,9 +141,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
add_fields("redacts")
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION:
- add_fields(EventContentFields.MSC2716_NEXT_CHUNK_ID)
- elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_CHUNK:
- add_fields(EventContentFields.MSC2716_CHUNK_ID)
+ add_fields(EventContentFields.MSC2716_NEXT_BATCH_ID)
+ elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
+ add_fields(EventContentFields.MSC2716_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
add_fields(EventContentFields.MSC2716_MARKER_INSERTION)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 024e440f..0cd424e1 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -15,7 +15,7 @@
import logging
from collections import namedtuple
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash
@@ -184,10 +184,10 @@ async def _check_sigs_on_pdu(
room_version.msc3083_join_rules
and pdu.type == EventTypes.Member
and pdu.membership == Membership.JOIN
- and "join_authorised_via_users_server" in pdu.content
+ and EventContentFields.AUTHORISING_USER in pdu.content
):
authorising_server = get_domain_from_id(
- pdu.content["join_authorised_via_users_server"]
+ pdu.content[EventContentFields.AUTHORISING_USER]
)
try:
await keyring.verify_event_for_server(
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 1416abd0..2ab4dec8 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -37,7 +37,7 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -501,8 +501,6 @@ class FederationClient(FederationBase):
destination, auth_chain, outlier=True, room_version=room_version
)
- signed_auth.sort(key=lambda e: e.depth)
-
return signed_auth
def _is_unknown_endpoint(
@@ -877,9 +875,9 @@ class FederationClient(FederationBase):
# If the join is being authorised via allow rules, we need to send
# the /send_join back to the same server that was originally used
# with /make_join.
- if "join_authorised_via_users_server" in pdu.content:
+ if EventContentFields.AUTHORISING_USER in pdu.content:
destinations = [
- get_domain_from_id(pdu.content["join_authorised_via_users_server"])
+ get_domain_from_id(pdu.content[EventContentFields.AUTHORISING_USER])
]
return await self._try_destination_list(
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 214ee948..5f4383ee 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -34,7 +34,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import EduTypes, EventTypes, Membership
+from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -765,11 +765,11 @@ class FederationServer(FederationBase):
if (
room_version.msc3083_join_rules
and event.membership == Membership.JOIN
- and "join_authorised_via_users_server" in event.content
+ and EventContentFields.AUTHORISING_USER in event.content
):
# We can only authorise our own users.
authorising_server = get_domain_from_id(
- event.content["join_authorised_via_users_server"]
+ event.content[EventContentFields.AUTHORISING_USER]
)
if authorising_server != self.server_name:
raise SynapseError(
@@ -1237,7 +1237,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
- if not self.config.use_presence and edu_type == EduTypes.Presence:
+ if not self.config.server.use_presence and edu_type == EduTypes.Presence:
return
# Check if we have a handler on this instance
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4671ac02..720d7bd7 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -594,7 +594,7 @@ class FederationSender(AbstractFederationSender):
destinations (list[str])
"""
- if not states or not self.hs.config.use_presence:
+ if not states or not self.hs.config.server.use_presence:
# No-op if presence is disabled.
return
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index c11d1f6d..afe35e72 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -560,7 +560,7 @@ class PerDestinationQueue:
assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
- return (edus, now_stream_id)
+ return edus, now_stream_id
async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]:
last_device_stream_id = self._last_device_stream_id
@@ -593,7 +593,7 @@ class PerDestinationQueue:
stream_id,
)
- return (edus, stream_id)
+ return edus, stream_id
def _start_catching_up(self) -> None:
"""
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 624c859f..cef65929 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -49,7 +49,9 @@ class Authenticator:
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
- self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.federation_domain_whitelist = (
+ hs.config.federation.federation_domain_whitelist
+ )
self.notifier = hs.get_notifier()
self.replication_client = None
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index d6b75ac2..449bbc70 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -847,16 +847,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
UserID.from_string(requester_user_id)
)
if not is_admin:
- if not self.hs.config.enable_group_creation:
+ if not self.hs.config.groups.enable_group_creation:
raise SynapseError(
403, "Only a server admin can create groups on this server"
)
localpart = group_id_obj.localpart
- if not localpart.startswith(self.hs.config.group_creation_prefix):
+ if not localpart.startswith(self.hs.config.groups.group_creation_prefix):
raise SynapseError(
400,
"Can only create groups with prefix %r on this server"
- % (self.hs.config.group_creation_prefix,),
+ % (self.hs.config.groups.group_creation_prefix,),
)
profile = content.get("profile", {})
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index c23ccd6d..0ccef884 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.ratelimiting import Ratelimiter
+from synapse.types import Requester
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -63,16 +64,21 @@ class BaseHandler:
self.event_builder_factory = hs.get_event_builder_factory()
- async def ratelimit(self, requester, update=True, is_admin_redaction=False):
+ async def ratelimit(
+ self,
+ requester: Requester,
+ update: bool = True,
+ is_admin_redaction: bool = False,
+ ) -> None:
"""Ratelimits requests.
Args:
- requester (Requester)
- update (bool): Whether to record that a request is being processed.
+ requester
+ update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
- is_admin_redaction (bool): Whether this is a room admin/moderator
+ is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different
ratelimits depending on config.
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index affb54e0..96273e2f 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
-from typing import TYPE_CHECKING, List, Tuple
+from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
@@ -21,6 +21,7 @@ from synapse.replication.http.account_data import (
ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
)
+from synapse.streams import EventSource
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
@@ -163,7 +164,7 @@ class AccountDataHandler:
return response["max_stream_id"]
-class AccountDataEventSource:
+class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -171,7 +172,13 @@ class AccountDataEventSource:
return self.store.get_max_account_data_stream_id()
async def get_new_events(
- self, user: UserID, from_key: int, **kwargs
+ self,
+ user: UserID,
+ from_key: int,
+ limit: Optional[int],
+ room_ids: Collection[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index a9c2222f..5a5f124d 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -47,7 +47,7 @@ class AccountValidityHandler:
self.send_email_handler = self.hs.get_send_email_handler()
self.clock = self.hs.get_clock()
- self._app_name = self.hs.config.email_app_name
+ self._app_name = self.hs.config.email.email_app_name
self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
@@ -99,7 +99,7 @@ class AccountValidityHandler:
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
- ):
+ ) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)
@@ -165,7 +165,7 @@ class AccountValidityHandler:
return False
- async def on_user_registration(self, user_id: str):
+ async def on_user_registration(self, user_id: str) -> None:
"""Tell third-party modules about a user's registration.
Args:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index a7b5a4e9..16327870 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
from prometheus_client import Counter
@@ -52,13 +52,13 @@ class ApplicationServicesHandler:
self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False
self.clock = hs.get_clock()
- self.notify_appservices = hs.config.notify_appservices
+ self.notify_appservices = hs.config.appservice.notify_appservices
self.event_sources = hs.get_event_sources()
self.current_max = 0
self.is_processing = False
- def notify_interested_services(self, max_token: RoomStreamToken):
+ def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
@@ -82,7 +82,7 @@ class ApplicationServicesHandler:
self._notify_interested_services(max_token)
@wrap_as_background_process("notify_interested_services")
- async def _notify_interested_services(self, max_token: RoomStreamToken):
+ async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
@@ -100,7 +100,7 @@ class ApplicationServicesHandler:
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
- async def handle_event(event):
+ async def handle_event(event: EventBase) -> None:
# Gather interested services
services = await self._get_services_for_event(event)
if len(services) == 0:
@@ -116,9 +116,9 @@ class ApplicationServicesHandler:
if not self.started_scheduler:
- async def start_scheduler():
+ async def start_scheduler() -> None:
try:
- return await self.scheduler.start()
+ await self.scheduler.start()
except Exception:
logger.error("Application Services Failure")
@@ -137,7 +137,7 @@ class ApplicationServicesHandler:
"appservice_sender"
).observe((now - ts) / 1000)
- async def handle_room_events(events):
+ async def handle_room_events(events: Iterable[EventBase]) -> None:
for event in events:
await handle_event(event)
@@ -184,7 +184,7 @@ class ApplicationServicesHandler:
stream_key: str,
new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None,
- ):
+ ) -> None:
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
@@ -226,7 +226,7 @@ class ApplicationServicesHandler:
stream_key: str,
new_token: Optional[int],
users: Collection[Union[str, UserID]],
- ):
+ ) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
@@ -254,7 +254,7 @@ class ApplicationServicesHandler:
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
- typing_source = self.event_sources.sources["typing"]
+ typing_source = self.event_sources.sources.typing
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
service=service,
@@ -269,7 +269,7 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
- receipts_source = self.event_sources.sources["receipt"]
+ receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
)
@@ -279,7 +279,7 @@ class ApplicationServicesHandler:
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
events: List[JsonDict] = []
- presence_source = self.event_sources.sources["presence"]
+ presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fbbf6fd8..a8c717ef 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -29,6 +29,7 @@ from typing import (
Mapping,
Optional,
Tuple,
+ Type,
Union,
cast,
)
@@ -209,15 +210,15 @@ class AuthHandler(BaseHandler):
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
- for module, config in hs.config.password_providers
+ for module, config in hs.config.authproviders.password_providers
]
logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
- self._password_enabled = hs.config.password_enabled
- self._password_localdb_enabled = hs.config.password_localdb_enabled
+ self._password_enabled = hs.config.auth.password_enabled
+ self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = set()
@@ -249,7 +250,7 @@ class AuthHandler(BaseHandler):
)
# The number of seconds to keep a UI auth session active.
- self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
+ self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
@@ -276,23 +277,25 @@ class AuthHandler(BaseHandler):
# after the SSO completes and before redirecting them back to their client.
# It notifies the user they are about to give access to their matrix account
# to the client.
- self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
+ self._sso_redirect_confirm_template = (
+ hs.config.sso.sso_redirect_confirm_template
+ )
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
- self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
+ self._sso_auth_confirm_template = hs.config.sso.sso_auth_confirm_template
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
- hs.config.sso_account_deactivated_template
+ hs.config.sso.sso_account_deactivated_template
)
self._server_name = hs.config.server.server_name
# cast to tuple for use with str.startswith
- self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+ self._whitelisted_sso_clients = tuple(hs.config.sso.sso_client_whitelist)
# A mapping of user ID to extra attributes to include in the login
# response.
@@ -439,7 +442,7 @@ class AuthHandler(BaseHandler):
return ui_auth_types
- def get_enabled_auth_types(self):
+ def get_enabled_auth_types(self) -> Iterable[str]:
"""Return the enabled user-interactive authentication types
Returns the UI-Auth types which are supported by the homeserver's current
@@ -702,7 +705,7 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- async def _expire_old_sessions(self):
+ async def _expire_old_sessions(self) -> None:
"""
Invalidate any user interactive authentication sessions that have expired.
"""
@@ -738,19 +741,19 @@ class AuthHandler(BaseHandler):
return canonical_id
def _get_params_recaptcha(self) -> dict:
- return {"public_key": self.hs.config.recaptcha_public_key}
+ return {"public_key": self.hs.config.captcha.recaptcha_public_key}
def _get_params_terms(self) -> dict:
return {
"policies": {
"privacy_policy": {
- "version": self.hs.config.user_consent_version,
+ "version": self.hs.config.consent.user_consent_version,
"en": {
- "name": self.hs.config.user_consent_policy_name,
+ "name": self.hs.config.consent.user_consent_policy_name,
"url": "%s_matrix/consent?v=%s"
% (
self.hs.config.server.public_baseurl,
- self.hs.config.user_consent_version,
+ self.hs.config.consent.user_consent_version,
),
},
}
@@ -1015,7 +1018,7 @@ class AuthHandler(BaseHandler):
def can_change_password(self) -> bool:
"""Get whether users on this server are allowed to change or set a password.
- Both `config.password_enabled` and `config.password_localdb_enabled` must be true.
+ Both `config.auth.password_enabled` and `config.auth.password_localdb_enabled` must be true.
Note that any account (even SSO accounts) are allowed to add passwords if the above
is true.
@@ -1347,12 +1350,12 @@ class AuthHandler(BaseHandler):
try:
res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
- raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+ raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
await self.auth.check_auth_blocking(res.user_id)
return res
- async def delete_access_token(self, access_token: str):
+ async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token
Args:
@@ -1381,7 +1384,7 @@ class AuthHandler(BaseHandler):
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
- ):
+ ) -> None:
"""Invalidate access tokens belonging to a user
Args:
@@ -1409,7 +1412,7 @@ class AuthHandler(BaseHandler):
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
- ):
+ ) -> None:
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -1480,12 +1483,12 @@ class AuthHandler(BaseHandler):
Hashed password.
"""
- def _do_hash():
+ def _do_hash() -> str:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
- pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+ pw.encode("utf8") + self.hs.config.auth.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
).decode("ascii")
@@ -1504,12 +1507,12 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash.
"""
- def _do_validate_hash(checked_hash: bytes):
+ def _do_validate_hash(checked_hash: bytes) -> bool:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
- pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+ pw.encode("utf8") + self.hs.config.auth.password_pepper.encode("utf8"),
checked_hash,
)
@@ -1581,7 +1584,7 @@ class AuthHandler(BaseHandler):
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
- ):
+ ) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
Args:
@@ -1627,7 +1630,7 @@ class AuthHandler(BaseHandler):
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
- ):
+ ) -> None:
"""
The synchronous portion of complete_sso_login.
@@ -1726,7 +1729,7 @@ class AuthHandler(BaseHandler):
del self._extra_attributes[user_id]
@staticmethod
- def add_query_param_to_url(url: str, param_name: str, param: Any):
+ def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
url_parts = list(urllib.parse.urlparse(url))
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
@@ -1734,9 +1737,9 @@ class AuthHandler(BaseHandler):
return urllib.parse.urlunparse(url_parts)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class MacaroonGenerator:
- hs = attr.ib()
+ hs: "HomeServer"
def generate_guest_access_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
@@ -1801,7 +1804,7 @@ class MacaroonGenerator:
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server.server_name,
identifier="key",
- key=self.hs.config.macaroon_secret_key,
+ key=self.hs.config.key.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@@ -1816,7 +1819,9 @@ class PasswordProvider:
"""
@classmethod
- def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ def load(
+ cls, module: Type, config: JsonDict, module_api: ModuleApi
+ ) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
@@ -1824,7 +1829,7 @@ class PasswordProvider:
raise
return cls(pp, module_api)
- def __init__(self, pp, module_api: ModuleApi):
+ def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
@@ -1838,7 +1843,7 @@ class PasswordProvider:
if g:
self._supported_login_types.update(g())
- def __str__(self):
+ def __str__(self) -> str:
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
@@ -1876,19 +1881,19 @@ class PasswordProvider:
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
- g = getattr(self._pp, "check_password", None)
- if g:
+ check_password = getattr(self._pp, "check_password", None)
+ if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
- is_valid = await self._pp.check_password(
+ is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
- g = getattr(self._pp, "check_auth", None)
- if not g:
+ check_auth = getattr(self._pp, "check_auth", None)
+ if not check_auth:
return None
- result = await g(username, login_type, login_dict)
+ result = await check_auth(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 47ddabbe..5d8f6c50 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -34,20 +34,20 @@ logger = logging.getLogger(__name__)
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket."""
- def __init__(self, error, error_description=None):
+ def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description
- def __str__(self):
+ def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class CasResponse:
- username = attr.ib(type=str)
- attributes = attr.ib(type=Dict[str, List[Optional[str]]])
+ username: str
+ attributes: Dict[str, List[Optional[str]]]
class CasHandler:
@@ -65,10 +65,10 @@ class CasHandler:
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._cas_server_url = hs.config.cas_server_url
- self._cas_service_url = hs.config.cas_service_url
- self._cas_displayname_attribute = hs.config.cas_displayname_attribute
- self._cas_required_attributes = hs.config.cas_required_attributes
+ self._cas_server_url = hs.config.cas.cas_server_url
+ self._cas_service_url = hs.config.cas.cas_service_url
+ self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute
+ self._cas_required_attributes = hs.config.cas.cas_required_attributes
self._http_client = hs.get_proxied_http_client()
@@ -133,11 +133,9 @@ class CasHandler:
body = pde.response
except HttpResponseException as e:
description = (
- (
- 'Authorization server responded with a "{status}" error '
- "while exchanging the authorization code."
- ).format(status=e.code),
- )
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=e.code)
raise CasError("server_error", description) from e
return self._parse_cas_response(body)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index dcd320c5..9ae5b775 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -255,16 +255,16 @@ class DeactivateAccountHandler(BaseHandler):
Args:
user_id: ID of user to be re-activated
"""
- # Add the user to the directory, if necessary.
user = UserID.from_string(user_id)
- if self.hs.config.user_directory_search_all_users:
- profile = await self.store.get_profileinfo(user.localpart)
- await self.user_directory_handler.handle_local_profile_change(
- user_id, profile
- )
# Ensure the user is not marked as erased.
await self.store.mark_user_not_erased(user_id)
# Mark the user as active.
await self.store.set_user_deactivated_status(user_id, False)
+
+ # Add the user to the directory, if necessary. Note that
+ # this must be done after the user is re-activated, because
+ # deactivated users are excluded from the user directory.
+ profile = await self.store.get_profileinfo(user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(user_id, profile)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 46ee8344..35334725 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -267,7 +267,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- def _check_device_name_length(self, name: Optional[str]):
+ def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index d487fee6..5cfba3c8 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -48,7 +48,7 @@ class DirectoryHandler(BaseHandler):
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
self.config = hs.config
- self.enable_room_list_search = hs.config.enable_room_list_search
+ self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.require_membership = hs.config.require_membership_for_aliases
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -143,7 +143,7 @@ class DirectoryHandler(BaseHandler):
):
raise AuthError(403, "This user is not permitted to create this alias")
- if not self.config.is_alias_creation_allowed(
+ if not self.config.roomdirectory.is_alias_creation_allowed(
user_id, room_id, room_alias_str
):
# Lets just return a generic message, as there may be all sorts of
@@ -459,7 +459,7 @@ class DirectoryHandler(BaseHandler):
if canonical_alias:
room_aliases.append(canonical_alias)
- if not self.config.is_publishing_room_allowed(
+ if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
):
# Lets just return a generic message, as there may be all sorts of
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 08a13756..d0fb2fc7 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -202,7 +202,7 @@ class E2eKeysHandler:
# Now fetch any devices that we don't have in our cache
@trace
- async def do_remote_query(destination):
+ async def do_remote_query(destination: str) -> None:
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
@@ -447,7 +447,7 @@ class E2eKeysHandler:
}
@trace
- async def claim_client_keys(destination):
+ async def claim_client_keys(destination: str) -> None:
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 4288ffff..cb81fa09 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
+from synapse.events.snapshot import EventContext
from synapse.types import StateMap, get_domain_from_id
from synapse.util.metrics import Measure
@@ -45,7 +46,11 @@ class EventAuthHandler:
self._server_name = hs.hostname
async def check_from_context(
- self, room_version: str, event, context, do_sig_check=True
+ self,
+ room_version: str,
+ event: EventBase,
+ context: EventContext,
+ do_sig_check: bool = True,
) -> None:
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6754c64c..adbd150e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -27,7 +27,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership, RejectedReason
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RejectedReason,
+)
from synapse.api.errors import (
AuthError,
CodeMessageException,
@@ -91,7 +96,7 @@ class FederationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
- self._server_notices_mxid = hs.config.server_notices_mxid
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_proxied_blacklisted_http_client()
self._replication = hs.get_replication_data_handler()
@@ -593,6 +598,13 @@ class FederationHandler(BaseHandler):
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
)
+ # Mark the knock as an outlier as we don't yet have the state at this point in
+ # the DAG.
+ event.internal_metadata.outlier = True
+
+ # ... but tell /sync to send it to clients anyway.
+ event.internal_metadata.out_of_band_membership = True
+
# Record the room ID and its version so that we have a record of the room
await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=event_format_version
@@ -617,7 +629,7 @@ class FederationHandler(BaseHandler):
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
- context = await self.state_handler.compute_event_context(event)
+ context = EventContext.for_outlier()
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -705,7 +717,7 @@ class FederationHandler(BaseHandler):
if include_auth_user_id:
event_content[
- "join_authorised_via_users_server"
+ EventContentFields.AUTHORISING_USER
] = await self._event_auth_handler.get_user_which_could_invite(
room_id,
state_ids,
@@ -807,7 +819,7 @@ class FederationHandler(BaseHandler):
)
)
- context = await self.state_handler.compute_event_context(event)
+ context = EventContext.for_outlier()
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -836,7 +848,7 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
- context = await self.state_handler.compute_event_context(event)
+ context = EventContext.for_outlier()
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -1108,8 +1120,7 @@ class FederationHandler(BaseHandler):
events_to_context = {}
for e in itertools.chain(auth_events, state):
e.internal_metadata.outlier = True
- ctx = await self.state_handler.compute_event_context(e)
- events_to_context[e.event_id] = ctx
+ events_to_context[e.event_id] = EventContext.for_outlier()
event_map = {
e.event_id: e for e in itertools.chain(auth_events, state, [event])
@@ -1221,136 +1232,6 @@ class FederationHandler(BaseHandler):
return missing_events
- async def construct_auth_difference(
- self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
- ) -> Dict:
- """Given a local and remote auth chain, find the differences. This
- assumes that we have already processed all events in remote_auth
-
- Params:
- local_auth
- remote_auth
-
- Returns:
- dict
- """
-
- logger.debug("construct_auth_difference Start!")
-
- # TODO: Make sure we are OK with local_auth or remote_auth having more
- # auth events in them than strictly necessary.
-
- def sort_fun(ev):
- return ev.depth, ev.event_id
-
- logger.debug("construct_auth_difference after sort_fun!")
-
- # We find the differences by starting at the "bottom" of each list
- # and iterating up on both lists. The lists are ordered by depth and
- # then event_id, we iterate up both lists until we find the event ids
- # don't match. Then we look at depth/event_id to see which side is
- # missing that event, and iterate only up that list. Repeat.
-
- remote_list = list(remote_auth)
- remote_list.sort(key=sort_fun)
-
- local_list = list(local_auth)
- local_list.sort(key=sort_fun)
-
- local_iter = iter(local_list)
- remote_iter = iter(remote_list)
-
- logger.debug("construct_auth_difference before get_next!")
-
- def get_next(it, opt=None):
- try:
- return next(it)
- except Exception:
- return opt
-
- current_local = get_next(local_iter)
- current_remote = get_next(remote_iter)
-
- logger.debug("construct_auth_difference before while")
-
- missing_remotes = []
- missing_locals = []
- while current_local or current_remote:
- if current_remote is None:
- missing_locals.append(current_local)
- current_local = get_next(local_iter)
- continue
-
- if current_local is None:
- missing_remotes.append(current_remote)
- current_remote = get_next(remote_iter)
- continue
-
- if current_local.event_id == current_remote.event_id:
- current_local = get_next(local_iter)
- current_remote = get_next(remote_iter)
- continue
-
- if current_local.depth < current_remote.depth:
- missing_locals.append(current_local)
- current_local = get_next(local_iter)
- continue
-
- if current_local.depth > current_remote.depth:
- missing_remotes.append(current_remote)
- current_remote = get_next(remote_iter)
- continue
-
- # They have the same depth, so we fall back to the event_id order
- if current_local.event_id < current_remote.event_id:
- missing_locals.append(current_local)
- current_local = get_next(local_iter)
-
- if current_local.event_id > current_remote.event_id:
- missing_remotes.append(current_remote)
- current_remote = get_next(remote_iter)
- continue
-
- logger.debug("construct_auth_difference after while")
-
- # missing locals should be sent to the server
- # We should find why we are missing remotes, as they will have been
- # rejected.
-
- # Remove events from missing_remotes if they are referencing a missing
- # remote. We only care about the "root" rejected ones.
- missing_remote_ids = [e.event_id for e in missing_remotes]
- base_remote_rejected = list(missing_remotes)
- for e in missing_remotes:
- for e_id in e.auth_event_ids():
- if e_id in missing_remote_ids:
- try:
- base_remote_rejected.remove(e)
- except ValueError:
- pass
-
- reason_map = {}
-
- for e in base_remote_rejected:
- reason = await self.store.get_rejection_reason(e.event_id)
- if reason is None:
- # TODO: e is not in the current state, so we should
- # construct some proof of that.
- continue
-
- reason_map[e.event_id] = reason
-
- logger.debug("construct_auth_difference returning")
-
- return {
- "auth_chain": local_auth,
- "rejects": {
- e.event_id: {"reason": reason_map[e.event_id], "proof": None}
- for e in base_remote_rejected
- },
- "missing": [e.event_id for e in missing_locals],
- }
-
@log_function
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
@@ -1493,7 +1374,7 @@ class FederationHandler(BaseHandler):
builder=builder
)
EventValidator().validate_new(event, self.config)
- return (event, context)
+ return event, context
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 946343fa..01fd8411 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -27,11 +27,8 @@ from typing import (
Tuple,
)
-import attr
from prometheus_client import Counter
-from twisted.internet import defer
-
from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
@@ -54,11 +51,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
-from synapse.logging.context import (
- make_deferred_yieldable,
- nested_logging_context,
- run_in_background,
-)
+from synapse.logging.context import nested_logging_context, run_in_background
from synapse.logging.utils import log_function
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
@@ -75,7 +68,11 @@ from synapse.types import (
UserID,
get_domain_from_id,
)
-from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util.async_helpers import (
+ Linearizer,
+ concurrently_execute,
+ yieldable_gather_results,
+)
from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
@@ -92,30 +89,6 @@ soft_failed_event_counter = Counter(
)
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _NewEventInfo:
- """Holds information about a received event, ready for passing to _auth_and_persist_events
-
- Attributes:
- event: the received event
-
- claimed_auth_event_map: a map of (type, state_key) => event for the event's
- claimed auth_events.
-
- This can include events which have not yet been persisted, in the case that
- we are backfilling a batch of events.
-
- Note: May be incomplete: if we were unable to find all of the claimed auth
- events. Also, treat the contents with caution: the events might also have
- been rejected, might not yet have been authorized themselves, or they might
- be in the wrong room.
-
- """
-
- event: EventBase
- claimed_auth_event_map: StateMap[EventBase]
-
-
class FederationEventHandler:
"""Handles events that originated from federation.
@@ -1016,7 +989,7 @@ class FederationEventHandler:
except Exception:
logger.exception("Failed to resync device for %s", sender)
- async def _handle_marker_event(self, origin: str, marker_event: EventBase):
+ async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
"""Handles backfilling the insertion event when we receive a marker
event that points to one.
@@ -1107,9 +1080,9 @@ class FederationEventHandler:
room_version = await self._store.get_room_version(room_id)
- event_map: Dict[str, EventBase] = {}
+ events: List[EventBase] = []
- async def get_event(event_id: str):
+ async def get_event(event_id: str) -> None:
with nested_logging_context(event_id):
try:
event = await self._federation_client.get_pdu(
@@ -1125,8 +1098,7 @@ class FederationEventHandler:
event_id,
)
return
-
- event_map[event.event_id] = event
+ events.append(event)
except Exception as e:
logger.warning(
@@ -1137,11 +1109,29 @@ class FederationEventHandler:
)
await concurrently_execute(get_event, event_ids, 5)
- logger.info("Fetched %i events of %i requested", len(event_map), len(event_ids))
+ logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
+ await self._auth_and_persist_fetched_events(destination, room_id, events)
+
+ async def _auth_and_persist_fetched_events(
+ self, origin: str, room_id: str, events: Iterable[EventBase]
+ ) -> None:
+ """Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event
+
+ The events to be persisted must be outliers.
+
+ We first sort the events to make sure that we process each event's auth_events
+ before the event itself, and then auth and persist them.
+
+ Notifies about the events where appropriate.
+
+ Params:
+ origin: where the events came from
+ room_id: the room that the events are meant to be in (though this has
+ not yet been checked)
+ events: the events that have been fetched
+ """
+ event_map = {event.event_id: event for event in events}
- # we now need to auth the events in an order which ensures that each event's
- # auth_events are authed before the event itself.
- #
# XXX: it might be possible to kick this process off in parallel with fetching
# the events.
while event_map:
@@ -1168,22 +1158,18 @@ class FederationEventHandler:
"Persisting %i of %i remaining events", len(roots), len(event_map)
)
- await self._auth_and_persist_fetched_events(destination, room_id, roots)
+ await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
for ev in roots:
del event_map[ev.event_id]
- async def _auth_and_persist_fetched_events(
+ async def _auth_and_persist_fetched_events_inner(
self, origin: str, room_id: str, fetched_events: Collection[EventBase]
) -> None:
- """Persist the events fetched by _get_events_and_persist.
+ """Helper for _auth_and_persist_fetched_events
- The events should not depend on one another, e.g. this should be used to persist
- a bunch of outliers, but not a chunk of individual events that depend
- on each other for state calculations.
-
- We also assume that all of the auth events for all of the events have already
- been persisted.
+ Persists a batch of events where we have (theoretically) already persisted all
+ of their auth events.
Notifies about the events where appropriate.
@@ -1191,7 +1177,7 @@ class FederationEventHandler:
origin: where the events came from
room_id: the room that the events are meant to be in (though this has
not yet been checked)
- event_id: map from event_id -> event for the fetched events
+ fetched_events: the events to persist
"""
# get all the auth events for all the events in this batch. By now, they should
# have been persisted.
@@ -1203,47 +1189,36 @@ class FederationEventHandler:
allow_rejected=True,
)
- event_infos = []
- for event in fetched_events:
- auth = {}
- for auth_event_id in event.auth_event_ids():
- ae = persisted_events.get(auth_event_id)
- if ae:
+ async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
+ with nested_logging_context(suffix=event.event_id):
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = persisted_events.get(auth_event_id)
+ if not ae:
+ logger.warning(
+ "Event %s relies on auth_event %s, which could not be found.",
+ event,
+ auth_event_id,
+ )
+ # the fact we can't find the auth event doesn't mean it doesn't
+ # exist, which means it is premature to reject `event`. Instead we
+ # just ignore it for now.
+ return None
auth[(ae.type, ae.state_key)] = ae
- else:
- logger.info("Missing auth event %s", auth_event_id)
- event_infos.append(_NewEventInfo(event, auth))
-
- if not event_infos:
- return
-
- async def prep(ev_info: _NewEventInfo):
- event = ev_info.event
- with nested_logging_context(suffix=event.event_id):
- res = await self._state_handler.compute_event_context(event)
- res = await self._check_event_auth(
+ context = EventContext.for_outlier()
+ context = await self._check_event_auth(
origin,
event,
- res,
- claimed_auth_event_map=ev_info.claimed_auth_event_map,
+ context,
+ claimed_auth_event_map=auth,
)
- return res
+ return event, context
- contexts = await make_deferred_yieldable(
- defer.gatherResults(
- [run_in_background(prep, ev_info) for ev_info in event_infos],
- consumeErrors=True,
- )
- )
-
- await self.persist_events_and_notify(
- room_id,
- [
- (ev_info.event, context)
- for ev_info, context in zip(event_infos, contexts)
- ],
+ events_to_persist = (
+ x for x in await yieldable_gather_results(prep, fetched_events) if x
)
+ await self.persist_events_and_notify(room_id, tuple(events_to_persist))
async def _check_event_auth(
self,
@@ -1269,8 +1244,7 @@ class FederationEventHandler:
claimed_auth_event_map:
A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
+ Possibly including events that were rejected, or are in the wrong room.
Only populated when populating outliers.
@@ -1505,64 +1479,22 @@ class FederationEventHandler:
# If we don't have all the auth events, we need to get them.
logger.info("auth_events contains unknown events: %s", missing_auth)
try:
- try:
- remote_auth_chain = await self._federation_client.get_event_auth(
- origin, event.room_id, event.event_id
- )
- except RequestSendFailed as e1:
- # The other side isn't around or doesn't implement the
- # endpoint, so lets just bail out.
- logger.info("Failed to get event auth from remote: %s", e1)
- return context, auth_events
-
- seen_remotes = await self._store.have_seen_events(
- event.room_id, [e.event_id for e in remote_auth_chain]
+ await self._get_remote_auth_chain_for_event(
+ origin, event.room_id, event.event_id
)
-
- for auth_event in remote_auth_chain:
- if auth_event.event_id in seen_remotes:
- continue
-
- if auth_event.event_id == event.event_id:
- continue
-
- try:
- auth_ids = auth_event.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in remote_auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- auth_event.internal_metadata.outlier = True
-
- logger.debug(
- "_check_event_auth %s missing_auth: %s",
- event.event_id,
- auth_event.event_id,
- )
- missing_auth_event_context = (
- await self._state_handler.compute_event_context(auth_event)
- )
-
- missing_auth_event_context = await self._check_event_auth(
- origin,
- auth_event,
- missing_auth_event_context,
- claimed_auth_event_map=auth,
- )
- await self.persist_events_and_notify(
- event.room_id, [(auth_event, missing_auth_event_context)]
- )
-
- if auth_event.event_id in event_auth_events:
- auth_events[
- (auth_event.type, auth_event.state_key)
- ] = auth_event
- except AuthError:
- pass
-
except Exception:
logger.exception("Failed to get auth chain")
+ else:
+ # load any auth events we might have persisted from the database. This
+ # has the side-effect of correctly setting the rejected_reason on them.
+ auth_events.update(
+ {
+ (ae.type, ae.state_key): ae
+ for ae in await self._store.get_events_as_list(
+ missing_auth, allow_rejected=True
+ )
+ }
+ )
if event.internal_metadata.is_outlier():
# XXX: given that, for an outlier, we'll be working with the
@@ -1636,6 +1568,45 @@ class FederationEventHandler:
return context, auth_events
+ async def _get_remote_auth_chain_for_event(
+ self, destination: str, room_id: str, event_id: str
+ ) -> None:
+ """If we are missing some of an event's auth events, attempt to request them
+
+ Args:
+ destination: where to fetch the auth tree from
+ room_id: the room in which we are lacking auth events
+ event_id: the event for which we are lacking auth events
+ """
+ try:
+ remote_event_map = {
+ e.event_id: e
+ for e in await self._federation_client.get_event_auth(
+ destination, room_id, event_id
+ )
+ }
+ except RequestSendFailed as e1:
+ # The other side isn't around or doesn't implement the
+ # endpoint, so lets just bail out.
+ logger.info("Failed to get event auth from remote: %s", e1)
+ return
+
+ logger.info("/event_auth returned %i events", len(remote_event_map))
+
+ # `event` may be returned, but we should not yet process it.
+ remote_event_map.pop(event_id, None)
+
+ # nor should we reprocess any events we have already seen.
+ seen_remotes = await self._store.have_seen_events(
+ room_id, remote_event_map.keys()
+ )
+ for s in seen_remotes:
+ remote_event_map.pop(s, None)
+
+ await self._auth_and_persist_fetched_events(
+ destination, room_id, remote_event_map.values()
+ )
+
async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
) -> EventContext:
@@ -1692,7 +1663,7 @@ class FederationEventHandler:
async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
- ):
+ ) -> None:
"""Run the push actions for a received event, and persist it.
Args:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 1a6c5c64..9e270d46 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, JsonDict, get_domain_from_id
@@ -25,12 +25,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _create_rerouter(func_name):
+def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
"""Returns an async function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
- async def f(self, group_id, *args, **kwargs):
+ async def f(
+ self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
+ ) -> JsonDict:
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 8b8f1f41..fe8a9958 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -62,7 +62,7 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
- self._web_client_location = hs.config.invite_client_location
+ self._web_client_location = hs.config.email.invite_client_location
# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
@@ -419,7 +419,7 @@ class IdentityHandler(BaseHandler):
token_expires = (
self.hs.get_clock().time_msec()
- + self.hs.config.email_validation_token_lifetime
+ + self.hs.config.email.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(
@@ -465,7 +465,7 @@ class IdentityHandler(BaseHandler):
if next_link:
params["next_link"] = next_link
- if self.hs.config.using_identity_server_from_trusted_list:
+ if self.hs.config.email.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use
logger.warning(
'The config option "trust_identity_server_for_password_resets" '
@@ -518,7 +518,7 @@ class IdentityHandler(BaseHandler):
if next_link:
params["next_link"] = next_link
- if self.hs.config.using_identity_server_from_trusted_list:
+ if self.hs.config.email.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use
logger.warning(
'The config option "trust_identity_server_for_password_resets" '
@@ -572,12 +572,12 @@ class IdentityHandler(BaseHandler):
validation_session = None
# Try to validate as email
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server
validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
- elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 4e8f7f1d..9ad39a65 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.internet import defer
@@ -125,7 +125,7 @@ class InitialSyncHandler(BaseHandler):
now_token = self.hs.get_event_sources().get_current_token()
- presence_stream = self.hs.get_event_sources().sources["presence"]
+ presence_stream = self.hs.get_event_sources().sources.presence
presence, _ = await presence_stream.get_new_events(
user, from_key=None, include_offline=False
)
@@ -150,7 +150,7 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- async def handle_room(event: RoomsForUser):
+ async def handle_room(event: RoomsForUser) -> None:
d: JsonDict = {
"room_id": event.room_id,
"membership": event.membership,
@@ -411,9 +411,9 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
- async def get_presence():
+ async def get_presence() -> List[JsonDict]:
# If presence is disabled, return an empty list
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return []
states = await presence_handler.get_states(
@@ -428,7 +428,7 @@ class InitialSyncHandler(BaseHandler):
for s in states
]
- async def get_receipts():
+ async def get_receipts() -> List[JsonDict]:
receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 10f1584a..fd861e94 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
+from synapse.handlers.directory import DirectoryHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@@ -298,7 +299,7 @@ class MessageHandler:
for user_id, profile in users_with_profile.items()
}
- def maybe_schedule_expiry(self, event: EventBase):
+ def maybe_schedule_expiry(self, event: EventBase) -> None:
"""Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided
timestamp.
@@ -318,7 +319,7 @@ class MessageHandler:
# a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts)
- async def _schedule_next_expiry(self):
+ async def _schedule_next_expiry(self) -> None:
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it.
@@ -331,7 +332,7 @@ class MessageHandler:
event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts)
- def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
+ def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int) -> None:
"""Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one.
@@ -367,7 +368,7 @@ class MessageHandler:
event_id,
)
- async def _expire_event(self, event_id: str):
+ async def _expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date
@@ -442,7 +443,7 @@ class EventCreationHandler:
)
self._block_events_without_consent_error = (
- self.config.block_events_without_consent_error
+ self.config.consent.block_events_without_consent_error
)
# we need to construct a ConsentURIBuilder here, as it checks that the necessary
@@ -665,7 +666,7 @@ class EventCreationHandler:
self.validator.validate_new(event, self.config)
- return (event, context)
+ return event, context
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@@ -691,10 +692,10 @@ class EventCreationHandler:
return False
async def _is_server_notices_room(self, room_id: str) -> bool:
- if self.config.server_notices_mxid is None:
+ if self.config.servernotices.server_notices_mxid is None:
return False
user_ids = await self.store.get_users_in_room(room_id)
- return self.config.server_notices_mxid in user_ids
+ return self.config.servernotices.server_notices_mxid in user_ids
async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
"""Check if a user has accepted the privacy policy
@@ -730,8 +731,8 @@ class EventCreationHandler:
# exempt the system notices user
if (
- self.config.server_notices_mxid is not None
- and user_id == self.config.server_notices_mxid
+ self.config.servernotices.server_notices_mxid is not None
+ and user_id == self.config.servernotices.server_notices_mxid
):
return
@@ -743,7 +744,7 @@ class EventCreationHandler:
if u["appservice_id"] is not None:
# users registered by an appservice are exempt
return
- if u["consent_version"] == self.config.user_consent_version:
+ if u["consent_version"] == self.config.consent.user_consent_version:
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
@@ -951,18 +952,13 @@ class EventCreationHandler:
depth=depth,
)
- old_state = None
-
# Pass on the outlier property from the builder to the event
# after it is created
if builder.internal_metadata.outlier:
- event.internal_metadata.outlier = builder.internal_metadata.outlier
-
- # Calculate the state for outliers that pass in their own `auth_event_ids`
- if auth_event_ids:
- old_state = await self.store.get_events_as_list(auth_event_ids)
-
- context = await self.state.compute_event_context(event, old_state=old_state)
+ event.internal_metadata.outlier = True
+ context = EventContext.for_outlier()
+ else:
+ context = await self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
@@ -1003,7 +999,7 @@ class EventCreationHandler:
logger.debug("Created event %s", event.event_id)
- return (event, context)
+ return event, context
@measure_func("handle_new_client_event")
async def handle_new_client_event(
@@ -1229,7 +1225,10 @@ class EventCreationHandler:
self._external_cache_joined_hosts_updates[state_entry.state_group] = None
async def _validate_canonical_alias(
- self, directory_handler, room_alias_str: str, expected_room_id: str
+ self,
+ directory_handler: DirectoryHandler,
+ room_alias_str: str,
+ expected_room_id: str,
) -> None:
"""
Ensure that the given room alias points to the expected room ID.
@@ -1421,7 +1420,7 @@ class EventCreationHandler:
# structural protocol level).
is_msc2716_event = (
original_event.type == EventTypes.MSC2716_INSERTION
- or original_event.type == EventTypes.MSC2716_CHUNK
+ or original_event.type == EventTypes.MSC2716_BATCH
or original_event.type == EventTypes.MSC2716_MARKER
)
if not room_version_obj.msc2716_historical and is_msc2716_event:
@@ -1477,7 +1476,7 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
- def _notify():
+ def _notify() -> None:
try:
self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
@@ -1523,7 +1522,7 @@ class EventCreationHandler:
except Exception:
logger.exception("Error bumping presence active time")
- async def _send_dummy_events_to_fill_extremities(self):
+ async def _send_dummy_events_to_fill_extremities(self) -> None:
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
@@ -1600,7 +1599,7 @@ class EventCreationHandler:
)
return False
- def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
+ def _expire_rooms_to_exclude_from_dummy_event_insertion(self) -> None:
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
to_expire = set()
for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index dfc251b2..3665d915 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode, urlparse
import attr
@@ -249,11 +249,11 @@ class OidcHandler:
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint"""
- def __init__(self, error, error_description=None):
+ def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description
- def __str__(self):
+ def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error
@@ -277,7 +277,7 @@ class OidcProvider:
self._token_generator = token_generator
self._config = provider
- self._callback_url: str = hs.config.oidc_callback_url
+ self._callback_url: str = hs.config.oidc.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
@@ -1057,13 +1057,13 @@ class JwtClientSecret:
self._cached_secret = b""
self._cached_secret_replacement_time = 0
- def __str__(self):
+ def __str__(self) -> str:
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
# here.
return self._get_secret().decode("ascii")
- def __bytes__(self):
+ def __bytes__(self) -> bytes:
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
# encode_client_secret_post, which ends up here.
return self._get_secret()
@@ -1197,21 +1197,21 @@ class OidcSessionTokenGenerator:
)
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
# the Identity Provider being used
- idp_id = attr.ib(type=str)
+ idp_id: str
# The `nonce` parameter passed to the OIDC provider.
- nonce = attr.ib(type=str)
+ nonce: str
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
- client_redirect_url = attr.ib(type=str)
+ client_redirect_url: str
# The session ID of the ongoing UI Auth ("" if this is a login)
- ui_auth_session_id = attr.ib(type=str)
+ ui_auth_session_id: str
class UserAttributeDict(TypedDict):
@@ -1290,20 +1290,20 @@ class OidcMappingProvider(Generic[C]):
# Used to clear out "None" values in templates
-def jinja_finalize(thing):
+def jinja_finalize(thing: Any) -> Any:
return thing if thing is not None else ""
env = Environment(finalize=jinja_finalize)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig:
- subject_claim = attr.ib(type=str)
- localpart_template = attr.ib(type=Optional[Template])
- display_name_template = attr.ib(type=Optional[Template])
- email_template = attr.ib(type=Optional[Template])
- extra_attributes = attr.ib(type=Dict[str, Template])
+ subject_claim: str
+ localpart_template: Optional[Template]
+ display_name_template: Optional[Template]
+ email_template: Optional[Template]
+ extra_attributes: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 7dc0ee4b..08b93b3e 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,6 +15,8 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
+import attr
+
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
@@ -24,7 +26,7 @@ from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
-from synapse.types import Requester
+from synapse.types import JsonDict, Requester
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -36,15 +38,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, auto_attribs=True)
class PurgeStatus:
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
return by get_purge_status.
-
- Attributes:
- status (int): Tracks whether this request has completed. One of
- STATUS_{ACTIVE,COMPLETE,FAILED}
"""
STATUS_ACTIVE = 0
@@ -57,10 +56,10 @@ class PurgeStatus:
STATUS_FAILED: "failed",
}
- def __init__(self):
- self.status = PurgeStatus.STATUS_ACTIVE
+ # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}.
+ status: int = STATUS_ACTIVE
- def asdict(self):
+ def asdict(self) -> JsonDict:
return {"status": PurgeStatus.STATUS_TEXT[self.status]}
@@ -107,7 +106,7 @@ class PaginationHandler:
async def purge_history_for_rooms_in_range(
self, min_ms: Optional[int], max_ms: Optional[int]
- ):
+ ) -> None:
"""Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its
@@ -291,7 +290,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.discard(room_id)
# remove the purge from the list 24 hours after it completes
- def clear_purge():
+ def clear_purge() -> None:
del self._purges_by_id[purge_id]
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index cd21efdc..eadd7ced 100644
--- a/synapse/handlers/password_policy.py
+++ b/synapse/handlers/password_policy.py
@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class PasswordPolicyHandler:
def __init__(self, hs: "HomeServer"):
- self.policy = hs.config.password_policy
- self.enabled = hs.config.password_policy_enabled
+ self.policy = hs.config.auth.password_policy
+ self.enabled = hs.config.auth.password_policy_enabled
# Regexps for the spec'd policy parameters.
self.regexp_digit = re.compile("[0-9]")
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 39b39cd3..404afb94 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -26,18 +26,22 @@ import contextlib
import logging
from bisect import bisect
from contextlib import contextmanager
+from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Collection,
Dict,
FrozenSet,
+ Generator,
Iterable,
List,
Optional,
Set,
Tuple,
+ Type,
Union,
)
@@ -48,6 +52,7 @@ import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
+from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
@@ -61,6 +66,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
+from synapse.streams import EventSource
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -240,7 +246,7 @@ class BasePresenceHandler(abc.ABC):
"""
@abc.abstractmethod
- async def bump_presence_active_time(self, user: UserID):
+ async def bump_presence_active_time(self, user: UserID) -> None:
"""We've seen the user do something that indicates they're interacting
with the app.
"""
@@ -274,7 +280,7 @@ class BasePresenceHandler(abc.ABC):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
- ):
+ ) -> None:
"""Process streams received over replication."""
await self._federation_queue.process_replication_rows(
stream_name, instance_name, token, rows
@@ -286,7 +292,7 @@ class BasePresenceHandler(abc.ABC):
async def maybe_send_presence_to_interested_destinations(
self, states: List[UserPresenceState]
- ):
+ ) -> None:
"""If this instance is a federation sender, send the states to all
destinations that are interested. Filters out any states for remote
users.
@@ -309,7 +315,7 @@ class BasePresenceHandler(abc.ABC):
for destination, host_states in hosts_to_states.items():
self._federation.send_presence_to_destinations(host_states, [destination])
- async def send_full_presence_to_users(self, user_ids: Collection[str]):
+ async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
"""
Adds to the list of users who should receive a full snapshot of presence
upon their next sync. Note that this only works for local users.
@@ -363,7 +369,12 @@ class BasePresenceHandler(abc.ABC):
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
pass
@@ -374,7 +385,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._presence_writer_instance = hs.config.worker.writers.presence[0]
- self._presence_enabled = hs.config.use_presence
+ self._presence_enabled = hs.config.server.use_presence
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
@@ -468,7 +479,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
if self._user_to_num_current_syncs[user_id] == 1:
self.mark_as_coming_online(user_id)
- def _end():
+ def _end() -> None:
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
@@ -480,7 +491,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.mark_as_going_offline(user_id)
@contextlib.contextmanager
- def _user_syncing():
+ def _user_syncing() -> Generator[None, None, None]:
try:
yield
finally:
@@ -503,7 +514,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
- ):
+ ) -> None:
await super().process_replication_rows(stream_name, instance_name, token, rows)
if stream_name != PresenceStream.NAME:
@@ -584,7 +595,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
user_id = target_user.to_string()
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
# Proxy request to instance that writes presence
@@ -601,7 +612,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
# Proxy request to instance that writes presence
@@ -618,7 +629,7 @@ class PresenceHandler(BasePresenceHandler):
self.server_name = hs.hostname
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
- self._presence_enabled = hs.config.use_presence
+ self._presence_enabled = hs.config.server.use_presence
federation_registry = hs.get_federation_registry()
@@ -689,7 +700,7 @@ class PresenceHandler(BasePresenceHandler):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
- def run_timeout_handler():
+ def run_timeout_handler() -> Awaitable[None]:
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
@@ -698,7 +709,7 @@ class PresenceHandler(BasePresenceHandler):
30, self.clock.looping_call, run_timeout_handler, 5000
)
- def run_persister():
+ def run_persister() -> Awaitable[None]:
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
@@ -916,7 +927,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
user_id = user.to_string()
@@ -942,14 +953,14 @@ class PresenceHandler(BasePresenceHandler):
when users disconnect/reconnect.
Args:
- user_id (str)
- affect_presence (bool): If false this function will be a no-op.
+ user_id
+ affect_presence: If false this function will be a no-op.
Useful for streams that are not associated with an actual
client that is being used by a user.
"""
# Override if it should affect the user's presence, if presence is
# disabled.
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
affect_presence = False
if affect_presence:
@@ -978,7 +989,7 @@ class PresenceHandler(BasePresenceHandler):
]
)
- async def _end():
+ async def _end() -> None:
try:
self.user_to_num_current_syncs[user_id] -= 1
@@ -994,7 +1005,7 @@ class PresenceHandler(BasePresenceHandler):
logger.exception("Error updating presence after sync")
@contextmanager
- def _user_syncing():
+ def _user_syncing() -> Generator[None, None, None]:
try:
yield
finally:
@@ -1264,7 +1275,7 @@ class PresenceHandler(BasePresenceHandler):
if self._event_processing:
return
- async def _process_presence():
+ async def _process_presence() -> None:
assert not self._event_processing
self._event_processing = True
@@ -1491,7 +1502,7 @@ def format_user_presence_state(
return content
-class PresenceEventSource:
+class PresenceEventSource(EventSource[int, UserPresenceState]):
def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle:
#
@@ -1510,10 +1521,12 @@ class PresenceEventSource:
self,
user: UserID,
from_key: Optional[int],
- room_ids: Optional[List[str]] = None,
- include_offline: bool = True,
+ limit: Optional[int] = None,
+ room_ids: Optional[Collection[str]] = None,
+ is_guest: bool = False,
explicit_room_id: Optional[str] = None,
- **kwargs,
+ include_offline: bool = True,
+ service: Optional[ApplicationService] = None,
) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.
@@ -2074,7 +2087,7 @@ class PresenceFederationQueue:
if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
- def _clear_queue(self):
+ def _clear_queue(self) -> None:
"""Clear out older entries from the queue."""
clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
@@ -2205,7 +2218,7 @@ class PresenceFederationQueue:
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
- ):
+ ) -> None:
if stream_name != PresenceFederationStream.NAME:
return
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 51adf876..b23a1541 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -214,11 +214,10 @@ class ProfileHandler(BaseHandler):
target_user.localpart, displayname_to_set
)
- if self.hs.config.user_directory_search_all_users:
- profile = await self.store.get_profileinfo(target_user.localpart)
- await self.user_directory_handler.handle_local_profile_change(
- target_user.to_string(), profile
- )
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
await self._update_join_states(requester, target_user)
@@ -254,7 +253,7 @@ class ProfileHandler(BaseHandler):
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
- ):
+ ) -> None:
"""Set a new avatar URL for a user.
Args:
@@ -300,18 +299,17 @@ class ProfileHandler(BaseHandler):
target_user.localpart, avatar_url_to_set
)
- if self.hs.config.user_directory_search_all_users:
- profile = await self.store.get_profileinfo(target_user.localpart)
- await self.user_directory_handler.handle_local_profile_change(
- target_user.to_string(), profile
- )
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
await self._update_join_states(requester, target_user)
async def on_profile_query(self, args: JsonDict) -> JsonDict:
"""Handles federation profile query requests."""
- if not self.hs.config.allow_profile_lookup_over_federation:
+ if not self.hs.config.federation.allow_profile_lookup_over_federation:
raise SynapseError(
403,
"Profile lookup over federation is disabled on this homeserver",
@@ -425,7 +423,7 @@ class ProfileHandler(BaseHandler):
raise
@wrap_as_background_process("Update remote profile")
- async def _update_remote_profile_cache(self):
+ async def _update_remote_profile_cache(self) -> None:
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a49b8ee4..f21f33ad 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
+from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
if TYPE_CHECKING:
@@ -162,7 +163,7 @@ class ReceiptsHandler(BaseHandler):
await self.federation_sender.send_read_receipt(receipt)
-class ReceiptEventSource:
+class ReceiptEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.config = hs.config
@@ -216,7 +217,13 @@ class ReceiptEventSource:
return visible_events
async def get_new_events(
- self, from_key: int, room_ids: List[str], user: UserID, **kwargs
+ self,
+ user: UserID,
+ from_key: int,
+ limit: Optional[int],
+ room_ids: Iterable[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()
@@ -231,7 +238,7 @@ class ReceiptEventSource:
if self.config.experimental.msc2285_enabled:
events = ReceiptEventSource.filter_out_hidden(events, user.to_string())
- return (events, to_key)
+ return events, to_key
async def get_new_events_as(
self, from_key: int, service: ApplicationService
@@ -263,7 +270,7 @@ class ReceiptEventSource:
events.append(event)
- return (events, to_key)
+ return events, to_key
def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 38c4993d..4f99f137 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -97,7 +97,8 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._account_validity_handler = hs.get_account_validity_handler()
- self._server_notices_mxid = hs.config.server_notices_mxid
+ self._user_consent_version = self.hs.config.consent.user_consent_version
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker()
@@ -125,7 +126,7 @@ class RegistrationHandler(BaseHandler):
localpart: str,
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
- ):
+ ) -> None:
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -295,11 +296,10 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
- if self.hs.config.user_directory_search_all_users:
- profile = await self.store.get_profileinfo(localpart)
- await self.user_directory_handler.handle_local_profile_change(
- user_id, profile
- )
+ profile = await self.store.get_profileinfo(localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
else:
# autogen a sequential user ID
@@ -340,7 +340,7 @@ class RegistrationHandler(BaseHandler):
auth_provider=(auth_provider_id or ""),
).inc()
- if not self.hs.config.user_consent_at_registration:
+ if not self.hs.config.consent.user_consent_at_registration:
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
logger.info(
"Skipping auto-join for %s because auto-join for guests is disabled",
@@ -865,7 +865,9 @@ class RegistrationHandler(BaseHandler):
await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result:
- await self._on_user_consented(user_id, self.hs.config.user_consent_version)
+ # The terms type should only exist if consent is enabled.
+ assert self._user_consent_version is not None
+ await self._on_user_consented(user_id, self._user_consent_version)
async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
"""A user consented to the terms on registration
@@ -911,8 +913,8 @@ class RegistrationHandler(BaseHandler):
# getting mail spam where they weren't before if email
# notifs are set up on a homeserver)
if (
- self.hs.config.email_enable_notifs
- and self.hs.config.email_notif_for_new_users
+ self.hs.config.email.email_enable_notifs
+ and self.hs.config.email.email_notif_for_new_users
and token
):
# Pull the ID of the access token back out of the db
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9345ae02..8fede5e9 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,6 +1,4 @@
-# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +20,16 @@ import math
import random
import string
from collections import OrderedDict
-from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+)
from synapse.api.constants import (
EventContentFields,
@@ -49,6 +56,7 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
+from synapse.streams import EventSource
from synapse.types import (
JsonDict,
MutableStateMap,
@@ -118,7 +126,7 @@ class RoomCreationHandler(BaseHandler):
for preset_name, preset_config in self._presets_dict.items():
encrypted = (
preset_name
- in self.config.encryption_enabled_by_default_for_room_presets
+ in self.config.room.encryption_enabled_by_default_for_room_presets
)
preset_config["encrypted"] = encrypted
@@ -133,7 +141,7 @@ class RoomCreationHandler(BaseHandler):
self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
)
- self._server_notices_mxid = hs.config.server_notices_mxid
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -186,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
- ):
+ ) -> str:
"""
Args:
requester: the user requesting the upgrade
@@ -512,7 +520,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
- ):
+ ) -> None:
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
@@ -641,8 +649,16 @@ class RoomCreationHandler(BaseHandler):
requester, config, is_requester_admin=is_requester_admin
)
- if not is_requester_admin and not await self.spam_checker.user_may_create_room(
- user_id
+ invite_3pid_list = config.get("invite_3pid", [])
+ invite_list = config.get("invite", [])
+
+ if not is_requester_admin and not (
+ await self.spam_checker.user_may_create_room(user_id)
+ and await self.spam_checker.user_may_create_room_with_invites(
+ user_id,
+ invite_list,
+ invite_3pid_list,
+ )
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -676,8 +692,6 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
- invite_3pid_list = config.get("invite_3pid", [])
- invite_list = config.get("invite", [])
for i in invite_list:
try:
uid = UserID.from_string(i)
@@ -749,7 +763,9 @@ class RoomCreationHandler(BaseHandler):
)
if is_public:
- if not self.config.is_publishing_room_allowed(user_id, room_id, room_alias):
+ if not self.config.roomdirectory.is_publishing_room_allowed(
+ user_id, room_id, room_alias
+ ):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
@@ -902,7 +918,7 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
- def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
+ def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -910,7 +926,7 @@ class RoomCreationHandler(BaseHandler):
return e
- async def send(etype: str, content: JsonDict, **kwargs) -> int:
+ async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
# Allow these events to be sent even if the user is shadow-banned to
@@ -1033,7 +1049,7 @@ class RoomCreationHandler(BaseHandler):
creator_id: str,
is_public: bool,
room_version: RoomVersion,
- ):
+ ) -> str:
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
@@ -1097,7 +1113,7 @@ class RoomContextHandler:
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
- async def filter_evts(events):
+ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if use_admin_priviledge:
return events
return await filter_events_for_client(
@@ -1175,7 +1191,7 @@ class RoomContextHandler:
return results
-class RoomEventSource:
+class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -1183,8 +1199,8 @@ class RoomEventSource:
self,
user: UserID,
from_key: RoomStreamToken,
- limit: int,
- room_ids: List[str],
+ limit: Optional[int],
+ room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
@@ -1227,7 +1243,7 @@ class RoomEventSource:
else:
end_key = to_key
- return (events, end_key)
+ return events, end_key
def get_current_key(self) -> RoomStreamToken:
return self.store.get_room_max_token()
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 81680b8d..c3d4199e 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -14,7 +14,7 @@
import logging
from collections import namedtuple
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@@ -33,7 +33,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.types import JsonDict, ThirdPartyInstanceID
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
@@ -52,7 +52,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.enable_room_list_search = hs.config.enable_room_list_search
+ self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
] = ResponseCache(hs.get_clock(), "room_list")
@@ -169,7 +169,7 @@ class RoomListHandler(BaseHandler):
ignore_non_federatable=from_federation,
)
- def build_room_entry(room):
+ def build_room_entry(room: JsonDict) -> JsonDict:
entry = {
"room_id": room["room_id"],
"name": room["name"],
@@ -249,10 +249,10 @@ class RoomListHandler(BaseHandler):
self,
room_id: str,
num_joined_users: int,
- cache_context,
+ cache_context: _CacheContext,
with_alias: bool = True,
allow_private: bool = False,
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""Returns the entry for a room
Args:
@@ -507,7 +507,7 @@ class RoomListNextBatch(
)
)
- def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
+ def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
return self._replace(**kwds)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 43902016..afa7e472 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -88,7 +88,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
- self._server_notices_mxid = self.config.server_notices_mxid
+ self._server_notices_mxid = self.config.servernotices.server_notices_mxid
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
@@ -225,7 +225,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: Optional[str],
n_invites: int,
update: bool = True,
- ):
+ ) -> None:
"""Ratelimit more than one invite sent by the given requester in the given room.
Args:
@@ -249,7 +249,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester: Optional[Requester],
room_id: Optional[str],
invitee_user_id: str,
- ):
+ ) -> None:
"""Ratelimit invites by room and by target user.
If room ID is missing then we just rate limit by target user.
@@ -386,7 +386,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return result_event.event_id, result_event.internal_metadata.stream_ordering
async def copy_room_tags_and_direct_to_room(
- self, old_room_id, new_room_id, user_id
+ self, old_room_id: str, new_room_id: str, user_id: str
) -> None:
"""Copies the tags and direct room state from one room to another.
@@ -573,6 +573,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
errcode=Codes.BAD_JSON,
)
+ # The event content should *not* include the authorising user as
+ # it won't be properly signed. Strip it out since it might come
+ # back from a client updating a display name / avatar.
+ #
+ # This only applies to restricted rooms, but there should be no reason
+ # for a client to include it. Unconditionally remove it.
+ content.pop(EventContentFields.AUTHORISING_USER, None)
+
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@@ -668,7 +676,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
" (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE,
)
- if old_membership == "ban" and action != "unban":
+ if old_membership == "ban" and action not in ["ban", "unban", "leave"]:
raise SynapseError(
403,
"Cannot %s user who was banned" % (action,),
@@ -939,7 +947,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# be included in the event content in order to efficiently validate
# the event.
content[
- "join_authorised_via_users_server"
+ EventContentFields.AUTHORISING_USER
] = await self.event_auth_handler.get_user_which_could_invite(
room_id,
current_state_ids,
@@ -1030,7 +1038,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
event: EventBase,
context: EventContext,
ratelimit: bool = True,
- ):
+ ) -> None:
"""
Change the membership status of a user in a room.
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 781da9e8..fb26ee7a 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -541,7 +541,7 @@ class RoomSummaryHandler:
origin: str,
requested_room_id: str,
suggested_only: bool,
- ):
+ ) -> JsonDict:
"""
Implementation of the room hierarchy Federation API.
@@ -1179,4 +1179,4 @@ def _child_events_comparison_key(
order = None
# Items without an order come last.
- return (order is None, order, child.origin_server_ts, child.room_id)
+ return order is None, order, child.origin_server_ts, child.room_id
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 0066d570..2fed9f37 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -40,33 +40,32 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
- creation_time = attr.ib()
+ creation_time: int
# The user interactive authentication session ID associated with this SAML
# session (or None if this SAML session is for an initial login).
- ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+ ui_auth_session_id: Optional[str] = None
class SamlHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._saml_idp_entityid = hs.config.saml2_idp_entityid
+ self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config)
+ self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid
- self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+ self._saml2_session_lifetime = hs.config.saml2.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
- hs.config.saml2_grandfathered_mxid_source_attribute
+ hs.config.saml2.saml2_grandfathered_mxid_source_attribute
)
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
- self._error_template = hs.config.sso_error_template
# plugin to do custom mapping from saml response to mxid
- self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
- hs.config.saml2_user_mapping_provider_config,
+ self._user_mapping_provider = hs.config.saml2.saml2_user_mapping_provider_class(
+ hs.config.saml2.saml2_user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
@@ -359,7 +358,7 @@ class SamlHandler(BaseHandler):
return remote_user_id
- def expire_sessions(self):
+ def expire_sessions(self) -> None:
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
@@ -391,10 +390,10 @@ MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
}
-@attr.s
+@attr.s(auto_attribs=True)
class SamlConfig:
- mxid_source_attribute = attr.ib()
- mxid_mapper = attr.ib()
+ mxid_source_attribute: str
+ mxid_mapper: Callable[[str], str]
class DefaultSamlMappingProvider:
@@ -411,7 +410,7 @@ class DefaultSamlMappingProvider:
self._mxid_mapper = parsed_config.mxid_mapper
self._grandfathered_mxid_source_attribute = (
- module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+ module_api._hs.config.saml2.saml2_grandfathered_mxid_source_attribute
)
def get_remote_user_id(
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index a31fe3e3..25e6b012 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -17,7 +17,7 @@ import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from io import BytesIO
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, Optional
from pkg_resources import parse_version
@@ -79,7 +79,7 @@ async def _sendmail(
msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred()
- def build_sender_factory(**kwargs) -> ESMTPSenderFactory:
+ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
return ESMTPSenderFactory(
username,
password,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 05aa76d6..49fde01c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -184,15 +184,17 @@ class SsoHandler:
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler()
- self._error_template = hs.config.sso_error_template
- self._bad_user_template = hs.config.sso_auth_bad_user_template
+ self._error_template = hs.config.sso.sso_error_template
+ self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
self._profile_handler = hs.get_profile_handler()
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
- self._sso_auth_success_template = hs.config.sso_auth_success_template
+ self._sso_auth_success_template = hs.config.sso.sso_auth_success_template
- self._sso_update_profile_information = hs.config.sso_update_profile_information
+ self._sso_update_profile_information = (
+ hs.config.sso.sso_update_profile_information
+ )
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
@@ -205,7 +207,7 @@ class SsoHandler:
self._consent_at_registration = hs.config.consent.user_consent_at_registration
- def register_identity_provider(self, p: SsoIdentityProvider):
+ def register_identity_provider(self, p: SsoIdentityProvider) -> None:
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
@@ -856,7 +858,7 @@ class SsoHandler:
async def handle_terms_accepted(
self, request: Request, session_id: str, terms_version: str
- ):
+ ) -> None:
"""Handle a request to the new-user 'consent' endpoint
Will serve an HTTP response to the request.
@@ -959,7 +961,7 @@ class SsoHandler:
new_user=True,
)
- def _expire_old_sessions(self):
+ def _expire_old_sessions(self) -> None:
to_expire = []
now = int(self._clock.time_msec())
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index b64ce8ca..bd3e6f2e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -46,7 +46,7 @@ class StatsHandler:
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.stats_enabled = hs.config.stats_enabled
+ self.stats_enabled = hs.config.stats.stats_enabled
# The current position in the current_state_delta stream
self.pos: Optional[int] = None
@@ -68,7 +68,7 @@ class StatsHandler:
self._is_processing = True
- async def process():
+ async def process() -> None:
try:
await self._unsafe_process()
finally:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index edfdb99c..2c7c6d63 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -364,7 +364,9 @@ class SyncHandler:
)
else:
- async def current_sync_callback(before_token, after_token) -> SyncResult:
+ async def current_sync_callback(
+ before_token: StreamToken, after_token: StreamToken
+ ) -> SyncResult:
return await self.current_sync_for_user(sync_config, since_token)
result = await self.notifier.wait_for_events(
@@ -441,7 +443,7 @@ class SyncHandler:
room_ids = sync_result_builder.joined_room_ids
- typing_source = self.event_sources.sources["typing"]
+ typing_source = self.event_sources.sources.typing
typing, typing_key = await typing_source.get_new_events(
user=sync_config.user,
from_key=typing_key,
@@ -463,7 +465,7 @@ class SyncHandler:
receipt_key = since_token.receipt_key if since_token else 0
- receipt_source = self.event_sources.sources["receipt"]
+ receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user,
from_key=receipt_key,
@@ -1090,7 +1092,7 @@ class SyncHandler:
block_all_presence_data = (
since_token is None and sync_config.filter_collection.blocks_all_presence()
)
- if self.hs_config.use_presence and not block_all_presence_data:
+ if self.hs_config.server.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
sync_result_builder,
@@ -1413,7 +1415,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config
user = sync_result_builder.sync_config.user
- presence_source = self.event_sources.sources["presence"]
+ presence_source = self.event_sources.sources.presence
since_token = sync_result_builder.since_token
presence_key = None
@@ -1532,9 +1534,9 @@ class SyncHandler:
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
- async def handle_room_entries(room_entry: "RoomSyncResultBuilder"):
+ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id)
- res = await self._generate_room_entry(
+ await self._generate_room_entry(
sync_result_builder,
ignored_users,
room_entry,
@@ -1544,7 +1546,6 @@ class SyncHandler:
always_include=sync_result_builder.full_state,
)
logger.debug("Generated room entry for %s", room_entry.room_id)
- return res
await concurrently_execute(handle_room_entries, room_entries, 10)
@@ -1925,7 +1926,7 @@ class SyncHandler:
tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
- ):
+ ) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 9cea011e..d10e9b8e 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.replication.tcp.streams import TypingStream
+from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
@@ -439,7 +440,7 @@ class TypingWriterHandler(FollowerTypingHandler):
raise Exception("Typing writer instance got typing info over replication")
-class TypingNotificationEventSource:
+class TypingNotificationEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
@@ -482,10 +483,16 @@ class TypingNotificationEventSource:
events.append(self._make_event_for(room_id))
- return (events, handler._latest_room_serial)
+ return events, handler._latest_room_serial
async def get_new_events(
- self, from_key: int, room_ids: Iterable[str], **kwargs
+ self,
+ user: UserID,
+ from_key: int,
+ limit: Optional[int],
+ room_ids: Iterable[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
@@ -500,7 +507,7 @@ class TypingNotificationEventSource:
events.append(self._make_event_for(room_id))
- return (events, handler._latest_room_serial)
+ return events, handler._latest_room_serial
def get_current_key(self) -> int:
return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index d3828dec..8f5d465f 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -70,7 +70,7 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
class TermsAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.TERMS
- def is_enabled(self):
+ def is_enabled(self) -> bool:
return True
async def check_auth(self, authdict: dict, clientip: str) -> Any:
@@ -82,10 +82,10 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self._enabled = bool(hs.config.recaptcha_private_key)
+ self._enabled = bool(hs.config.captcha.recaptcha_private_key)
self._http_client = hs.get_proxied_http_client()
- self._url = hs.config.recaptcha_siteverify_api
- self._secret = hs.config.recaptcha_private_key
+ self._url = hs.config.captcha.recaptcha_siteverify_api
+ self._secret = hs.config.captcha.recaptcha_private_key
def is_enabled(self) -> bool:
return self._enabled
@@ -161,12 +161,17 @@ class _BaseThreepidAuthChecker:
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
elif medium == "email":
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if (
+ self.hs.config.email.threepid_behaviour_email
+ == ThreepidBehaviour.REMOTE
+ ):
assert self.hs.config.account_threepid_delegate_email
threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
- elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ elif (
+ self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ ):
threepid = None
row = await self.store.get_threepid_validation_session(
medium,
@@ -218,7 +223,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
- return self.hs.config.threepid_behaviour_email in (
+ return self.hs.config.email.threepid_behaviour_email in (
ThreepidBehaviour.REMOTE,
ThreepidBehaviour.LOCAL,
)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 6faa1d84..b91e7cb5 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -61,7 +61,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory
- self.search_all_users = hs.config.user_directory_search_all_users
+ self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
self.pos: Optional[int] = None
@@ -114,7 +114,7 @@ class UserDirectoryHandler(StateDeltasHandler):
if self._is_processing:
return
- async def process():
+ async def process() -> None:
try:
await self._unsafe_process()
finally:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index c2ea51ee..5204c3d0 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -321,8 +321,11 @@ class SimpleHttpClient:
self.user_agent = hs.version_string
self.clock = hs.get_clock()
- if hs.config.user_agent_suffix:
- self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+ if hs.config.server.user_agent_suffix:
+ self.user_agent = "%s %s" % (
+ self.user_agent,
+ hs.config.server.user_agent_suffix,
+ )
# We use this for our body producers to ensure that they use the correct
# reactor.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2e989899..cdc36b8d 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -66,7 +66,7 @@ from synapse.http.client import (
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging import opentracing
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict
from synapse.util import json_decoder
@@ -465,8 +465,9 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
if (
- self.hs.config.federation_domain_whitelist is not None
- and request.destination not in self.hs.config.federation_domain_whitelist
+ self.hs.config.federation.federation_domain_whitelist is not None
+ and request.destination
+ not in self.hs.config.federation.federation_domain_whitelist
):
raise FederationDeniedError(request.destination)
@@ -553,20 +554,29 @@ class MatrixFederationHttpClient:
with Measure(self.clock, "outbound_request"):
# we don't want all the fancy cookie and redirect handling
# that treq.request gives: just use the raw Agent.
- request_deferred = self.agent.request(
+
+ # To preserve the logging context, the timeout is treated
+ # in a similar way to `defer.gatherResults`:
+ # * Each logging context-preserving fork is wrapped in
+ # `run_in_background`. In this case there is only one,
+ # since the timeout fork is not logging-context aware.
+ # * The `Deferred` that joins the forks back together is
+ # wrapped in `make_deferred_yieldable` to restore the
+ # logging context regardless of the path taken.
+ request_deferred = run_in_background(
+ self.agent.request,
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
-
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.reactor,
)
- response = await request_deferred
+ response = await make_deferred_yieldable(request_deferred)
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
@@ -1177,7 +1187,7 @@ class MatrixFederationHttpClient:
request.method,
request.uri.decode("ascii"),
)
- return (length, headers)
+ return length, headers
def _flatten_response_never_received(e):
diff --git a/synapse/http/server.py b/synapse/http/server.py
index b79fa722..0df1bfbe 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,7 +21,6 @@ import types
import urllib
from http import HTTPStatus
from inspect import isawaitable
-from io import BytesIO
from typing import (
Any,
Awaitable,
@@ -37,7 +36,7 @@ from typing import (
)
import jinja2
-from canonicaljson import iterencode_canonical_json
+from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
from zope.interface import implementer
@@ -45,7 +44,7 @@ from twisted.internet import defer, interfaces
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request
-from twisted.web.static import File, NoRangeStaticProducer
+from twisted.web.static import File
from twisted.web.util import redirectTo
from synapse.api.errors import (
@@ -56,10 +55,11 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.context import preserve_fn
+from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
+from synapse.util.iterutils import chunk_seq
logger = logging.getLogger(__name__)
@@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response(
self,
- request: Request,
+ request: SynapseRequest,
code: int,
response_object: Any,
):
@@ -561,9 +561,17 @@ class _ByteProducer:
self._iterator = iterator
self._paused = False
- # Register the producer and start producing data.
- self._request.registerProducer(self, True)
- self.resumeProducing()
+ try:
+ self._request.registerProducer(self, True)
+ except RuntimeError as e:
+ logger.info("Connection disconnected before response was written: %r", e)
+
+ # We drop our references to data we'll not use.
+ self._request = None
+ self._iterator = iter(())
+ else:
+ # Start producing if `registerProducer` was successful
+ self.resumeProducing()
def _send_data(self, data: List[bytes]) -> None:
"""
@@ -620,16 +628,15 @@ class _ByteProducer:
self._request = None
-def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
+def _encode_json_bytes(json_object: Any) -> bytes:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
- for chunk in json_encoder.iterencode(json_object):
- yield chunk.encode("utf-8")
+ return json_encoder.encode(json_object).encode("utf-8")
def respond_with_json(
- request: Request,
+ request: SynapseRequest,
code: int,
json_object: Any,
send_cors: bool = False,
@@ -659,7 +666,7 @@ def respond_with_json(
return None
if canonical_json:
- encoder = iterencode_canonical_json
+ encoder = encode_canonical_json
else:
encoder = _encode_json_bytes
@@ -670,7 +677,9 @@ def respond_with_json(
if send_cors:
set_cors_headers(request)
- _ByteProducer(request, encoder(json_object))
+ run_in_background(
+ _async_write_json_to_request_in_thread, request, encoder, json_object
+ )
return NOT_DONE_YET
@@ -706,15 +715,56 @@ def respond_with_json_bytes(
if send_cors:
set_cors_headers(request)
- # note that this is zero-copy (the bytesio shares a copy-on-write buffer with
- # the original `bytes`).
- bytes_io = BytesIO(json_bytes)
-
- producer = NoRangeStaticProducer(request, bytes_io)
- producer.start()
+ _write_bytes_to_request(request, json_bytes)
return NOT_DONE_YET
+async def _async_write_json_to_request_in_thread(
+ request: SynapseRequest,
+ json_encoder: Callable[[Any], bytes],
+ json_object: Any,
+):
+ """Encodes the given JSON object on a thread and then writes it to the
+ request.
+
+ This is done so that encoding large JSON objects doesn't block the reactor
+ thread.
+
+ Note: We don't use JsonEncoder.iterencode here as that falls back to the
+ Python implementation (rather than the C backend), which is *much* more
+ expensive.
+ """
+
+ json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
+
+ _write_bytes_to_request(request, json_str)
+
+
+def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
+ """Writes the bytes to the request using an appropriate producer.
+
+ Note: This should be used instead of `Request.write` to correctly handle
+ large response bodies.
+ """
+
+ # The problem with dumping all of the response into the `Request` object at
+ # once (via `Request.write`) is that doing so starts the timeout for the
+ # next request to be received: so if it takes longer than 60s to stream back
+ # the response to the client, the client never gets it.
+ #
+ # The correct solution is to use a Producer; then the timeout is only
+ # started once all of the content is sent over the TCP connection.
+
+ # To make sure we don't write all of the bytes at once we split it up into
+ # chunks.
+ chunk_size = 4096
+ bytes_generator = chunk_seq(bytes_to_write, chunk_size)
+
+ # We use a `_ByteProducer` here rather than `NoRangeStaticProducer` as the
+ # unit tests can't cope with being given a pull producer.
+ _ByteProducer(request, bytes_generator)
+
+
def set_cors_headers(request: Request):
"""Set the CORS headers so that javascript running in a web browsers can
use this API
diff --git a/synapse/http/site.py b/synapse/http/site.py
index c665a9d5..755ad566 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,14 +14,15 @@
import contextlib
import logging
import time
-from typing import Optional, Tuple, Union
+from typing import Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
-from twisted.web.resource import IResource
+from twisted.web.http import HTTPChannel
+from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@@ -61,10 +62,18 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
- def __init__(self, channel, *args, max_request_body_size=1024, **kw):
- Request.__init__(self, channel, *args, **kw)
+ def __init__(
+ self,
+ channel: HTTPChannel,
+ site: "SynapseSite",
+ *args,
+ max_request_body_size: int = 1024,
+ **kw,
+ ):
+ super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
- self.site: SynapseSite = channel.site
+ self.synapse_site = site
+ self.reactor = site.reactor
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -83,13 +92,13 @@ class SynapseRequest(Request):
self._is_processing = False
# the time when the asynchronous request handler completed its processing
- self._processing_finished_time = None
+ self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection
# dropped)
- self.finish_time = None
+ self.finish_time: Optional[float] = None
- def __repr__(self):
+ def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
@@ -97,10 +106,10 @@ class SynapseRequest(Request):
self.get_method(),
self.get_redacted_uri(),
self.clientproto.decode("ascii", errors="replace"),
- self.site.site_tag,
+ self.synapse_site.site_tag,
)
- def handleContentChunk(self, data):
+ def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
@@ -139,7 +148,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
- def get_request_id(self):
+ def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str:
@@ -205,7 +214,7 @@ class SynapseRequest(Request):
return None, None
- def render(self, resrc):
+ def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
@@ -216,7 +225,7 @@ class SynapseRequest(Request):
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
- site_tag=self.site.site_tag,
+ site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
authenticated_entity=None,
@@ -228,7 +237,7 @@ class SynapseRequest(Request):
)
# override the Server header which is set by twisted
- self.setHeader("Server", self.site.server_version_string)
+ self.setHeader("Server", self.synapse_site.server_version_string)
with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab
@@ -247,7 +256,7 @@ class SynapseRequest(Request):
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
- def processing(self):
+ def processing(self) -> Generator[None, None, None]:
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@@ -282,7 +291,7 @@ class SynapseRequest(Request):
if self.finish_time is not None:
self._finished_processing()
- def finish(self):
+ def finish(self) -> None:
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
@@ -295,7 +304,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
@@ -327,7 +336,7 @@ class SynapseRequest(Request):
if not self._is_processing:
self._finished_processing()
- def _started_processing(self, servlet_name):
+ def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
@@ -346,17 +355,19 @@ class SynapseRequest(Request):
self.start_time, name=servlet_name, method=self.get_method()
)
- self.site.access_logger.debug(
+ self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s",
self.getClientIP(),
- self.site.site_tag,
+ self.synapse_site.site_tag,
self.get_method(),
self.get_redacted_uri(),
)
- def _finished_processing(self):
+ def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics"""
assert self.logcontext is not None
+ assert self.finish_time is not None
+
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
@@ -386,13 +397,13 @@ class SynapseRequest(Request):
if authenticated_entity:
requester = f"{authenticated_entity}|{requester}"
- self.site.access_logger.log(
+ self.synapse_site.access_logger.log(
log_level,
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
- self.site.site_tag,
+ self.synapse_site.site_tag,
requester,
processing_time,
response_send_time,
@@ -437,7 +448,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False
- def requestReceived(self, command, path, version):
+ def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the
@@ -445,7 +456,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers()
return super().requestReceived(command, path, version)
- def _process_forwarded_headers(self):
+ def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
@@ -470,7 +481,7 @@ class XForwardedForRequest(SynapseRequest):
)
self._forwarded_https = True
- def isSecure(self):
+ def isSecure(self) -> bool:
if self._forwarded_https:
return True
return super().isSecure()
@@ -520,7 +531,7 @@ class SynapseSite(Site):
site_tag: str,
config: ListenerConfig,
resource: IResource,
- server_version_string,
+ server_version_string: str,
max_request_body_size: int,
reactor: IReactorTime,
):
@@ -540,19 +551,23 @@ class SynapseSite(Site):
Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag
+ self.reactor = reactor
assert config.http_options is not None
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
- def request_factory(channel, queued) -> Request:
+ def request_factory(channel, queued: bool) -> Request:
return request_class(
- channel, max_request_body_size=max_request_body_size, queued=queued
+ channel,
+ self,
+ max_request_body_size=max_request_body_size,
+ queued=queued,
)
self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
- def log(self, request):
+ def log(self, request: SynapseRequest) -> None:
pass
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c6c4d3bd..03d2dd94 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -363,7 +363,7 @@ def noop_context_manager(*args, **kwargs):
def init_tracer(hs: "HomeServer"):
"""Set the whitelists and initialise the JaegerClient tracer"""
global opentracing
- if not hs.config.opentracer_enabled:
+ if not hs.config.tracing.opentracer_enabled:
# We don't have a tracer
opentracing = None
return
@@ -377,12 +377,12 @@ def init_tracer(hs: "HomeServer"):
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
- set_homeserver_whitelist(hs.config.opentracer_whitelist)
+ set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist)
from jaeger_client.metrics.prometheus import PrometheusMetricsFactory
config = JaegerConfig(
- config=hs.config.jaeger_config,
+ config=hs.config.tracing.jaeger_config,
service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}",
scope_manager=LogContextScopeManager(hs.config),
metrics_factory=PrometheusMetricsFactory(),
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2d403532..8ae21bc4 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -24,8 +24,10 @@ from typing import (
List,
Optional,
Tuple,
+ Union,
)
+import attr
import jinja2
from twisted.internet import defer
@@ -46,7 +48,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester
+from synapse.types import (
+ DomainSpecificString,
+ JsonDict,
+ Requester,
+ UserID,
+ UserInfo,
+ create_requester,
+)
from synapse.util import Clock
from synapse.util.caches.descriptors import cached
@@ -79,6 +88,18 @@ __all__ = [
logger = logging.getLogger(__name__)
+@attr.s(auto_attribs=True)
+class UserIpAndAgent:
+ """
+ An IP address and user agent used by a user to connect to this homeserver.
+ """
+
+ ip: str
+ user_agent: str
+ # The time at which this user agent/ip was last seen.
+ last_seen: int
+
+
class ModuleApi:
"""A proxy object that gets passed to various plugin modules so they
can register new users etc if necessary.
@@ -91,21 +112,23 @@ class ModuleApi:
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
- self._presence_stream = hs.get_event_sources().sources["presence"]
+ self._presence_stream = hs.get_event_sources().sources.presence
self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock()
self._send_email_handler = hs.get_send_email_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
try:
- app_name = self._hs.config.email_app_name
+ app_name = self._hs.config.email.email_app_name
- self._from_string = self._hs.config.email_notif_from % {"app": app_name}
+ self._from_string = self._hs.config.email.email_notif_from % {
+ "app": app_name
+ }
except (KeyError, TypeError):
# If substitution failed (which can happen if the string contains
# placeholders other than just "app", or if the type of the placeholder is
# not a string), fall back to the bare strings.
- self._from_string = self._hs.config.email_notif_from
+ self._from_string = self._hs.config.email.email_notif_from
self._raw_from = email.utils.parseaddr(self._from_string)[1]
@@ -700,6 +723,65 @@ class ModuleApi:
(td for td in (self.custom_template_dir, custom_template_directory) if td),
)
+ def is_mine(self, id: Union[str, DomainSpecificString]) -> bool:
+ """
+ Checks whether an ID (user id, room, ...) comes from this homeserver.
+
+ Args:
+ id: any Matrix id (e.g. user id, room id, ...), either as a raw id,
+ e.g. string "@user:example.com" or as a parsed UserID, RoomID, ...
+ Returns:
+ True if id comes from this homeserver, False otherwise.
+
+ Added in Synapse v1.44.0.
+ """
+ if isinstance(id, DomainSpecificString):
+ return self._hs.is_mine(id)
+ else:
+ return self._hs.is_mine_id(id)
+
+ async def get_user_ip_and_agents(
+ self, user_id: str, since_ts: int = 0
+ ) -> List[UserIpAndAgent]:
+ """
+ Return the list of user IPs and agents for a user.
+
+ Args:
+ user_id: the id of a user, local or remote
+ since_ts: a timestamp in seconds since the epoch,
+ or the epoch itself if not specified.
+ Returns:
+ The list of all UserIpAndAgent that the user has
+ used to connect to this homeserver since `since_ts`.
+ If the user is remote, this list is empty.
+
+ Added in Synapse v1.44.0.
+ """
+ # Don't hit the db if this is not a local user.
+ is_mine = False
+ try:
+ # Let's be defensive against ill-formed strings.
+ if self.is_mine(user_id):
+ is_mine = True
+ except Exception:
+ pass
+
+ if is_mine:
+ raw_data = await self._store.get_user_ip_and_agents(
+ UserID.from_string(user_id), since_ts
+ )
+ # Sanitize some of the data. We don't want to return tokens.
+ return [
+ UserIpAndAgent(
+ ip=str(data["ip"]),
+ user_agent=str(data["user_agent"]),
+ last_seen=int(data["last_seen"]),
+ )
+ for data in raw_data
+ ]
+ else:
+ return []
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/notifier.py b/synapse/notifier.py
index bbe33794..1a9f84ba 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -584,7 +584,7 @@ class Notifier:
events: List[EventBase] = []
end_token = from_token
- for name, source in self.event_sources.sources.items():
+ for name, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index e08e125c..cf5abdfb 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -184,7 +184,7 @@ class EmailPusher(Pusher):
should_notify_at = max(notif_ready_at, room_ready_at)
- if should_notify_at < self.clock.time_msec():
+ if should_notify_at <= self.clock.time_msec():
# one of our notifications is ready for sending, so we send
# *one* email updating the user on their notifications,
# we then consider all previously outstanding notifications
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 36aabd84..eac65572 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -73,7 +73,9 @@ class HttpPusher(Pusher):
self.failing_since = pusher_config.failing_since
self.timed_call: Optional[IDelayedCall] = None
self._is_processing = False
- self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
+ self._group_unread_count_by_room = (
+ hs.config.push.push_group_unread_count_by_room
+ )
self._pusherpool = hs.get_pusherpool()
self.data = pusher_config.data
@@ -365,7 +367,7 @@ class HttpPusher(Pusher):
if event.type == "m.room.member" and event.is_state():
d["notification"]["membership"] = event.content["membership"]
d["notification"]["user_is_target"] = event.state_key == self.user_id
- if self.hs.config.push_include_content and event.content:
+ if self.hs.config.push.push_include_content and event.content:
d["notification"]["content"] = event.content
# We no longer send aliases separately, instead, we send the human
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index b89c6e6f..e38e3c5d 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -110,7 +110,7 @@ class Mailer:
self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self.app_name = app_name
- self.email_subjects: EmailSubjectConfig = hs.config.email_subjects
+ self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
logger.info("Created Mailer for app_name %s" % app_name)
@@ -796,8 +796,8 @@ class Mailer:
Returns:
A link to open a room in the web client.
"""
- if self.hs.config.email_riot_base_url:
- base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
+ if self.hs.config.email.email_riot_base_url:
+ base_url = "%s/#/room" % (self.hs.config.email.email_riot_base_url)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
@@ -815,9 +815,9 @@ class Mailer:
Returns:
A link to open the notification in the web client.
"""
- if self.hs.config.email_riot_base_url:
+ if self.hs.config.email.email_riot_base_url:
return "%s/#/room/%s/%s" % (
- self.hs.config.email_riot_base_url,
+ self.hs.config.email.email_riot_base_url,
notif["room_id"],
notif["event_id"],
)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 02127543..b57e0940 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -35,12 +35,12 @@ class PusherFactory:
"http": HttpPusher
}
- logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
- if hs.config.email_enable_notifs:
+ logger.info("email enable notifs: %r", hs.config.email.email_enable_notifs)
+ if hs.config.email.email_enable_notifs:
self.mailers: Dict[str, Mailer] = {}
- self._notif_template_html = hs.config.email_notif_template_html
- self._notif_template_text = hs.config.email_notif_template_text
+ self._notif_template_html = hs.config.email.email_notif_template_html
+ self._notif_template_text = hs.config.email.email_notif_template_text
self.pusher_types["email"] = self._create_email_pusher
@@ -77,4 +77,4 @@ class PusherFactory:
if isinstance(brand, str):
return brand
- return self.config.email_app_name
+ return self.config.email.email_app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index a1436f39..26735447 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,7 +62,7 @@ class PusherPool:
self.clock = self.hs.get_clock()
# We shard the handling of push notifications by user ID.
- self._pusher_shard_config = hs.config.push.pusher_shard_config
+ self._pusher_shard_config = hs.config.worker.pusher_shard_config
self._instance_name = hs.get_instance_name()
self._should_start_pushers = (
self._instance_name in self._pusher_shard_config.instances
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 25589b00..f1b78d09 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -168,8 +168,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
client = hs.get_simple_http_client()
local_instance_name = hs.get_instance_name()
- master_host = hs.config.worker_replication_host
- master_port = hs.config.worker_replication_http_port
+ master_host = hs.config.worker.worker_replication_host
+ master_port = hs.config.worker.worker_replication_http_port
instance_map = hs.config.worker.instance_map
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 509ed7fb..1438a82b 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -322,8 +322,8 @@ class ReplicationCommandHandler:
else:
client_name = hs.get_instance_name()
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
- host = hs.config.worker_replication_host
- port = hs.config.worker_replication_port
+ host = hs.config.worker.worker_replication_host
+ port = hs.config.worker.worker_replication_port
hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3adc5761..e04af705 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import JsonResource
+from typing import TYPE_CHECKING
+
+from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
from synapse.rest.client import (
account,
@@ -57,6 +59,9 @@ from synapse.rest.client import (
voip,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
@@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
- def register_servlets(client_resource, hs):
+ def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource)
# Deprecated in r0
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index a03774c9..e1506deb 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -267,7 +267,7 @@ def register_servlets_for_client_rest_resource(
# Load the media repo ones if we're using them. Otherwise load the servlets which
# don't need a media repo (typically readonly admin APIs).
- if hs.config.can_load_media_repo:
+ if hs.config.media.can_load_media_repo:
register_servlets_for_media_repo(hs, http_server)
else:
ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 5715190a..a6fa03c9 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore()
async def on_GET(
- self, request: SynapseRequest, user_id, device_id: str
+ self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 5a1c929d..aba48f6e 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -113,7 +113,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
# A string of all the characters allowed to be in a registration_token
- self.allowed_chars = string.ascii_letters + string.digits + "-_"
+ self.allowed_chars = string.ascii_letters + string.digits + "._~-"
self.allowed_chars_set = set(self.allowed_chars)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ad83d4b5..a4823ca6 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -125,7 +125,7 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- search_term = parse_string(request, "search_term")
+ search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "":
raise SynapseError(
400,
@@ -213,7 +213,7 @@ class RoomRestServlet(RestServlet):
members = await self.store.get_users_in_room(room_id)
ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
- return (200, ret)
+ return 200, ret
async def on_DELETE(
self, request: SynapseRequest, room_id: str
@@ -668,4 +668,4 @@ async def _delete_room(
if purge:
await pagination_handler.purge_room(room_id, force=force_purge)
- return (200, ret)
+ return 200, ret
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index f5a38c26..19f84f33 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
- def register(self, json_resource: HttpServer):
+ def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index c1a1ba64..46bfec46 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -368,8 +368,8 @@ class UserRestServletV2(RestServlet):
user_id, medium, address, current_time
)
if (
- self.hs.config.email_enable_notifs
- and self.hs.config.email_notif_for_new_users
+ self.hs.config.email.email_enable_notifs
+ and self.hs.config.email.email_notif_for_new_users
):
await self.pusher_pool.add_pusher(
user_id=user_id,
@@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {}
self.hs = hs
- def _clear_old_nonces(self):
+ def _clear_old_nonces(self) -> None:
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index aefaaa8a..6a7608d6 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -64,17 +64,17 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.config = hs.config
self.identity_handler = hs.get_identity_handler()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
- app_name=self.config.email_app_name,
- template_html=self.config.email_password_reset_template_html,
- template_text=self.config.email_password_reset_template_text,
+ app_name=self.config.email.email_app_name,
+ template_html=self.config.email.email_password_reset_template_html,
+ template_text=self.config.email.email_password_reset_template_text,
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"User password resets have been disabled due to lack of email config"
)
@@ -129,7 +129,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
@@ -349,17 +349,17 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.store = self.hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
- app_name=self.config.email_app_name,
- template_html=self.config.email_add_threepid_template_html,
- template_text=self.config.email_add_threepid_template_text,
+ app_name=self.config.email.email_app_name,
+ template_html=self.config.email.email_add_threepid_template_html,
+ template_text=self.config.email.email_add_threepid_template_text,
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"Adding emails have been disabled due to lack of an email config"
)
@@ -413,7 +413,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
@@ -534,21 +534,21 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
- self.config.email_add_threepid_template_failure_html
+ self.config.email.email_add_threepid_template_failure_html
)
async def on_GET(self, request: Request) -> None:
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"Adding emails have been disabled due to lack of an email config"
)
raise SynapseError(
400, "Adding an email to your account is disabled on this server"
)
- elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
raise SynapseError(
400,
"This homeserver is not validating threepids. Use an identity server "
@@ -575,7 +575,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
- html = self.config.email_add_threepid_template_success_html_content
+ html = self.config.email.email_add_threepid_template_success_html_content
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 7bb78014..282861fa 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -47,7 +47,7 @@ class AuthRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
- self.recaptcha_template = hs.config.recaptcha_template
+ self.recaptcha_template = hs.config.captcha.recaptcha_template
self.terms_template = hs.config.terms_template
self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
@@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet):
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- sitekey=self.hs.config.recaptcha_public_key,
+ sitekey=self.hs.config.captcha.recaptcha_public_key,
)
elif stagetype == LoginType.TERMS:
html = self.terms_template.render(
@@ -70,7 +70,7 @@ class AuthRestServlet(RestServlet):
terms_url="%s_matrix/consent?v=%s"
% (
self.hs.config.server.public_baseurl,
- self.hs.config.user_consent_version,
+ self.hs.config.consent.user_consent_version,
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
@@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet):
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- sitekey=self.hs.config.recaptcha_public_key,
+ sitekey=self.hs.config.captcha.recaptcha_public_key,
error=e.msg,
)
else:
@@ -139,7 +139,7 @@ class AuthRestServlet(RestServlet):
terms_url="%s_matrix/consent?v=%s"
% (
self.hs.config.server.public_baseurl,
- self.hs.config.user_consent_version,
+ self.hs.config.consent.user_consent_version,
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 25bc3c8f..8566dc5c 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -211,7 +211,7 @@ class DehydratedDeviceServlet(RestServlet):
if dehydrated_device is not None:
(device_id, device_data) = dehydrated_device
result = {"device_id": device_id, "device_data": device_data}
- return (200, result)
+ return 200, result
else:
raise errors.NotFoundError("No dehydrated device available")
@@ -293,7 +293,7 @@ class ClaimDehydratedDeviceServlet(RestServlet):
submission["device_id"],
)
- return (200, result)
+ return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index a6ede7e2..fa5c173f 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -69,16 +69,16 @@ class LoginRestServlet(RestServlet):
self.hs = hs
# JWT configuration variables.
- self.jwt_enabled = hs.config.jwt_enabled
- self.jwt_secret = hs.config.jwt_secret
- self.jwt_algorithm = hs.config.jwt_algorithm
- self.jwt_issuer = hs.config.jwt_issuer
- self.jwt_audiences = hs.config.jwt_audiences
+ self.jwt_enabled = hs.config.jwt.jwt_enabled
+ self.jwt_secret = hs.config.jwt.jwt_secret
+ self.jwt_algorithm = hs.config.jwt.jwt_algorithm
+ self.jwt_issuer = hs.config.jwt.jwt_issuer
+ self.jwt_audiences = hs.config.jwt.jwt_audiences
# SSO configuration.
- self.saml2_enabled = hs.config.saml2_enabled
- self.cas_enabled = hs.config.cas_enabled
- self.oidc_enabled = hs.config.oidc_enabled
+ self.saml2_enabled = hs.config.saml2.saml2_enabled
+ self.cas_enabled = hs.config.cas.cas_enabled
+ self.oidc_enabled = hs.config.oidc.oidc_enabled
self._msc2918_enabled = hs.config.access_token_lifetime is not None
self.auth = hs.get_auth()
@@ -559,7 +559,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
- if hs.config.cas_enabled:
+ if hs.config.cas.cas_enabled:
CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py
index 6d64efb1..9f190800 100644
--- a/synapse/rest/client/password_policy.py
+++ b/synapse/rest/client/password_policy.py
@@ -35,12 +35,12 @@ class PasswordPolicyServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
- self.policy = hs.config.password_policy
- self.enabled = hs.config.password_policy_enabled
+ self.policy = hs.config.auth.password_policy
+ self.enabled = hs.config.auth.password_policy_enabled
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.enabled or not self.policy:
- return (200, {})
+ return 200, {}
policy = {}
@@ -54,7 +54,7 @@ class PasswordPolicyServlet(RestServlet):
if param in self.policy:
policy["m.%s" % param] = self.policy[param]
- return (200, policy)
+ return 200, policy
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index abe4d7e2..48b0062c 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -75,17 +75,19 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.config = hs.config
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
hs=self.hs,
- app_name=self.config.email_app_name,
- template_html=self.config.email_registration_template_html,
- template_text=self.config.email_registration_template_text,
+ app_name=self.config.email.email_app_name,
+ template_html=self.config.email.email_registration_template_html,
+ template_text=self.config.email.email_registration_template_text,
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if (
+ self.hs.config.email.local_threepid_handling_disabled_due_to_email_config
+ ):
logger.warning(
"Email registration has been disabled due to lack of email config"
)
@@ -137,7 +139,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
@@ -259,9 +261,9 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
- self.config.email_registration_template_failure_html
+ self.config.email.email_registration_template_failure_html
)
async def on_GET(self, request: Request, medium: str) -> None:
@@ -269,8 +271,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
raise SynapseError(
400, "This medium is currently not supported for registration"
)
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
+ if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.email.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"User registration via email has been disabled due to lack of email config"
)
@@ -303,7 +305,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
- html = self.config.email_registration_template_success_html_content
+ html = self.config.email.email_registration_template_success_html_content
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
@@ -897,12 +899,12 @@ def _calculate_registration_flows(
flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
# Prepend m.login.terms to all flows if we're requiring consent
- if config.user_consent_at_registration:
+ if config.consent.user_consent_at_registration:
for flow in flows:
flow.insert(0, LoginType.TERMS)
# Prepend recaptcha to all flows if we're requiring captcha
- if config.enable_registration_captcha:
+ if config.captcha.enable_registration_captcha:
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index ed969784..bf14ec38 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,6 +14,7 @@
import logging
import re
+from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, List, Tuple
from twisted.web.server import Request
@@ -42,25 +43,25 @@ logger = logging.getLogger(__name__)
class RoomBatchSendEventRestServlet(RestServlet):
"""
- API endpoint which can insert a chunk of events historically back in time
+ API endpoint which can insert a batch of events historically back in time
next to the given `prev_event`.
- `chunk_id` comes from `next_chunk_id `in the response of the batch send
- endpoint and is derived from the "insertion" events added to each chunk.
+ `batch_id` comes from `next_batch_id `in the response of the batch send
+ endpoint and is derived from the "insertion" events added to each batch.
It's not required for the first batch send.
`state_events_at_start` is used to define the historical state events
needed to auth the events like join events. These events will float
outside of the normal DAG as outlier's and won't be visible in the chat
- history which also allows us to insert multiple chunks without having a bunch
- of `@mxid joined the room` noise between each chunk.
+ history which also allows us to insert multiple batches without having a bunch
+ of `@mxid joined the room` noise between each batch.
- `events` is chronological chunk/list of events you want to insert.
- There is a reverse-chronological constraint on chunks so once you insert
+ `events` is chronological list of events you want to insert.
+ There is a reverse-chronological constraint on batches so once you insert
some messages, you can only insert older ones after that.
- tldr; Insert chunks from your most recent history -> oldest history.
+ tldr; Insert batches from your most recent history -> oldest history.
- POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event=<eventID>&chunk_id=<chunkID>
+ POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event_id=<eventID>&batch_id=<batchID>
{
"events": [ ... ],
"state_events_at_start": [ ... ]
@@ -128,7 +129,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
self, sender: str, room_id: str, origin_server_ts: int
) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields
- and a random chunk ID.
+ and a random batch ID.
Args:
sender: The event author MXID
@@ -139,13 +140,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
The new event dictionary to insert.
"""
- next_chunk_id = random_string(8)
+ next_batch_id = random_string(8)
insertion_event = {
"type": EventTypes.MSC2716_INSERTION,
"sender": sender,
"room_id": room_id,
"content": {
- EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
+ EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
EventContentFields.MSC2716_HISTORICAL: True,
},
"origin_server_ts": origin_server_ts,
@@ -179,7 +180,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
if not requester.app_service:
raise AuthError(
- 403,
+ HTTPStatus.FORBIDDEN,
"Only application services can use the /batchsend endpoint",
)
@@ -187,24 +188,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
assert_params_in_dict(body, ["state_events_at_start", "events"])
assert request.args is not None
- prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
- chunk_id_from_query = parse_string(request, "chunk_id")
+ prev_event_ids_from_query = parse_strings_from_args(
+ request.args, "prev_event_id"
+ )
+ batch_id_from_query = parse_string(request, "batch_id")
- if prev_events_from_query is None:
+ if prev_event_ids_from_query is None:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"prev_event query parameter is required when inserting historical messages back in time",
errcode=Codes.MISSING_PARAM,
)
- # For the event we are inserting next to (`prev_events_from_query`),
+ # For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that
# allowed that message to be sent. We will use that as a base
# to auth our historical messages against.
(
most_recent_prev_event_id,
_,
- ) = await self.store.get_max_depth_of(prev_events_from_query)
+ ) = await self.store.get_max_depth_of(prev_event_ids_from_query)
# mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_prev_event_id
@@ -213,7 +216,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
- state_events_at_start = []
+ state_event_ids_at_start = []
for state_event in body["state_events_at_start"]:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
@@ -279,27 +282,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
)
event_id = event.event_id
- state_events_at_start.append(event_id)
+ state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)
events_to_create = body["events"]
inherited_depth = await self._inherit_depth_from_prev_ids(
- prev_events_from_query
+ prev_event_ids_from_query
)
- # Figure out which chunk to connect to. If they passed in
- # chunk_id_from_query let's use it. The chunk ID passed in comes
- # from the chunk_id in the "insertion" event from the previous chunk.
- last_event_in_chunk = events_to_create[-1]
- chunk_id_to_connect_to = chunk_id_from_query
+ # Figure out which batch to connect to. If they passed in
+ # batch_id_from_query let's use it. The batch ID passed in comes
+ # from the batch_id in the "insertion" event from the previous batch.
+ last_event_in_batch = events_to_create[-1]
+ batch_id_to_connect_to = batch_id_from_query
base_insertion_event = None
- if chunk_id_from_query:
+ if batch_id_from_query:
# All but the first base insertion event should point at a fake
# event, which causes the HS to ask for the state at the start of
- # the chunk later.
+ # the batch later.
prev_event_ids = [fake_prev_event_id]
- # TODO: Verify the chunk_id_from_query corresponds to an insertion event
+
+ # Verify the batch_id_from_query corresponds to an actual insertion event
+ # and have the batch connected.
+ corresponding_insertion_event_id = (
+ await self.store.get_insertion_event_by_batch_id(batch_id_from_query)
+ )
+ if corresponding_insertion_event_id is None:
+ raise SynapseError(
+ 400,
+ "No insertion event corresponds to the given ?batch_id",
+ errcode=Codes.INVALID_PARAM,
+ )
pass
# Otherwise, create an insertion event to act as a starting point.
#
@@ -309,12 +323,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
# an insertion event), in which case we just create a new insertion event
# that can then get pointed to by a "marker" event later.
else:
- prev_event_ids = prev_events_from_query
+ prev_event_ids = prev_event_ids_from_query
base_insertion_event_dict = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
- origin_server_ts=last_event_in_chunk["origin_server_ts"],
+ origin_server_ts=last_event_in_batch["origin_server_ts"],
)
base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
@@ -333,38 +347,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth,
)
- chunk_id_to_connect_to = base_insertion_event["content"][
- EventContentFields.MSC2716_NEXT_CHUNK_ID
+ batch_id_to_connect_to = base_insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_BATCH_ID
]
- # Connect this current chunk to the insertion event from the previous chunk
- chunk_event = {
- "type": EventTypes.MSC2716_CHUNK,
+ # Connect this current batch to the insertion event from the previous batch
+ batch_event = {
+ "type": EventTypes.MSC2716_BATCH,
"sender": requester.user.to_string(),
"room_id": room_id,
"content": {
- EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to,
+ EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
EventContentFields.MSC2716_HISTORICAL: True,
},
- # Since the chunk event is put at the end of the chunk,
+ # Since the batch event is put at the end of the batch,
# where the newest-in-time event is, copy the origin_server_ts from
# the last event we're inserting
- "origin_server_ts": last_event_in_chunk["origin_server_ts"],
+ "origin_server_ts": last_event_in_batch["origin_server_ts"],
}
- # Add the chunk event to the end of the chunk (newest-in-time)
- events_to_create.append(chunk_event)
+ # Add the batch event to the end of the batch (newest-in-time)
+ events_to_create.append(batch_event)
- # Add an "insertion" event to the start of each chunk (next to the oldest-in-time
- # event in the chunk) so the next chunk can be connected to this one.
+ # Add an "insertion" event to the start of each batch (next to the oldest-in-time
+ # event in the batch) so the next batch can be connected to this one.
insertion_event = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
- # Since the insertion event is put at the start of the chunk,
+ # Since the insertion event is put at the start of the batch,
# where the oldest-in-time event is, copy the origin_server_ts from
# the first event we're inserting
origin_server_ts=events_to_create[0]["origin_server_ts"],
)
- # Prepend the insertion event to the start of the chunk (oldest-in-time)
+ # Prepend the insertion event to the start of the batch (oldest-in-time)
events_to_create = [insertion_event] + events_to_create
event_ids = []
@@ -424,20 +438,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
context=context,
)
- # Add the base_insertion_event to the bottom of the list we return
- if base_insertion_event is not None:
- event_ids.append(base_insertion_event.event_id)
+ insertion_event_id = event_ids[0]
+ batch_event_id = event_ids[-1]
+ historical_event_ids = event_ids[1:-1]
- return 200, {
- "state_events": state_events_at_start,
- "events": event_ids,
- "next_chunk_id": insertion_event["content"][
- EventContentFields.MSC2716_NEXT_CHUNK_ID
+ response_dict = {
+ "state_event_ids": state_event_ids_at_start,
+ "event_ids": historical_event_ids,
+ "next_batch_id": insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_BATCH_ID
],
+ "insertion_event_id": insertion_event_id,
+ "batch_event_id": batch_event_id,
}
+ if base_insertion_event is not None:
+ response_dict["base_insertion_event_id"] = base_insertion_event.event_id
+
+ return HTTPStatus.OK, response_dict
def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
- return 501, "Not implemented"
+ return HTTPStatus.NOT_IMPLEMENTED, "Not implemented"
def on_PUT(
self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index 88528111..a47d9bd0 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -58,7 +58,7 @@ class UserDirectorySearchRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- if not self.hs.config.user_directory_search_enabled:
+ if not self.hs.config.userdirectory.user_directory_search_enabled:
return 200, {"limited": False, "results": []}
body = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index a1a815cf..b52a296d 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -42,15 +42,15 @@ class VersionsRestServlet(RestServlet):
# Calculate these once since they shouldn't change after start-up.
self.e2ee_forced_public = (
RoomCreationPreset.PUBLIC_CHAT
- in self.config.encryption_enabled_by_default_for_room_presets
+ in self.config.room.encryption_enabled_by_default_for_room_presets
)
self.e2ee_forced_private = (
RoomCreationPreset.PRIVATE_CHAT
- in self.config.encryption_enabled_by_default_for_room_presets
+ in self.config.room.encryption_enabled_by_default_for_room_presets
)
self.e2ee_forced_trusted_private = (
RoomCreationPreset.TRUSTED_PRIVATE_CHAT
- in self.config.encryption_enabled_by_default_for_room_presets
+ in self.config.room.encryption_enabled_by_default_for_room_presets
)
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index 9d46ed3a..ea2b8aa4 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -37,14 +37,14 @@ class VoipRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(
- request, self.hs.config.turn_allow_guests
+ request, self.hs.config.voip.turn_allow_guests
)
- turnUris = self.hs.config.turn_uris
- turnSecret = self.hs.config.turn_shared_secret
- turnUsername = self.hs.config.turn_username
- turnPassword = self.hs.config.turn_password
- userLifetime = self.hs.config.turn_user_lifetime
+ turnUris = self.hs.config.voip.turn_uris
+ turnSecret = self.hs.config.voip.turn_shared_secret
+ turnUsername = self.hs.config.voip.turn_username
+ turnPassword = self.hs.config.voip.turn_password
+ userLifetime = self.hs.config.voip.turn_user_lifetime
if turnUris and turnSecret and userLifetime:
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 11f73208..3d2afacc 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -17,17 +17,22 @@ import logging
from hashlib import sha256
from http import HTTPStatus
from os import path
-from typing import Dict, List
+from typing import TYPE_CHECKING, Any, Dict, List
import jinja2
from jinja2 import TemplateNotFound
+from twisted.web.server import Request
+
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en"
@@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user.
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): homeserver
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -83,14 +84,15 @@ class ConsentResource(DirectServeHtmlResource):
# this is required by the request_handler wrapper
self.clock = hs.get_clock()
- self._default_consent_version = hs.config.user_consent_version
- if self._default_consent_version is None:
+ # Consent must be configured to create this resource.
+ default_consent_version = hs.config.consent.user_consent_version
+ consent_template_directory = hs.config.consent.user_consent_template_dir
+ if default_consent_version is None or consent_template_directory is None:
raise ConfigError(
"Consent resource is enabled but user_consent section is "
"missing in config file."
)
-
- consent_template_directory = hs.config.user_consent_template_dir
+ self._default_consent_version = default_consent_version
# TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(consent_template_directory)
@@ -98,26 +100,22 @@ class ConsentResource(DirectServeHtmlResource):
loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
)
- if hs.config.form_secret is None:
+ if hs.config.key.form_secret is None:
raise ConfigError(
"Consent resource is enabled but form_secret is not set in "
"config file. It should be set to an arbitrary secret string."
)
- self._hmac_secret = hs.config.form_secret.encode("utf-8")
+ self._hmac_secret = hs.config.key.form_secret.encode("utf-8")
- async def _async_render_GET(self, request):
- """
- Args:
- request (twisted.web.http.Request):
- """
+ async def _async_render_GET(self, request: Request) -> None:
version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="")
userhmac = None
has_consented = False
public_version = username == ""
if not public_version:
- args: Dict[bytes, List[bytes]] = request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@@ -147,14 +145,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
- async def _async_render_POST(self, request):
- """
- Args:
- request (twisted.web.http.Request):
- """
+ async def _async_render_POST(self, request: Request) -> None:
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
- args: Dict[bytes, List[bytes]] = request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)
@@ -177,7 +171,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("success.html not found")
- def _render_template(self, request, template_name, **template_args):
+ def _render_template(
+ self, request: Request, template_name: str, **template_args: Any
+ ) -> None:
# get_template checks for ".." so we don't need to worry too much
# about path traversal here.
template_html = self._jinja_env.get_template(
@@ -186,11 +182,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args)
respond_with_html(request, 200, html)
- def _check_hash(self, userid, userhmac):
+ def _check_hash(self, userid: str, userhmac: bytes) -> None:
"""
Args:
- userid (unicode):
- userhmac (bytes):
+ userid:
+ userhmac:
Raises:
SynapseError if the hash doesn't match
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
index 4487b54a..78df7af2 100644
--- a/synapse/rest/health.py
+++ b/synapse/rest/health.py
@@ -13,6 +13,7 @@
# limitations under the License.
from twisted.web.resource import Resource
+from twisted.web.server import Request
class HealthResource(Resource):
@@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain")
return b"OK"
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index c6c63073..7f8c1de1 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
from twisted.web.resource import Resource
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class KeyApiV2Resource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 25f6eb84..12b3ae12 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.http.server import respond_with_json_bytes
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -58,26 +63,26 @@ class LocalKey(Resource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
- def update_response_body(self, time_now_msec):
- refresh_interval = self.config.key_refresh_interval
+ def update_response_body(self, time_now_msec: int) -> None:
+ refresh_interval = self.config.key.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
- def response_json_object(self):
+ def response_json_object(self) -> JsonDict:
verify_keys = {}
- for key in self.config.signing_key:
+ for key in self.config.key.signing_key:
verify_key_bytes = key.verify_key.encode()
key_id = "%s:%s" % (key.alg, key.version)
verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)}
old_verify_keys = {}
- for key_id, key in self.config.old_signing_keys.items():
+ for key_id, key in self.config.key.old_signing_keys.items():
verify_key_bytes = key.encode()
old_verify_keys[key_id] = {
"key": encode_base64(verify_key_bytes),
@@ -90,13 +95,13 @@ class LocalKey(Resource):
"verify_keys": verify_keys,
"old_verify_keys": old_verify_keys,
}
- for key in self.config.signing_key:
+ for key in self.config.key.signing_key:
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
- if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
+ if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now)
return respond_with_json_bytes(request, 200, self.response_body)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 744360e5..3923ba84 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json
@@ -21,9 +21,14 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -85,16 +90,19 @@ class RemoteKey(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
- self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.federation_domain_whitelist = (
+ hs.config.federation.federation_domain_whitelist
+ )
self.config = hs.config
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
+ assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
query: dict = {server.decode("ascii"): {}}
@@ -110,14 +118,19 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
content = parse_json_object_from_request(request)
query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ async def query_keys(
+ self,
+ request: SynapseRequest,
+ query: JsonDict,
+ query_remote_on_cache_miss: bool = False,
+ ) -> None:
logger.info("Handling query for keys %r", query)
store_queries = []
@@ -142,8 +155,8 @@ class RemoteKey(DirectServeJsonResource):
# Note that the value is unused.
cache_misses: Dict[str, Dict[str, int]] = {}
- for (server_name, key_id, _), results in cached.items():
- results = [(result["ts_added_ms"], result) for result in results]
+ for (server_name, key_id, _), key_results in cached.items():
+ results = [(result["ts_added_ms"], result) for result in key_results]
if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0
@@ -223,13 +236,13 @@ class RemoteKey(DirectServeJsonResource):
signed_keys = []
for key_json in json_results:
key_json = json_decoder.decode(key_json.decode("utf-8"))
- for signing_key in self.config.key_server_signing_keys:
+ for signing_key in self.config.key.key_server_signing_keys:
key_json = sign_json(
key_json, self.config.server.server_name, signing_key
)
signed_keys.append(key_json)
- results = {"server_keys": signed_keys}
+ response = {"server_keys": signed_keys}
- respond_with_json(request, 200, results, canonical_json=True)
+ respond_with_json(request, 200, response, canonical_json=True)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 90364ebc..014fa893 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -16,7 +16,10 @@
import logging
import os
import urllib
-from typing import Awaitable, Dict, Generator, List, Optional, Tuple
+from types import TracebackType
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+
+import attr
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
@@ -24,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
+from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii
@@ -71,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
)
-def respond_404(request: Request) -> None:
+def respond_404(request: SynapseRequest) -> None:
respond_with_json(
request,
404,
@@ -81,7 +85,7 @@ def respond_404(request: Request) -> None:
async def respond_with_file(
- request: Request,
+ request: SynapseRequest,
media_type: str,
file_path: str,
file_size: Optional[int] = None,
@@ -120,7 +124,7 @@ def add_file_headers(
upload_name: The name of the requested file, if any.
"""
- def _quote(x):
+ def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8"))
# Default to a UTF-8 charset for text content types.
@@ -218,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool:
async def respond_with_responder(
- request: Request,
+ request: SynapseRequest,
responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],
@@ -280,51 +284,74 @@ class Responder:
"""
pass
- def __enter__(self):
+ def __enter__(self) -> None:
pass
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
pass
-class FileInfo:
- """Details about a requested/uploaded file.
-
- Attributes:
- server_name (str): The server name where the media originated from,
- or None if local.
- file_id (str): The local ID of the file. For local files this is the
- same as the media_id
- url_cache (bool): If the file is for the url preview cache
- thumbnail (bool): Whether the file is a thumbnail or not.
- thumbnail_width (int)
- thumbnail_height (int)
- thumbnail_method (str)
- thumbnail_type (str): Content type of thumbnail, e.g. image/png
- thumbnail_length (int): The size of the media file, in bytes.
- """
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThumbnailInfo:
+ """Details about a generated thumbnail."""
- def __init__(
- self,
- server_name,
- file_id,
- url_cache=False,
- thumbnail=False,
- thumbnail_width=None,
- thumbnail_height=None,
- thumbnail_method=None,
- thumbnail_type=None,
- thumbnail_length=None,
- ):
- self.server_name = server_name
- self.file_id = file_id
- self.url_cache = url_cache
- self.thumbnail = thumbnail
- self.thumbnail_width = thumbnail_width
- self.thumbnail_height = thumbnail_height
- self.thumbnail_method = thumbnail_method
- self.thumbnail_type = thumbnail_type
- self.thumbnail_length = thumbnail_length
+ width: int
+ height: int
+ method: str
+ # Content type of thumbnail, e.g. image/png
+ type: str
+ # The size of the media file, in bytes.
+ length: Optional[int] = None
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FileInfo:
+ """Details about a requested/uploaded file."""
+
+ # The server name where the media originated from, or None if local.
+ server_name: Optional[str]
+ # The local ID of the file. For local files this is the same as the media_id
+ file_id: str
+ # If the file is for the url preview cache
+ url_cache: bool = False
+ # Whether the file is a thumbnail or not.
+ thumbnail: Optional[ThumbnailInfo] = None
+
+ # The below properties exist to maintain compatibility with third-party modules.
+ @property
+ def thumbnail_width(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.width
+
+ @property
+ def thumbnail_height(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.height
+
+ @property
+ def thumbnail_method(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.method
+
+ @property
+ def thumbnail_type(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.type
+
+ @property
+ def thumbnail_length(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index a1d36e5c..a95804d3 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -16,8 +16,6 @@
from typing import TYPE_CHECKING
-from twisted.web.server import Request
-
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.site import SynapseRequest
@@ -33,11 +31,11 @@ class MediaConfigResource(DirectServeJsonResource):
config = hs.config
self.clock = hs.get_clock()
self.auth = hs.get_auth()
- self.limits_dict = {"m.upload.size": config.max_upload_size}
+ self.limits_dict = {"m.upload.size": config.media.max_upload_size}
async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
- async def _async_render_OPTIONS(self, request: Request) -> None:
+ async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d6d93895..6180fa57 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -15,10 +15,9 @@
import logging
from typing import TYPE_CHECKING
-from twisted.web.server import Request
-
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
+from synapse.http.site import SynapseRequest
from ._base import parse_media_id, respond_404
@@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource):
self.media_repo = media_repo
self.server_name = hs.hostname
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 09531ebf..08bd85f6 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -16,7 +16,7 @@
import functools
import os
import re
-from typing import Callable, List
+from typing import Any, Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
"""
@functools.wraps(func)
- def _wrapped(self, *args, **kwargs):
+ def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path)
@@ -129,7 +129,7 @@ class MediaFilePaths:
# using the new path.
def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str
- ):
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
@@ -195,23 +195,24 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
- def url_cache_thumbnail_directory(self, media_id: str) -> str:
+ def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
- return os.path.join(
- self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
- )
+ return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:])
else:
return os.path.join(
- self.base_path,
"url_cache_thumbnails",
media_id[0:2],
media_id[2:4],
media_id[4:],
)
+ url_cache_thumbnail_directory = _wrap_in_base_path(
+ url_cache_thumbnail_directory_rel
+ )
+
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0f5ce41f..abd88a2d 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -21,8 +21,8 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
-from twisted.web.server import Request
from synapse.api.errors import (
FederationDeniedError,
@@ -32,6 +32,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config._base import ConfigError
+from synapse.config.repository import ThumbnailRequirement
+from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
@@ -42,6 +44,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
Responder,
+ ThumbnailInfo,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -73,29 +76,35 @@ class MediaRepository:
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
- self.max_upload_size = hs.config.max_upload_size
- self.max_image_pixels = hs.config.max_image_pixels
+ self.max_upload_size = hs.config.media.max_upload_size
+ self.max_image_pixels = hs.config.media.max_image_pixels
Thumbnailer.set_limits(self.max_image_pixels)
- self.primary_base_path: str = hs.config.media_store_path
+ self.primary_base_path: str = hs.config.media.media_store_path
self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
- self.dynamic_thumbnails = hs.config.dynamic_thumbnails
- self.thumbnail_requirements = hs.config.thumbnail_requirements
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+ self.thumbnail_requirements = hs.config.media.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
self.recently_accessed_locals: Set[str] = set()
- self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.federation_domain_whitelist = (
+ hs.config.federation.federation_domain_whitelist
+ )
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
- for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+ for (
+ clz,
+ provider_config,
+ wrapper_config,
+ ) in hs.config.media.media_storage_providers:
backend = clz(hs, provider_config)
provider = StorageProviderWrapper(
backend,
@@ -113,7 +122,7 @@ class MediaRepository:
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
)
- def _start_update_recently_accessed(self):
+ def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
)
@@ -184,7 +193,7 @@ class MediaRepository:
return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(
- self, request: Request, media_id: str, name: Optional[str]
+ self, request: SynapseRequest, media_id: str, name: Optional[str]
) -> None:
"""Responds to requests for local media, if exists, or returns 404.
@@ -210,7 +219,7 @@ class MediaRepository:
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
- file_info = FileInfo(None, media_id, url_cache=url_cache)
+ file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
@@ -218,7 +227,11 @@ class MediaRepository:
)
async def get_remote_media(
- self, request: Request, server_name: str, media_id: str, name: Optional[str]
+ self,
+ request: SynapseRequest,
+ server_name: str,
+ media_id: str,
+ name: Optional[str],
) -> None:
"""Respond to requests for remote media.
@@ -468,7 +481,9 @@ class MediaRepository:
return media_info
- def _get_thumbnail_requirements(self, media_type):
+ def _get_thumbnail_requirements(
+ self, media_type: str
+ ) -> Tuple[ThumbnailRequirement, ...]:
scpos = media_type.find(";")
if scpos > 0:
media_type = media_type[:scpos]
@@ -514,7 +529,7 @@ class MediaRepository:
t_height: int,
t_method: str,
t_type: str,
- url_cache: Optional[str],
+ url_cache: bool,
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
@@ -548,11 +563,12 @@ class MediaRepository:
server_name=None,
file_id=media_id,
url_cache=url_cache,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
output_path = await self.media_storage.store_file(
@@ -585,7 +601,7 @@ class MediaRepository:
t_type: str,
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
- FileInfo(server_name, file_id, url_cache=False)
+ FileInfo(server_name, file_id)
)
try:
@@ -616,11 +632,12 @@ class MediaRepository:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
output_path = await self.media_storage.store_file(
@@ -742,12 +759,13 @@ class MediaRepository:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
url_cache=url_cache,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
@@ -961,7 +979,7 @@ class MediaRepositoryResource(Resource):
def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
- if not hs.config.can_load_media_repo:
+ if not hs.config.media.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")
super().__init__()
@@ -972,7 +990,7 @@ class MediaRepositoryResource(Resource):
self.putChild(
b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
)
- if hs.config.url_preview_enabled:
+ if hs.config.media.url_preview_enabled:
self.putChild(
b"preview_url",
PreviewUrlResource(hs, media_repo, media_repo.media_storage),
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 56cdc1b4..fca239d8 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -15,7 +15,20 @@ import contextlib
import logging
import os
import shutil
-from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+from types import TracebackType
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ BinaryIO,
+ Callable,
+ Generator,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+)
import attr
@@ -83,12 +96,14 @@ class MediaStorage:
return fname
- async def write_to_file(self, source: IO, output: IO):
+ async def write_to_file(self, source: IO, output: IO) -> None:
"""Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
- def store_into_file(self, file_info: FileInfo):
+ def store_into_file(
+ self, file_info: FileInfo
+ ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
described by file_info.
@@ -117,15 +132,14 @@ class MediaStorage:
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ os.makedirs(dirname, exist_ok=True)
finished_called = [False]
try:
with open(fname, "wb") as f:
- async def finish():
+ async def finish() -> None:
# Ensure that all writes have been flushed and close the
# file.
f.flush()
@@ -176,9 +190,9 @@ class MediaStorage:
self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
)
)
@@ -220,17 +234,16 @@ class MediaStorage:
legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
)
legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
if os.path.exists(legacy_local_path):
return legacy_local_path
dirname = os.path.dirname(local_path)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ os.makedirs(dirname, exist_ok=True)
for provider in self.storage_providers:
res: Any = await provider.fetch(path, file_info)
@@ -255,10 +268,10 @@ class MediaStorage:
if file_info.thumbnail:
return self.filepaths.url_cache_thumbnail_rel(
media_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
@@ -267,10 +280,10 @@ class MediaStorage:
return self.filepaths.remote_media_thumbnail_rel(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.remote_media_filepath_rel(
file_info.server_name, file_info.file_id
@@ -279,10 +292,10 @@ class MediaStorage:
if file_info.thumbnail:
return self.filepaths.local_media_thumbnail_rel(
media_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.local_media_filepath_rel(file_info.file_id)
@@ -315,7 +328,12 @@ class FileResponder(Responder):
FileSender().beginFileTransfer(self.open_file, consumer)
)
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
self.open_file.close()
@@ -339,7 +357,7 @@ class ReadableFileWrapper:
clock = attr.ib(type=Clock)
path = attr.ib(type=str)
- async def write_chunks_to(self, callback: Callable[[bytes], None]):
+ async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file:
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2e6706db..e04671fb 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -12,30 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional
+import urllib.parse
+from typing import TYPE_CHECKING, List, Optional
import attr
from synapse.http.client import SimpleHttpClient
+from synapse.types import JsonDict
+from synapse.util import json_decoder
if TYPE_CHECKING:
+ from lxml import etree
+
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class OEmbedResult:
- # Either HTML content or URL must be provided.
- html: Optional[str]
- url: Optional[str]
- title: Optional[str]
- # Number of seconds to cache the content.
- cache_age: int
-
-
-class OEmbedError(Exception):
- """An error occurred processing the oEmbed object."""
+ # The Open Graph result (converted from the oEmbed result).
+ open_graph_result: JsonDict
+ # Number of milliseconds to cache the content, according to the oEmbed response.
+ #
+ # This will be None if no cache-age is provided in the oEmbed response (or
+ # if the oEmbed response cannot be turned into an Open Graph response).
+ cache_age: Optional[int]
class OEmbedProvider:
@@ -81,75 +83,145 @@ class OEmbedProvider:
"""
for url_pattern, endpoint in self._oembed_patterns.items():
if url_pattern.fullmatch(url):
- return endpoint
+ # TODO Specify max height / width.
+
+ # Note that only the JSON format is supported, some endpoints want
+ # this in the URL, others want it as an argument.
+ endpoint = endpoint.replace("{format}", "json")
+
+ args = {"url": url, "format": "json"}
+ query_str = urllib.parse.urlencode(args, True)
+ return f"{endpoint}?{query_str}"
# No match.
return None
- async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
"""
- Request content from an oEmbed endpoint.
+ Parse the oEmbed response into an Open Graph response.
Args:
- endpoint: The oEmbed API endpoint.
- url: The URL to pass to the API.
+ url: The URL which is being previewed (not the one which was
+ requested).
+ raw_body: The oEmbed response as JSON encoded as bytes.
Returns:
- An object representing the metadata returned.
-
- Raises:
- OEmbedError if fetching or parsing of the oEmbed information fails.
+ json-encoded Open Graph data
"""
- try:
- logger.debug("Trying to get oEmbed content for url '%s'", url)
-
- # Note that only the JSON format is supported, some endpoints want
- # this in the URL, others want it as an argument.
- endpoint = endpoint.replace("{format}", "json")
- result = await self._client.get_json(
- endpoint,
- # TODO Specify max height / width.
- args={"url": url, "format": "json"},
- )
+ try:
+ # oEmbed responses *must* be UTF-8 according to the spec.
+ oembed = json_decoder.decode(raw_body.decode("utf-8"))
# Ensure there's a version of 1.0.
- if result.get("version") != "1.0":
- raise OEmbedError("Invalid version: %s" % (result.get("version"),))
-
- oembed_type = result.get("type")
+ oembed_version = oembed["version"]
+ if oembed_version != "1.0":
+ raise RuntimeError(f"Invalid version: {oembed_version}")
# Ensure the cache age is None or an int.
- cache_age = result.get("cache_age")
+ cache_age = oembed.get("cache_age")
if cache_age:
- cache_age = int(cache_age)
+ cache_age = int(cache_age) * 1000
- oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+ # The results.
+ open_graph_response = {
+ "og:url": url,
+ }
- # HTML content.
- if oembed_type == "rich":
- oembed_result.html = result.get("html")
- return oembed_result
+ # Use either title or author's name as the title.
+ title = oembed.get("title") or oembed.get("author_name")
+ if title:
+ open_graph_response["og:title"] = title
- if oembed_type == "photo":
- oembed_result.url = result.get("url")
- return oembed_result
+ # Use the provider name and as the site.
+ provider_name = oembed.get("provider_name")
+ if provider_name:
+ open_graph_response["og:site_name"] = provider_name
- # TODO Handle link and video types.
+ # If a thumbnail exists, use it. Note that dimensions will be calculated later.
+ if "thumbnail_url" in oembed:
+ open_graph_response["og:image"] = oembed["thumbnail_url"]
- if "thumbnail_url" in result:
- oembed_result.url = result.get("thumbnail_url")
- return oembed_result
+ # Process each type separately.
+ oembed_type = oembed["type"]
+ if oembed_type == "rich":
+ calc_description_and_urls(open_graph_response, oembed["html"])
+
+ elif oembed_type == "photo":
+ # If this is a photo, use the full image, not the thumbnail.
+ open_graph_response["og:image"] = oembed["url"]
- raise OEmbedError("Incompatible oEmbed information.")
+ elif oembed_type == "video":
+ open_graph_response["og:type"] = "video.other"
+ calc_description_and_urls(open_graph_response, oembed["html"])
+ open_graph_response["og:video:width"] = oembed["width"]
+ open_graph_response["og:video:height"] = oembed["height"]
- except OEmbedError as e:
- # Trap OEmbedErrors first so we can directly re-raise them.
- logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
- raise
+ elif oembed_type == "link":
+ open_graph_response["og:type"] = "website"
+
+ else:
+ raise RuntimeError(f"Unknown oEmbed type: {oembed_type}")
except Exception as e:
# Trap any exception and let the code follow as usual.
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
- raise OEmbedError() from e
+ logger.warning(f"Error parsing oEmbed metadata from {url}: {e:r}")
+ open_graph_response = {}
+ cache_age = None
+
+ return OEmbedResult(open_graph_response, cache_age)
+
+
+def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]:
+ results = []
+ for tag in tree.xpath("//*/" + tag_name):
+ if "src" in tag.attrib:
+ results.append(tag.attrib["src"])
+ return results
+
+
+def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None:
+ """
+ Calculate description for an HTML document.
+
+ This uses lxml to convert the HTML document into plaintext. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ open_graph_response: The current Open Graph summary. This is updated with additional fields.
+ html_body: The HTML document, as bytes.
+
+ Returns:
+ The summary
+ """
+ # If there's no body, nothing useful is going to be found.
+ if not html_body:
+ return
+
+ from lxml import etree
+
+ # Create an HTML parser. If this fails, log and return no metadata.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(html_body, parser)
+
+ # The data was successfully parsed, but no tree was found.
+ if tree is None:
+ return
+
+ # Attempt to find interesting URLs (images, videos, embeds).
+ if "og:image" not in open_graph_response:
+ image_urls = _fetch_urls(tree, "img")
+ if image_urls:
+ open_graph_response["og:image"] = image_urls[0]
+
+ video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed")
+ if video_urls:
+ open_graph_response["og:video"] = video_urls[0]
+
+ from synapse.rest.media.v1.preview_url_resource import _calc_description
+
+ description = _calc_description(tree)
+ if description:
+ open_graph_response["og:description"] = description
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index f108da05..79a42b24 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,8 +27,8 @@ from urllib import parse as urlparse
import attr
+from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
-from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@@ -43,7 +43,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
-from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider
+from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -72,6 +72,7 @@ OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
+ONE_DAY = 24 * ONE_HOUR
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -124,14 +125,14 @@ class PreviewUrlResource(DirectServeJsonResource):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.filepaths = media_repo.filepaths
- self.max_spider_size = hs.config.max_spider_size
+ self.max_spider_size = hs.config.media.max_spider_size
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.client = SimpleHttpClient(
hs,
treq_args={"browser_like_redirects": True},
- ip_whitelist=hs.config.url_preview_ip_range_whitelist,
- ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ ip_whitelist=hs.config.media.url_preview_ip_range_whitelist,
+ ip_blacklist=hs.config.media.url_preview_ip_range_blacklist,
use_proxy=True,
)
self.media_repo = media_repo
@@ -149,8 +150,8 @@ class PreviewUrlResource(DirectServeJsonResource):
or instance_running_jobs == hs.get_instance_name()
)
- self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
- self.url_preview_accept_language = hs.config.url_preview_accept_language
+ self.url_preview_url_blacklist = hs.config.media.url_preview_url_blacklist
+ self.url_preview_accept_language = hs.config.media.url_preview_accept_language
# memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata
@@ -166,7 +167,7 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
- async def _async_render_OPTIONS(self, request: Request) -> None:
+ async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
@@ -254,10 +255,19 @@ class PreviewUrlResource(DirectServeJsonResource):
og = og.encode("utf8")
return og
- media_info = await self._download_url(url, user)
+ # If this URL can be accessed via oEmbed, use that instead.
+ url_to_download = url
+ oembed_url = self._oembed.get_oembed_url(url)
+ if oembed_url:
+ url_to_download = oembed_url
+
+ media_info = await self._download_url(url_to_download, user)
logger.debug("got media_info of '%s'", media_info)
+ # The number of milliseconds that the response should be considered valid.
+ expiration_ms = media_info.expires
+
if _is_media(media_info.media_type):
file_id = media_info.filesystem_id
dims = await self.media_repo._generate_thumbnails(
@@ -287,34 +297,22 @@ class PreviewUrlResource(DirectServeJsonResource):
encoding = get_html_media_encoding(body, media_info.media_type)
og = decode_and_calc_og(body, media_info.uri, encoding)
- # pre-cache the image for posterity
- # FIXME: it might be cleaner to use the same flow as the main /preview_url
- # request itself and benefit from the same caching etc. But for now we
- # just rely on the caching on the master request to speed things up.
- if "og:image" in og and og["og:image"]:
- image_info = await self._download_url(
- _rebase_url(og["og:image"], media_info.uri), user
- )
+ await self._precache_image_url(user, media_info, og)
+
+ elif oembed_url and _is_json(media_info.media_type):
+ # Handle an oEmbed response.
+ with open(media_info.filename, "rb") as file:
+ body = file.read()
+
+ oembed_response = self._oembed.parse_oembed_response(url, body)
+ og = oembed_response.open_graph_result
+
+ # Use the cache age from the oEmbed result, instead of the HTTP response.
+ if oembed_response.cache_age is not None:
+ expiration_ms = oembed_response.cache_age
+
+ await self._precache_image_url(user, media_info, og)
- if _is_media(image_info.media_type):
- # TODO: make sure we don't choke on white-on-transparent images
- file_id = image_info.filesystem_id
- dims = await self.media_repo._generate_thumbnails(
- None, file_id, file_id, image_info.media_type, url_cache=True
- )
- if dims:
- og["og:image:width"] = dims["width"]
- og["og:image:height"] = dims["height"]
- else:
- logger.warning("Couldn't get dims for %s", og["og:image"])
-
- og[
- "og:image"
- ] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
- og["og:image:type"] = image_info.media_type
- og["matrix:image:size"] = image_info.media_length
- else:
- del og["og:image"]
else:
logger.warning("Failed to find any OG data in %s", url)
og = {}
@@ -335,12 +333,15 @@ class PreviewUrlResource(DirectServeJsonResource):
jsonog = json_encoder.encode(og)
+ # Cap the amount of time to consider a response valid.
+ expiration_ms = min(expiration_ms, ONE_DAY)
+
# store OG in history-aware DB cache
await self.store.store_url_cache(
url,
media_info.response_code,
media_info.etag,
- media_info.expires + media_info.created_ts_ms,
+ media_info.created_ts_ms + expiration_ms,
jsonog,
media_info.filesystem_id,
media_info.created_ts_ms,
@@ -357,88 +358,52 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- # If this URL can be accessed via oEmbed, use that instead.
- url_to_download: Optional[str] = url
- oembed_url = self._oembed.get_oembed_url(url)
- if oembed_url:
- # The result might be a new URL to download, or it might be HTML content.
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
- oembed_result = await self._oembed.get_oembed_content(oembed_url, url)
- if oembed_result.url:
- url_to_download = oembed_result.url
- elif oembed_result.html:
- url_to_download = None
- except OEmbedError:
- # If an error occurs, try doing a normal preview.
- pass
+ logger.debug("Trying to get preview for url '%s'", url)
+ length, headers, uri, code = await self.client.get_file(
+ url,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
+ )
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
+ )
+ except Exception as e:
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading %s: %r", url, e)
- if url_to_download:
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- try:
- logger.debug("Trying to get preview for url '%s'", url_to_download)
- length, headers, uri, code = await self.client.get_file(
- url_to_download,
- output_stream=f,
- max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
- )
- except SynapseError:
- # Pass SynapseErrors through directly, so that the servlet
- # handler will return a SynapseError to the client instead of
- # blank data or a 500.
- raise
- except DNSLookupError:
- # DNS lookup returned no results
- # Note: This will also be the case if one of the resolved IP
- # addresses is blacklisted
- raise SynapseError(
- 502,
- "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN,
- )
- except Exception as e:
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading %s: %r", url_to_download, e)
-
- raise SynapseError(
- 500,
- "Failed to download content: %s"
- % (traceback.format_exception_only(sys.exc_info()[0], e),),
- Codes.UNKNOWN,
- )
- await finish()
-
- if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode("ascii")
- else:
- media_type = "application/octet-stream"
+ raise SynapseError(
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+ await finish()
- download_name = get_filename_from_headers(headers)
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- expires = ONE_HOUR
- etag = (
- headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
- )
- else:
- # we can only get here if we did an oembed request and have an oembed_result.html
- assert oembed_result.html is not None
- assert oembed_url is not None
-
- html_bytes = oembed_result.html.encode("utf-8")
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- f.write(html_bytes)
- await finish()
-
- media_type = "text/html"
- download_name = oembed_result.title
- length = len(html_bytes)
- # If a specific cache age was not given, assume 1 hour.
- expires = oembed_result.cache_age or ONE_HOUR
- uri = oembed_url
- code = 200
- etag = None
+ download_name = get_filename_from_headers(headers)
+
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ expires = ONE_HOUR
+ etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
try:
time_now_ms = self.clock.time_msec()
@@ -473,14 +438,53 @@ class PreviewUrlResource(DirectServeJsonResource):
etag=etag,
)
- def _start_expire_url_cache_data(self):
+ async def _precache_image_url(
+ self, user: str, media_info: MediaInfo, og: JsonDict
+ ) -> None:
+ """
+ Pre-cache the image (if one exists) for posterity
+
+ Args:
+ user: The user requesting the preview.
+ media_info: The media being previewed.
+ og: The Open Graph dictionary. This is modified with image information.
+ """
+ # If there's no image or it is blank, there's nothing to do.
+ if "og:image" not in og or not og["og:image"]:
+ return
+
+ # FIXME: it might be cleaner to use the same flow as the main /preview_url
+ # request itself and benefit from the same caching etc. But for now we
+ # just rely on the caching on the master request to speed things up.
+ image_info = await self._download_url(
+ _rebase_url(og["og:image"], media_info.uri), user
+ )
+
+ if _is_media(image_info.media_type):
+ # TODO: make sure we don't choke on white-on-transparent images
+ file_id = image_info.filesystem_id
+ dims = await self.media_repo._generate_thumbnails(
+ None, file_id, file_id, image_info.media_type, url_cache=True
+ )
+ if dims:
+ og["og:image:width"] = dims["width"]
+ og["og:image:height"] = dims["height"]
+ else:
+ logger.warning("Couldn't get dims for %s", og["og:image"])
+
+ og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
+ og["og:image:type"] = image_info.media_type
+ og["matrix:image:size"] = image_info.media_length
+ else:
+ del og["og:image"]
+
+ def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data
)
async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails."""
- # TODO: Delete from backup media store
assert self._worker_run_media_background_jobs
@@ -526,7 +530,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * 24 * ONE_HOUR
+ expire_before = now - 2 * ONE_DAY
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
@@ -668,7 +672,18 @@ def decode_and_calc_og(
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
- # suck our tree into lxml and define our OG response.
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to search the HTML document for Open Graph data.
+
+ Args:
+ tree: The parsed HTML document.
+ media_url: The URI used to download the body.
+
+ Returns:
+ The Open Graph response as a dictionary.
+ """
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
@@ -742,35 +757,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
if meta_description:
og["og:description"] = meta_description[0]
else:
- # grab any text nodes which are inside the <body/> tag...
- # unless they are within an HTML5 semantic markup tag...
- # <header/>, <nav/>, <aside/>, <footer/>
- # ...or if they are within a <script/> or <style/> tag.
- # This is a very very very coarse approximation to a plain text
- # render of the page.
-
- # We don't just use XPATH here as that is slow on some machines.
-
- from lxml import etree
-
- TAGS_TO_REMOVE = (
- "header",
- "nav",
- "aside",
- "footer",
- "script",
- "noscript",
- "style",
- etree.Comment,
- )
-
- # Split all the text nodes into paragraphs (by splitting on new
- # lines)
- text_nodes = (
- re.sub(r"\s+", "\n", el).strip()
- for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
- )
- og["og:description"] = summarize_paragraphs(text_nodes)
+ og["og:description"] = _calc_description(tree)
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
@@ -781,8 +768,48 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
return og
+def _calc_description(tree: "etree.Element") -> Optional[str]:
+ """
+ Calculate a text description based on an HTML document.
+
+ Grabs any text nodes which are inside the <body/> tag, unless they are within
+ an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
+ if they are within a <script/> or <style/> tag.
+
+ This is a very very very coarse approximation to a plain text render of the page.
+
+ Args:
+ tree: The parsed HTML document.
+
+ Returns:
+ The plain text description, or None if one cannot be generated.
+ """
+ # We don't just use XPATH here as that is slow on some machines.
+
+ from lxml import etree
+
+ TAGS_TO_REMOVE = (
+ "header",
+ "nav",
+ "aside",
+ "footer",
+ "script",
+ "noscript",
+ "style",
+ etree.Comment,
+ )
+
+ # Split all the text nodes into paragraphs (by splitting on new
+ # lines)
+ text_nodes = (
+ re.sub(r"\s+", "\n", el).strip()
+ for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+ )
+ return summarize_paragraphs(text_nodes)
+
+
def _iterate_over_text(
- tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+ tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
@@ -840,11 +867,25 @@ def _is_html(content_type: str) -> bool:
)
+def _is_json(content_type: str) -> bool:
+ return content_type.lower().startswith("application/json")
+
+
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
- # Try to get a summary of between 200 and 500 words, respecting
- # first paragraph and then word boundaries.
+ """
+ Try to get a summary respecting first paragraph and then word boundaries.
+
+ Args:
+ text_nodes: The paragraphs to summarize.
+ min_size: The minimum number of words to include.
+ max_size: The maximum number of words to include.
+
+ Returns:
+ A summary of the text nodes, or None if that was not possible.
+ """
+
# TODO: Respect sentences?
description = ""
@@ -867,7 +908,7 @@ def summarize_paragraphs(
new_desc = ""
# This splits the paragraph into words, but keeping the
- # (preceeding) whitespace intact so we can easily concat
+ # (preceding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer(r"\s*\S+", description):
word = match.group()
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 0ff6ad3c..18bf977d 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -93,13 +93,18 @@ class StorageProviderWrapper(StorageProvider):
if file_info.server_name and not self.store_remote:
return None
+ if file_info.url_cache:
+ # The URL preview cache is short lived and not worth offloading or
+ # backing up.
+ return None
+
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
- async def store():
+ async def store() -> None:
try:
return await maybe_awaitable(
self.backend.store_file(path, file_info)
@@ -110,6 +115,11 @@ class StorageProviderWrapper(StorageProvider):
run_in_background(store)
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
+ if file_info.url_cache:
+ # Files in the URL preview cache definitely aren't stored here,
+ # so avoid any potentially slow I/O or network access.
+ return None
+
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info))
@@ -125,10 +135,10 @@ class FileStorageProviderBackend(StorageProvider):
def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
- self.cache_directory = hs.config.media_store_path
+ self.cache_directory = hs.config.media.media_store_path
self.base_directory = config
- def __str__(self):
+ def __str__(self) -> str:
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path: str, file_info: FileInfo) -> None:
@@ -138,8 +148,7 @@ class FileStorageProviderBackend(StorageProvider):
backup_fname = os.path.join(self.base_directory, path)
dirname = os.path.dirname(backup_fname)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ os.makedirs(dirname, exist_ok=True)
await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 12bd745c..ed91ef5a 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -17,15 +17,15 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-from twisted.web.server import Request
-
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
+ ThumbnailInfo,
parse_media_id,
respond_404,
respond_with_file,
@@ -53,10 +53,10 @@ class ThumbnailResource(DirectServeJsonResource):
self.store = hs.get_datastore()
self.media_repo = media_repo
self.media_storage = media_storage
- self.dynamic_thumbnails = hs.config.dynamic_thumbnails
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.server_name = hs.hostname
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@@ -87,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_local_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
media_id: str,
width: int,
height: int,
@@ -114,13 +114,13 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos,
media_id,
media_id,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info["url_cache"]),
server_name=None,
)
async def _select_or_generate_local_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
media_id: str,
desired_width: int,
desired_height: int,
@@ -149,11 +149,12 @@ class ThumbnailResource(DirectServeJsonResource):
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=info["thumbnail_width"],
- thumbnail_height=info["thumbnail_height"],
- thumbnail_type=info["thumbnail_type"],
- thumbnail_method=info["thumbnail_method"],
+ thumbnail=ThumbnailInfo(
+ width=info["thumbnail_width"],
+ height=info["thumbnail_height"],
+ type=info["thumbnail_type"],
+ method=info["thumbnail_method"],
+ ),
)
t_type = file_info.thumbnail_type
@@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height,
desired_method,
desired_type,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info["url_cache"]),
)
if file_path:
@@ -184,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
server_name: str,
media_id: str,
desired_width: int,
@@ -210,11 +211,12 @@ class ThumbnailResource(DirectServeJsonResource):
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=info["thumbnail_width"],
- thumbnail_height=info["thumbnail_height"],
- thumbnail_type=info["thumbnail_type"],
- thumbnail_method=info["thumbnail_method"],
+ thumbnail=ThumbnailInfo(
+ width=info["thumbnail_width"],
+ height=info["thumbnail_height"],
+ type=info["thumbnail_type"],
+ method=info["thumbnail_method"],
+ ),
)
t_type = file_info.thumbnail_type
@@ -246,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_remote_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
server_name: str,
media_id: str,
width: int,
@@ -271,13 +273,13 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos,
media_id,
media_info["filesystem_id"],
- url_cache=None,
+ url_cache=False,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
- request: Request,
+ request: SynapseRequest,
desired_width: int,
desired_height: int,
desired_method: str,
@@ -285,7 +287,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos: List[Dict[str, Any]],
media_id: str,
file_id: str,
- url_cache: Optional[str] = None,
+ url_cache: bool,
server_name: Optional[str] = None,
) -> None:
"""
@@ -299,7 +301,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: The URL cache value.
+ url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
if thumbnail_infos:
@@ -318,13 +320,16 @@ class ThumbnailResource(DirectServeJsonResource):
respond_404(request)
return
+ # The thumbnail property must exist.
+ assert file_info.thumbnail is not None
+
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
- file_info.thumbnail_type,
- file_info.thumbnail_length,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
)
return
@@ -351,18 +356,18 @@ class ThumbnailResource(DirectServeJsonResource):
server_name,
file_id=file_id,
media_id=media_id,
- t_width=file_info.thumbnail_width,
- t_height=file_info.thumbnail_height,
- t_method=file_info.thumbnail_method,
- t_type=file_info.thumbnail_type,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
- t_width=file_info.thumbnail_width,
- t_height=file_info.thumbnail_height,
- t_method=file_info.thumbnail_method,
- t_type=file_info.thumbnail_type,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
url_cache=url_cache,
)
@@ -370,8 +375,8 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_responder(
request,
responder,
- file_info.thumbnail_type,
- file_info.thumbnail_length,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
)
else:
logger.info("Failed to find any generated thumbnails")
@@ -385,7 +390,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
- url_cache: Optional[str],
+ url_cache: bool,
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
@@ -398,7 +403,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: The URL cache value.
+ url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
Returns:
@@ -495,12 +500,13 @@ class ThumbnailResource(DirectServeJsonResource):
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- thumbnail_length=thumbnail_info["thumbnail_length"],
+ thumbnail=ThumbnailInfo(
+ width=thumbnail_info["thumbnail_width"],
+ height=thumbnail_info["thumbnail_height"],
+ type=thumbnail_info["thumbnail_type"],
+ method=thumbnail_info["thumbnail_method"],
+ length=thumbnail_info["thumbnail_length"],
+ ),
)
# No matching thumbnail was found.
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a65e9e18..df54a406 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -41,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod
- def set_limits(max_image_pixels: int):
+ def set_limits(max_image_pixels: int) -> None:
Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str):
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 146adca8..7dcb1428 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -16,8 +16,6 @@
import logging
from typing import IO, TYPE_CHECKING, Dict, List, Optional
-from twisted.web.server import Request
-
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_bytes_from_args
@@ -43,10 +41,10 @@ class UploadResource(DirectServeJsonResource):
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.auth = hs.get_auth()
- self.max_upload_size = hs.config.max_upload_size
+ self.max_upload_size = hs.config.media.max_upload_size
self.clock = hs.get_clock()
- async def _async_render_OPTIONS(self, request: Request) -> None:
+ async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: SynapseRequest) -> None:
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index 47a2f72b..6ad558f5 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -45,12 +45,12 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
# provider-specific SSO bits. Only load these if they are enabled, since they
# rely on optional dependencies.
- if hs.config.oidc_enabled:
+ if hs.config.oidc.oidc_enabled:
from synapse.rest.synapse.client.oidc import OIDCResource
resources["/_synapse/client/oidc"] = OIDCResource(hs)
- if hs.config.saml2_enabled:
+ if hs.config.saml2.saml2_enabled:
from synapse.rest.synapse.client.saml2 import SAML2Resource
res = SAML2Resource(hs)
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index 67c1ed1f..1c1c7b36 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generator
from twisted.web.server import Request
@@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._server_name = hs.hostname
self._consent_version = hs.config.consent.user_consent_version
- def template_search_dirs():
+ def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: Request):
+ async def _async_render_POST(self, request: Request) -> None:
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index 36ba4016..81fec396 100644
--- a/synapse/rest/synapse/client/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -13,16 +13,20 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class OIDCResource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index 7785f17e..4f375cb7 100644
--- a/synapse/rest/synapse/client/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,10 +31,10 @@ class OIDCCallbackResource(DirectServeHtmlResource):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
await self._oidc_handler.handle_oidc_callback(request)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
# the auth response can be returned via an x-www-form-urlencoded form instead
# of GET params, as per
# https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index f2800bf2..28a67f04 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -47,20 +47,20 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
self.store = hs.get_datastore()
self._local_threepid_handling_disabled_due_to_email_config = (
- hs.config.local_threepid_handling_disabled_due_to_email_config
+ hs.config.email.local_threepid_handling_disabled_due_to_email_config
)
self._confirmation_email_template = (
- hs.config.email_password_reset_template_confirmation_html
+ hs.config.email.email_password_reset_template_confirmation_html
)
self._email_password_reset_template_success_html = (
- hs.config.email_password_reset_template_success_html_content
+ hs.config.email.email_password_reset_template_success_html_content
)
self._failure_email_template = (
- hs.config.email_password_reset_template_failure_html
+ hs.config.email.email_password_reset_template_failure_html
)
# This resource should not be mounted if threepid behaviour is not LOCAL
- assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
sid = parse_string(request, "sid", required=True)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d30b478b..28ae0834 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, Generator, List, Tuple
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -27,6 +27,7 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_boolean, parse_string
from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util.templates import build_jinja_env
if TYPE_CHECKING:
@@ -57,7 +58,7 @@ class AvailabilityCheckResource(DirectServeJsonResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_GET(self, request: Request):
+ async def _async_render_GET(self, request: Request) -> Tuple[int, JsonDict]:
localpart = parse_string(request, "username", required=True)
session_id = get_username_mapping_session_cookie_from_request(request)
@@ -73,7 +74,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- def template_search_dirs():
+ def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@@ -104,7 +105,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: SynapseRequest):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us.
assert request.args is not None
diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 781ccb23..3f247e6a 100644
--- a/synapse/rest/synapse/client/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -13,17 +13,21 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class SAML2Resource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs))
diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index b37c7083..d8eae397 100644
--- a/synapse/rest/synapse/client/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
import saml2.metadata
from twisted.web.resource import Resource
+from twisted.web.server import Request
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SAML2MetadataResource(Resource):
@@ -23,11 +28,11 @@ class SAML2MetadataResource(Resource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
- self.sp_config = hs.config.saml2_sp_config
+ self.sp_config = hs.config.saml2.saml2_sp_config
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
metadata_xml = saml2.metadata.create_metadata_string(
configfile=None, config=self.sp_config
)
diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index 774ccd87..47d2a6a2 100644
--- a/synapse/rest/synapse/client/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
@@ -15,7 +15,10 @@
from typing import TYPE_CHECKING
+from twisted.web.server import Request
+
from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -31,7 +34,7 @@ class SAML2ResponseResource(DirectServeHtmlResource):
self._saml_handler = hs.get_saml_handler()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
@@ -40,5 +43,5 @@ class SAML2ResponseResource(DirectServeHtmlResource):
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 6a66a88c..c80a3a99 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,26 +13,26 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.http.server import set_cors_headers
+from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class WellKnownBuilder:
- """Utility to construct the well-known response
-
- Args:
- hs (synapse.server.HomeServer):
- """
-
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._config = hs.config
- def get_well_known(self):
+ def get_well_known(self) -> Optional[JsonDict]:
# if we don't have a public_baseurl, we can't help much here.
if self._config.server.public_baseurl is None:
return None
@@ -52,11 +52,11 @@ class WellKnownResource(Resource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self._well_known_builder = WellKnownBuilder(hs)
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
set_cors_headers(request)
r = self._well_known_builder.get_well_known()
if not r:
diff --git a/synapse/server.py b/synapse/server.py
index 4777ef58..637eb15b 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -392,7 +392,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
- if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+ if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
return InsecureInterceptableContextFactory()
return RegularPolicyForHTTPS()
@@ -418,8 +418,8 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
return SimpleHttpClient(
self,
- ip_whitelist=self.config.ip_range_whitelist,
- ip_blacklist=self.config.ip_range_blacklist,
+ ip_whitelist=self.config.server.ip_range_whitelist,
+ ip_blacklist=self.config.server.ip_range_blacklist,
use_proxy=True,
)
@@ -801,18 +801,18 @@ class HomeServer(metaclass=abc.ABCMeta):
logger.info(
"Connecting to redis (host=%r port=%r) for external cache",
- self.config.redis_host,
- self.config.redis_port,
+ self.config.redis.redis_host,
+ self.config.redis.redis_port,
)
return lazyConnection(
hs=self,
- host=self.config.redis_host,
- port=self.config.redis_port,
+ host=self.config.redis.redis_host,
+ port=self.config.redis.redis_port,
password=self.config.redis.redis_password,
reconnect=True,
)
def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
- return self.config.send_federation
+ return self.config.worker.send_federation
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 4e0f8140..e09a2559 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -36,9 +36,11 @@ class ConsentServerNotices:
self._users_in_progress: Set[str] = set()
- self._current_consent_version = hs.config.user_consent_version
- self._server_notice_content = hs.config.user_consent_server_notice_content
- self._send_to_guests = hs.config.user_consent_server_notice_to_guests
+ self._current_consent_version = hs.config.consent.user_consent_version
+ self._server_notice_content = (
+ hs.config.consent.user_consent_server_notice_content
+ )
+ self._send_to_guests = hs.config.consent.user_consent_server_notice_to_guests
if self._server_notice_content is not None:
if not self._server_notices_manager.is_enabled():
@@ -63,6 +65,9 @@ class ConsentServerNotices:
# not enabled
return
+ # A consent version must be given.
+ assert self._current_consent_version is not None
+
# make sure we don't send two messages to the same user at once
if user_id in self._users_in_progress:
return
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index d87a5389..cd1c5ff6 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -39,7 +39,7 @@ class ServerNoticesManager:
self._server_name = hs.hostname
self._notifier = hs.get_notifier()
- self.server_notices_mxid = self._config.server_notices_mxid
+ self.server_notices_mxid = self._config.servernotices.server_notices_mxid
def is_enabled(self):
"""Checks if server notices are enabled on this server.
@@ -47,7 +47,7 @@ class ServerNoticesManager:
Returns:
bool
"""
- return self._config.server_notices_mxid is not None
+ return self.server_notices_mxid is not None
async def send_notice(
self,
@@ -71,9 +71,9 @@ class ServerNoticesManager:
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
- system_mxid = self._config.server_notices_mxid
+ assert self.server_notices_mxid is not None
requester = create_requester(
- system_mxid, authenticated_entity=self._server_name
+ self.server_notices_mxid, authenticated_entity=self._server_name
)
logger.info("Sending server notice to %s", user_id)
@@ -81,7 +81,7 @@ class ServerNoticesManager:
event_dict = {
"type": type,
"room_id": room_id,
- "sender": system_mxid,
+ "sender": self.server_notices_mxid,
"content": event_content,
}
@@ -106,7 +106,7 @@ class ServerNoticesManager:
Returns:
room id of notice room.
"""
- if not self.is_enabled():
+ if self.server_notices_mxid is None:
raise Exception("Server notices not enabled")
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
@@ -139,12 +139,12 @@ class ServerNoticesManager:
# avatar, we have to use both.
join_profile = None
if (
- self._config.server_notices_mxid_display_name is not None
- or self._config.server_notices_mxid_avatar_url is not None
+ self._config.servernotices.server_notices_mxid_display_name is not None
+ or self._config.servernotices.server_notices_mxid_avatar_url is not None
):
join_profile = {
- "displayname": self._config.server_notices_mxid_display_name,
- "avatar_url": self._config.server_notices_mxid_avatar_url,
+ "displayname": self._config.servernotices.server_notices_mxid_display_name,
+ "avatar_url": self._config.servernotices.server_notices_mxid_avatar_url,
}
requester = create_requester(
@@ -154,7 +154,7 @@ class ServerNoticesManager:
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
- "name": self._config.server_notices_room_name,
+ "name": self._config.servernotices.server_notices_room_name,
"power_level_content_override": {"users_default": -10},
},
ratelimit=False,
@@ -178,6 +178,7 @@ class ServerNoticesManager:
user_id: The ID of the user to invite.
room_id: The ID of the room to invite the user to.
"""
+ assert self.server_notices_mxid is not None
requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name
)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 463ce58d..c981df3f 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -263,7 +263,9 @@ class StateHandler:
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
) -> EventContext:
- """Build an EventContext structure for the event.
+ """Build an EventContext structure for a non-outlier event.
+
+ (for an outlier, call EventContext.for_outlier directly)
This works out what the current state should be for the event, and
generates a new state group if necessary.
@@ -278,35 +280,7 @@ class StateHandler:
The event context.
"""
- if event.internal_metadata.is_outlier():
- # If this is an outlier, then we know it shouldn't have any current
- # state. Certainly store.get_current_state won't return any, and
- # persisting the event won't store the state group.
-
- # FIXME: why do we populate current_state_ids? I thought the point was
- # that we weren't supposed to have any state for outliers?
- if old_state:
- prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
- if event.is_state():
- current_state_ids = dict(prev_state_ids)
- key = (event.type, event.state_key)
- current_state_ids[key] = event.event_id
- else:
- current_state_ids = prev_state_ids
- else:
- current_state_ids = {}
- prev_state_ids = {}
-
- # We don't store state for outliers, so we don't generate a state
- # group for it.
- context = EventContext.with_state(
- state_group=None,
- state_group_before_event=None,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
- )
-
- return context
+ assert not event.internal_metadata.is_outlier()
#
# first of all, figure out the state before the event
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0084d9f9..f5a8f90a 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1632,7 +1632,7 @@ class DatabasePool:
txn: LoggingTransaction,
table: str,
column: str,
- iterable: Iterable[Any],
+ iterable: Collection[Any],
keyvalues: Dict[str, Any],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
@@ -1891,29 +1891,32 @@ class DatabasePool:
txn: LoggingTransaction,
table: str,
column: str,
- iterable: Iterable[Any],
+ values: Collection[Any],
keyvalues: Dict[str, Any],
) -> int:
"""Executes a DELETE query on the named table.
- Filters rows by if value of `column` is in `iterable`.
+ Deletes the rows:
+ - whose value of `column` is in `values`; AND
+ - that match extra column-value pairs specified in `keyvalues`.
Args:
txn: Transaction object
table: string giving the table name
- column: column name to test for inclusion against `iterable`
- iterable: list
- keyvalues: dict of column names and values to select the rows with
+ column: column name to test for inclusion against `values`
+ values: values of `column` which choose rows to delete
+ keyvalues: dict of extra column names and values to select the rows
+ with. They will be ANDed together with the main predicate.
Returns:
Number rows deleted
"""
- if not iterable:
+ if not values:
return 0
sql = "DELETE FROM %s" % table
- clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, values)
clauses = [clause]
for key, value in keyvalues.items():
@@ -2098,7 +2101,7 @@ class DatabasePool:
def make_in_list_sql_clause(
- database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
+ database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1dc347f0..5c21402d 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -61,6 +61,7 @@ from .registration import RegistrationStore
from .rejections import RejectionsStore
from .relations import RelationsStore
from .room import RoomStore
+from .room_batch import RoomBatchStore
from .roommember import RoomMemberStore
from .search import SearchStore
from .session import SessionStore
@@ -81,6 +82,7 @@ class DataStore(
EventsBackgroundUpdatesStore,
RoomMemberStore,
RoomStore,
+ RoomBatchStore,
RegistrationStore,
StreamStore,
ProfileStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 1d02795f..70ca3e09 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -324,7 +324,7 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
- return ({}, {})
+ return {}, {}
return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
@@ -494,7 +494,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn,
table="ignored_users",
column="ignored_user_id",
- iterable=previously_ignored_users - currently_ignored_users,
+ values=previously_ignored_users - currently_ignored_users,
keyvalues={"ignorer_user_id": user_id},
)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e2d1b758..2da2659f 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -60,7 +60,7 @@ def _make_exclusive_regex(
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
- hs.hostname, hs.config.app_service_config_files
+ hs.hostname, hs.config.appservice.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 7a98275d..cc192f5c 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -555,8 +555,11 @@ class ClientIpStore(ClientIpWorkerStore):
return ret
async def get_user_ip_and_agents(
- self, user: UserID
+ self, user: UserID, since_ts: int = 0
) -> List[Dict[str, Union[str, int]]]:
+ """
+ Fetch IP/User Agent connection since a given timestamp.
+ """
user_id = user.to_string()
results = {}
@@ -568,18 +571,28 @@ class ClientIpStore(ClientIpWorkerStore):
) = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
- results[(access_token, ip)] = (user_agent, last_seen)
+ if last_seen >= since_ts:
+ results[(access_token, ip)] = (user_agent, last_seen)
- rows = await self.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "last_seen"],
- desc="get_user_ip_and_agents",
+ def get_recent(txn):
+ txn.execute(
+ """
+ SELECT access_token, ip, user_agent, last_seen FROM user_ips
+ WHERE last_seen >= ? AND user_id = ?
+ ORDER BY last_seen
+ DESC
+ """,
+ (since_ts, user_id),
+ )
+ return txn.fetchall()
+
+ rows = await self.db_pool.runInteraction(
+ desc="get_user_ip_and_agents", func=get_recent
)
results.update(
- ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
- for row in rows
+ ((access_token, ip), (user_agent, last_seen))
+ for access_token, ip, user_agent, last_seen in rows
)
return [
{
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index c5550886..3154906d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -136,7 +136,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_id, last_stream_id
)
if not has_changed:
- return ([], current_stream_id)
+ return [], current_stream_id
def get_new_messages_for_device_txn(txn):
sql = (
@@ -240,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
if not has_changed or last_stream_id == current_stream_id:
log_kv({"message": "No new messages in stream"})
- return ([], current_stream_id)
+ return [], current_stream_id
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
- return ([], last_stream_id)
+ return [], last_stream_id
@trace
def get_new_messages_for_remote_destination_txn(txn):
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1f0a39ea..a95ac34f 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -824,6 +824,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
if otk_row is None:
return None
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 047782eb..10184d6a 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1034,13 +1034,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
LIMIT ?
"""
- # Find any chunk connections of a given insertion event
- chunk_connection_query = """
+ # Find any batch connections of a given insertion event
+ batch_connection_query = """
SELECT e.depth, c.event_id FROM insertion_events AS i
- /* Find the chunk that connects to the given insertion event */
- INNER JOIN chunk_events AS c
- ON i.next_chunk_id = c.chunk_id
- /* Get the depth of the chunk start event from the events table */
+ /* Find the batch that connects to the given insertion event */
+ INNER JOIN batch_events AS c
+ ON i.next_batch_id = c.batch_id
+ /* Get the depth of the batch start event from the events table */
INNER JOIN events AS e USING (event_id)
/* Find an insertion event which matches the given event_id */
WHERE i.event_id = ?
@@ -1077,12 +1077,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.add(event_id)
- # Try and find any potential historical chunks of message history.
+ # Try and find any potential historical batches of message history.
#
# First we look for an insertion event connected to the current
# event (by prev_event). If we find any, we need to go and try to
- # find any chunk events connected to the insertion event (by
- # chunk_id). If we find any, we'll add them to the queue and
+ # find any batch events connected to the insertion event (by
+ # batch_id). If we find any, we'll add them to the queue and
# navigate up the DAG like normal in the next iteration of the loop.
txn.execute(
connected_insertion_event_query, (event_id, limit - len(event_results))
@@ -1097,17 +1097,17 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
connected_insertion_event = row[1]
queue.put((-connected_insertion_event_depth, connected_insertion_event))
- # Find any chunk connections for the given insertion event
+ # Find any batch connections for the given insertion event
txn.execute(
- chunk_connection_query,
+ batch_connection_query,
(connected_insertion_event, limit - len(event_results)),
)
- chunk_start_event_id_results = txn.fetchall()
+ batch_start_event_id_results = txn.fetchall()
logger.debug(
- "_get_backfill_events: chunk_start_event_id_results %s",
- chunk_start_event_id_results,
+ "_get_backfill_events: batch_start_event_id_results %s",
+ batch_start_event_id_results,
)
- for row in chunk_start_event_id_results:
+ for row in batch_start_event_id_results:
if row[1] not in event_results:
queue.put((-row[0], row[1]))
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8e691678..584f818f 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -667,7 +667,7 @@ class PersistEventsStore:
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
- iterable=new_chain_tuples,
+ values=new_chain_tuples,
)
# Now we need to calculate any new links between chains caused by
@@ -1509,7 +1509,7 @@ class PersistEventsStore:
self._handle_event_relations(txn, event)
self._handle_insertion_event(txn, event)
- self._handle_chunk_event(txn, event)
+ self._handle_batch_event(txn, event)
# Store the labels for this event.
labels = event.content.get(EventContentFields.LABELS)
@@ -1790,23 +1790,23 @@ class PersistEventsStore:
):
return
- next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID)
- if next_chunk_id is None:
- # Invalid insertion event without next chunk ID
+ next_batch_id = event.content.get(EventContentFields.MSC2716_NEXT_BATCH_ID)
+ if next_batch_id is None:
+ # Invalid insertion event without next batch ID
return
logger.debug(
- "_handle_insertion_event (next_chunk_id=%s) %s", next_chunk_id, event
+ "_handle_insertion_event (next_batch_id=%s) %s", next_batch_id, event
)
- # Keep track of the insertion event and the chunk ID
+ # Keep track of the insertion event and the batch ID
self.db_pool.simple_insert_txn(
txn,
table="insertion_events",
values={
"event_id": event.event_id,
"room_id": event.room_id,
- "next_chunk_id": next_chunk_id,
+ "next_batch_id": next_batch_id,
},
)
@@ -1822,8 +1822,8 @@ class PersistEventsStore:
},
)
- def _handle_chunk_event(self, txn: LoggingTransaction, event: EventBase):
- """Handles inserting the chunk edges/connections between the chunk event
+ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+ """Handles inserting the batch edges/connections between the batch event
and an insertion event. Part of MSC2716.
Args:
@@ -1831,11 +1831,11 @@ class PersistEventsStore:
event: The event to process
"""
- if event.type != EventTypes.MSC2716_CHUNK:
- # Not a chunk event
+ if event.type != EventTypes.MSC2716_BATCH:
+ # Not a batch event
return
- # Skip processing a chunk event if the room version doesn't
+ # Skip processing a batch event if the room version doesn't
# support it or the event is not from the room creator.
room_version = self.store.get_room_version_txn(txn, event.room_id)
room_creator = self.db_pool.simple_select_one_onecol_txn(
@@ -1852,35 +1852,35 @@ class PersistEventsStore:
):
return
- chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID)
- if chunk_id is None:
- # Invalid chunk event without a chunk ID
+ batch_id = event.content.get(EventContentFields.MSC2716_BATCH_ID)
+ if batch_id is None:
+ # Invalid batch event without a batch ID
return
- logger.debug("_handle_chunk_event chunk_id=%s %s", chunk_id, event)
+ logger.debug("_handle_batch_event batch_id=%s %s", batch_id, event)
- # Keep track of the insertion event and the chunk ID
+ # Keep track of the insertion event and the batch ID
self.db_pool.simple_insert_txn(
txn,
- table="chunk_events",
+ table="batch_events",
values={
"event_id": event.event_id,
"room_id": event.room_id,
- "chunk_id": chunk_id,
+ "batch_id": batch_id,
},
)
- # When we receive an event with a `chunk_id` referencing the
- # `next_chunk_id` of the insertion event, we can remove it from the
+ # When we receive an event with a `batch_id` referencing the
+ # `next_batch_id` of the insertion event, we can remove it from the
# `insertion_event_extremities` table.
sql = """
DELETE FROM insertion_event_extremities WHERE event_id IN (
SELECT event_id FROM insertion_events
- WHERE next_chunk_id = ?
+ WHERE next_batch_id = ?
)
"""
- txn.execute(sql, (chunk_id,))
+ txn.execute(sql, (batch_id,))
def _handle_redaction(self, txn, redacted_event_id):
"""Handles receiving a redaction and checking whether we need to remove
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 6fcb2b83..1afc59fa 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -490,7 +490,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="event_forward_extremities",
column="event_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
@@ -520,7 +520,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="_extremities_to_check",
column="event_id",
- iterable=original_set,
+ values=original_set,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d72e716b..4a1a2f4a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1495,7 +1495,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+ return int(res["topological_ordering"]), int(res["stream_ordering"])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index d213b267..b76ee51a 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -63,7 +63,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
- `config.track_appservice_user_ips` must be set to `true` for this
+ `config.appservice.track_appservice_user_ips` must be set to `true` for this
method to return anything other than native matrix users.
Returns:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bccff5e5..3eb30944 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -102,15 +102,19 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
(room_id,),
)
rows = txn.fetchall()
- max_depth = max(row[1] for row in rows)
-
- if max_depth < token.topological:
- # We need to ensure we don't delete all the events from the database
- # otherwise we wouldn't be able to send any events (due to not
- # having any backwards extremities)
- raise SynapseError(
- 400, "topological_ordering is greater than forward extremeties"
- )
+ # if we already have no forwards extremities (for example because they were
+ # cleared out by the `delete_old_current_state_events` background database
+ # update), then we may as well carry on.
+ if rows:
+ max_depth = max(row[1] for row in rows)
+
+ if max_depth < token.topological:
+ # We need to ensure we don't delete all the events from the database
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremities)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremities"
+ )
logger.info("[purge] looking for events to delete")
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 63ac09c6..a93caae8 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -324,7 +324,7 @@ class PusherWorkerStore(SQLBaseStore):
txn,
table="pushers",
column="user_name",
- iterable=users,
+ values=users,
keyvalues={},
)
@@ -373,7 +373,7 @@ class PusherWorkerStore(SQLBaseStore):
txn,
table="pushers",
column="id",
- iterable=(pusher_id for pusher_id, token in pushers if token is None),
+ values=[pusher_id for pusher_id, token in pushers if token is None],
keyvalues={},
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index edeaacd7..01a42813 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
@@ -153,12 +153,12 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_id: List of room_ids.
+ room_id: The room IDs to fetch receipts of.
to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fafadb88..c83089ee 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -388,7 +388,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"get_users_expiring_soon",
select_users_txn,
self._clock.time_msec(),
- self.config.account_validity_renew_at,
+ self.config.account_validity.account_validity_renew_at,
)
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@@ -2015,7 +2015,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
(user_id_obj.localpart, create_profile_with_displayname),
)
- if self.hs.config.stats_enabled:
+ if self.hs.config.stats.stats_enabled:
# we create a new completed user statistics row
# we don't strictly need current_token since this user really can't
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
new file mode 100644
index 00000000..a3833887
--- /dev/null
+++ b/synapse/storage/databases/main/room_batch.py
@@ -0,0 +1,36 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from synapse.storage._base import SQLBaseStore
+
+
+class RoomBatchStore(SQLBaseStore):
+ async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]:
+ """Retrieve a insertion event ID.
+
+ Args:
+ batch_id: The batch ID of the insertion event to retrieve.
+
+ Returns:
+ The event_id of an insertion event, or None if there is no known
+ insertion event for the given insertion event.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="insertion_events",
+ keyvalues={"next_batch_id": batch_id},
+ retcol="event_id",
+ allow_none=True,
+ )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 9beeb96a..ddb162a4 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -82,7 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if (
self.hs.config.worker.run_background_tasks
- and self.hs.config.metrics_flags.known_servers
+ and self.hs.config.metrics.metrics_flags.known_servers
):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
@@ -162,7 +162,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn,
)
- @cached(max_entries=100000, iterable=True)
+ @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
async def get_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
@@ -439,7 +439,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results_dict.get("membership"), results_dict.get("event_id")
- @cached(max_entries=500000, iterable=True)
+ @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -544,7 +544,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return frozenset(r.room_id for r in rooms)
- @cached(max_entries=500000, cache_context=True, iterable=True)
+ @cached(
+ max_entries=500000,
+ cache_context=True,
+ iterable=True,
+ prune_unread_entries=False,
+ )
async def get_users_who_share_room_with_user(
self, user_id: str, cache_context: _CacheContext
) -> Set[str]:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 6480d5a9..2a1e99e1 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -15,12 +15,12 @@
import logging
import re
from collections import namedtuple
-from typing import Collection, List, Optional, Set
+from typing import Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -32,14 +32,24 @@ SearchEntry = namedtuple(
)
+def _clean_value_for_search(value: str) -> str:
+ """
+ Replaces any null code points in the string with spaces as
+ Postgres and SQLite do not like the insertion of strings with
+ null code points into the full-text search tables.
+ """
+ return value.replace("\u0000", " ")
+
+
class SearchWorkerStore(SQLBaseStore):
- def store_search_entries_txn(self, txn, entries):
+ def store_search_entries_txn(
+ self, txn: LoggingTransaction, entries: Iterable[SearchEntry]
+ ) -> None:
"""Add entries to the search table
Args:
- txn (cursor):
- entries (iterable[SearchEntry]):
- entries to be added to the table
+ txn:
+ entries: entries to be added to the table
"""
if not self.hs.config.enable_search:
return
@@ -55,7 +65,7 @@ class SearchWorkerStore(SQLBaseStore):
entry.event_id,
entry.room_id,
entry.key,
- entry.value,
+ _clean_value_for_search(entry.value),
entry.stream_ordering,
entry.origin_server_ts,
)
@@ -70,11 +80,16 @@ class SearchWorkerStore(SQLBaseStore):
" VALUES (?,?,?,?)"
)
args = (
- (entry.event_id, entry.room_id, entry.key, entry.value)
+ (
+ entry.event_id,
+ entry.room_id,
+ entry.key,
+ _clean_value_for_search(entry.value),
+ )
for entry in entries
)
-
txn.execute_batch(sql, args)
+
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -646,6 +661,7 @@ class SearchStore(SearchBackgroundUpdateStore):
for key in ("body", "name", "topic"):
v = event.content.get(key, None)
if v:
+ v = _clean_value_for_search(v)
values.append(v)
if not values:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 8e22da99..a8e8dd45 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -473,7 +473,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
@@ -481,7 +481,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="event_forward_extremities",
column="room_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index bff7d040..a89747d7 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -58,7 +58,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
- return (max_stream_id, [])
+ return max_stream_id, []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 343d6efc..e20033bb 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -98,7 +98,7 @@ class StatsStore(StateDeltasStore):
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
- self.stats_enabled = hs.config.stats_enabled
+ self.stats_enabled = hs.config.stats.stats_enabled
self.stats_delta_processing_lock = DeferredLock()
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 959f13de..dc7884b1 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,6 +39,8 @@ import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from frozendict import frozendict
+
from twisted.internet import defer
from synapse.api.filtering import Filter
@@ -379,7 +381,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if p > min_pos
}
- return RoomStreamToken(None, min_pos, positions)
+ return RoomStreamToken(None, min_pos, frozendict(positions))
async def get_room_events_stream_for_rooms(
self,
@@ -622,7 +624,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
self._set_before_and_after(events, rows)
- return (events, token)
+ return events, token
async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: RoomStreamToken
@@ -1240,7 +1242,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
self._set_before_and_after(events, rows)
- return (events, token)
+ return events, token
@cached()
async def get_id_for_instance(self, instance_name: str) -> int:
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 4d6bbc94..340ca9e4 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -326,7 +326,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions_ips",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
@@ -377,7 +377,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions_credentials",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
@@ -386,7 +386,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 8aebdc28..90d65edc 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,14 +14,28 @@
import logging
import re
-from typing import Any, Dict, Iterable, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ cast,
+)
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import get_domain_from_id, get_localpart_from_id
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: Connection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -57,10 +76,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
- async def _populate_user_directory_createtables(self, progress, batch_size):
+ async def _populate_user_directory_createtables(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
# Get all the rooms that we want to process.
- def _make_staging_area(txn):
+ def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
"CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE
@@ -85,19 +106,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
- # If search all users is on, get all the users we want to add.
- if self.hs.config.user_directory_search_all_users:
- sql = (
- "CREATE TABLE IF NOT EXISTS "
- + TEMP_TABLE
- + "_users(user_id TEXT NOT NULL)"
- )
- txn.execute(sql)
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_users(user_id TEXT NOT NULL)"
+ )
+ txn.execute(sql)
- txn.execute("SELECT name FROM users")
- users = [{"user_id": x[0]} for x in txn.fetchall()]
+ txn.execute("SELECT name FROM users")
+ users = [{"user_id": x[0]} for x in txn.fetchall()]
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = await self.get_max_stream_id_in_current_state_deltas()
await self.db_pool.runInteraction(
@@ -112,16 +131,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
return 1
- async def _populate_user_directory_cleanup(self, progress, batch_size):
+ async def _populate_user_directory_cleanup(
+ self,
+ progress: JsonDict,
+ batch_size: int,
+ ) -> int:
"""
Update the user directory stream position, then clean up the old tables.
"""
position = await self.db_pool.simple_select_one_onecol(
- TEMP_TABLE + "_position", None, "position"
+ TEMP_TABLE + "_position", {}, "position"
)
await self.update_user_directory_stream_pos(position)
- def _delete_staging_area(txn):
+ def _delete_staging_area(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
@@ -135,18 +158,32 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
return 1
- async def _populate_user_directory_process_rooms(self, progress, batch_size):
+ async def _populate_user_directory_process_rooms(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""
+ Rescan the state of all rooms so we can track
+
+ - who's in a public room;
+ - which local users share a private room with other users (local
+ and remote); and
+ - who should be in the user_directory.
+
Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.
+
+ Returns:
+ number of events processed.
"""
# If we don't have progress filed, delete everything.
if not progress:
await self.delete_all_from_user_dir()
- def _get_next_batch(txn):
+ def _get_next_batch(
+ txn: LoggingTransaction,
+ ) -> Optional[Sequence[Tuple[str, int]]]:
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """
@@ -157,7 +194,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
TEMP_TABLE + "_rooms",
)
txn.execute(sql)
- rooms_to_work_on = txn.fetchall()
+ rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
if not rooms_to_work_on:
return None
@@ -165,7 +202,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# Get how many are left to process, so we can give status on how
# far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
- progress["remaining"] = txn.fetchone()[0]
+ result = txn.fetchone()
+ assert result is not None
+ progress["remaining"] = result[0]
return rooms_to_work_on
@@ -263,34 +302,33 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count
- async def _populate_user_directory_process_users(self, progress, batch_size):
+ async def _populate_user_directory_process_users(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""
- If search_all_users is enabled, add all of the users to the user directory.
+ Add all local users to the user directory.
"""
- if not self.hs.config.user_directory_search_all_users:
- await self.db_pool.updates._end_background_update(
- "populate_user_directory_process_users"
- )
- return 1
- def _get_next_batch(txn):
+ def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
sql = "SELECT user_id FROM %s LIMIT %s" % (
TEMP_TABLE + "_users",
str(batch_size),
)
txn.execute(sql)
- users_to_work_on = txn.fetchall()
+ user_result = cast(List[Tuple[str]], txn.fetchall())
- if not users_to_work_on:
+ if not user_result:
return None
- users_to_work_on = [x[0] for x in users_to_work_on]
+ users_to_work_on = [x[0] for x in user_result]
# Get how many are left to process, so we can give status on how
# far we are in processing
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
txn.execute(sql)
- progress["remaining"] = txn.fetchone()[0]
+ count_result = txn.fetchone()
+ assert count_result is not None
+ progress["remaining"] = count_result[0]
return users_to_work_on
@@ -331,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
- async def is_room_world_readable_or_publicly_joinable(self, room_id):
+ async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
# Create a state filter that only queries join and history state event
@@ -375,7 +413,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if not isinstance(avatar_url, str):
avatar_url = None
- def _update_profile_in_user_dir_txn(txn):
+ def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
@@ -442,7 +480,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id, other_user_id in user_id_tuples
],
value_names=(),
- value_values=None,
+ value_values=(),
desc="add_users_who_share_room",
)
@@ -461,14 +499,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(),
- value_values=None,
+ value_values=(),
desc="add_users_in_public_rooms",
)
async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory"""
- def _delete_all_from_user_dir_txn(txn):
+ def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
@@ -480,7 +518,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
@@ -504,16 +542,21 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: Connection,
+ hs: "HomeServer",
+ ) -> None:
super().__init__(database, db_conn, hs)
self._prefer_local_users_in_search = (
- hs.config.user_directory_search_prefer_local_users
+ hs.config.userdirectory.user_directory_search_prefer_local_users
)
self._server_name = hs.config.server.server_name
async def remove_from_user_dir(self, user_id: str) -> None:
- def _remove_from_user_dir_txn(txn):
+ def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
@@ -539,7 +582,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_from_user_dir", _remove_from_user_dir_txn
)
- async def get_users_in_dir_due_to_room(self, room_id):
+ async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
@@ -572,7 +615,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
room_id
"""
- def _remove_user_who_share_room_txn(txn):
+ def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
@@ -593,7 +636,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
- async def get_user_dir_rooms_user_is_in(self, user_id):
+ async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
"""
Returns the rooms that a user is in.
@@ -635,7 +678,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
- def _get_shared_rooms_for_users_txn(txn):
+ def _get_shared_rooms_for_users_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, str]]:
txn.execute(
"""
SELECT p1.room_id
@@ -676,7 +721,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
desc="get_user_directory_stream_pos",
)
- async def search_user_dir(self, user_id, search_term, limit):
+ async def search_user_dir(
+ self, user_id: str, search_term: str, limit: int
+ ) -> JsonDict:
"""Searches for users in directory
Returns:
@@ -694,7 +741,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
}
"""
- if self.hs.config.user_directory_search_all_users:
+ if self.hs.config.userdirectory.user_directory_search_all_users:
join_args = (user_id,)
where_clause = "user_id != ?"
else:
@@ -712,7 +759,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
- ordering_arguments = ()
+ ordering_arguments: Tuple[str, ...] = ()
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -818,7 +865,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results}
-def _parse_query_sqlite(search_term):
+def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
@@ -833,7 +880,7 @@ def _parse_query_sqlite(search_term):
return " & ".join("(%s* OR %s)" % (result, result) for result in results)
-def _parse_query_postgres(search_term):
+def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c2891cb0..eb1118d2 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,12 +13,20 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates.
"""
- def _count_state_group_hops_txn(self, txn, state_group):
+ def _count_state_group_hops_txn(
+ self, txn: LoggingTransaction, state_group: int
+ ) -> int:
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
+ next_group: Optional[int] = state_group
count = 0
while next_group:
@@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter: Optional[StateFilter] = None
- ):
+ self,
+ txn: LoggingTransaction,
+ groups: List[int],
+ state_filter: Optional[StateFilter] = None,
+ ) -> Mapping[int, StateMap[str]]:
state_filter = state_filter or StateFilter.all()
- results = {group: {} for group in groups}
+ results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
@@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""
for group in groups:
- args = [group]
+ args: List[Union[int, str]] = [group]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
@@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
- next_group = group
+ next_group: Optional[int] = group
while next_group:
# We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
allow_none=True,
)
+ # The results shouldn't be considered mutable.
return results
@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
- async def _background_deduplicate_state(self, progress, batch_size):
+ async def _background_deduplicate_state(
+ self, progress: dict, batch_size: int
+ ) -> int:
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
max_group = rows[0][0]
- def reindex_txn(txn):
+ def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group
for count in range(batch_size):
txn.execute(
@@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
- (prev_group,) = txn.fetchone()
+ # There will be a result due to the coalesce.
+ (prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group
if prev_group:
@@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
# otherwise read performance degrades.
continue
- prev_state = self._get_state_groups_from_groups_txn(
+ prev_state_by_group = self._get_state_groups_from_groups_txn(
txn, [prev_group]
)
- prev_state = prev_state[prev_group]
+ prev_state = prev_state_by_group[prev_group]
- curr_state = self._get_state_groups_from_groups_txn(
+ curr_state_by_group = self._get_state_groups_from_groups_txn(
txn, [state_group]
)
- curr_state = curr_state[state_group]
+ curr_state = curr_state_by_group[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return result * BATCH_SIZE_SCALE_FACTOR
- async def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
+ async def _background_index_state(self, progress: dict, batch_size: int) -> int:
+ def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f839c0c2..c4c8c002 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,43 +13,56 @@
# limitations under the License.
import logging
-from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+
+import attr
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
-class _GetStateGroupDelta(
- namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _GetStateGroupDelta:
"""Return type of get_state_group_delta that implements __len__, which lets
- us use the itrable flag when caching
+ us use the iterable flag when caching
"""
- __slots__ = []
+ prev_group: Optional[int]
+ delta_ids: Optional[StateMap[str]]
- def __len__(self):
+ def __len__(self) -> int:
return len(self.delta_ids) if self.delta_ids else 0
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups."""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
@@ -81,19 +94,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# We size the non-members cache to be smaller than the members cache as the
# vast majority of state in Matrix (today) is member events.
- self._state_group_cache = DictionaryCache(
+ self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
50000,
)
- self._state_group_members_cache = DictionaryCache(
+ self._state_group_members_cache: DictionaryCache[
+ int, StateKey, str
+ ] = DictionaryCache(
"*stateGroupMembersCache*",
500000,
)
- def get_max_state_group_txn(txn: Cursor):
+ def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
- return txn.fetchone()[0]
+ return txn.fetchone()[0] # type: ignore
self._state_group_seq_gen = build_sequence_generator(
db_conn,
@@ -105,15 +120,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@cached(max_entries=10000, iterable=True)
- async def get_state_group_delta(self, state_group):
+ async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta:
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- (prev_group, delta_ids), where both may be None.
+ _GetStateGroupDelta containing prev_group and delta_ids, where both may be None.
"""
- def _get_state_group_delta_txn(txn):
+ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
prev_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
@@ -154,7 +169,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
- results = {}
+ results: Dict[int, StateMap[str]] = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
@@ -168,19 +183,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return results
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ def _get_state_for_group_using_cache(
+ self,
+ cache: DictionaryCache[int, StateKey, str],
+ group: int,
+ state_filter: StateFilter,
+ ) -> Tuple[MutableStateMap[str], bool]:
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
- cache(DictionaryCache): the state group cache to use
- group(int): The state group to lookup
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ cache: the state group cache to use
+ group: The state group to lookup
+ state_filter: The state filter used to fetch state from the database.
- Returns 2-tuple (`state_dict`, `got_all`).
- `got_all` is a bool indicating if we successfully retrieved all
- requests state from the cache, if False we need to query the DB for the
- missing state.
+ Returns:
+ 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
"""
cache_entry = cache.get(group)
state_dict_ids = cache_entry.value
@@ -277,8 +297,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state
def _get_state_for_groups_using_cache(
- self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
- ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+ self,
+ groups: Iterable[int],
+ cache: DictionaryCache[int, StateKey, str],
+ state_filter: StateFilter,
+ ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@@ -310,21 +333,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
def _insert_into_cache(
self,
- group_to_state_dict,
- state_filter,
- cache_seq_num_members,
- cache_seq_num_non_members,
- ):
+ group_to_state_dict: Dict[int, StateMap[str]],
+ state_filter: StateFilter,
+ cache_seq_num_members: int,
+ cache_seq_num_non_members: int,
+ ) -> None:
"""Inserts results from querying the database into the relevant cache.
Args:
- group_to_state_dict (dict): The new entries pulled from database.
+ group_to_state_dict: The new entries pulled from database.
Map from state group to state dict
- state_filter (StateFilter): The state filter used to fetch state
+ state_filter: The state filter used to fetch state
from the database.
- cache_seq_num_members (int): Sequence number of member cache since
+ cache_seq_num_members: Sequence number of member cache since
last lookup in cache
- cache_seq_num_non_members (int): Sequence number of member cache since
+ cache_seq_num_non_members: Sequence number of member cache since
last lookup in cache
"""
@@ -395,7 +418,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group ID
"""
- def _store_state_group_txn(txn):
+ def _store_state_group_txn(txn: LoggingTransaction) -> int:
if current_state_ids is None:
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
@@ -426,6 +449,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ assert delta_ids is not None
+
self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
@@ -498,7 +523,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def purge_unreferenced_state_groups(
- self, room_id: str, state_groups_to_delete
+ self, room_id: str, state_groups_to_delete: Collection[int]
) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@@ -506,8 +531,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Args:
room_id: The room the state groups belong to (must all be in the
same room).
- state_groups_to_delete (Collection[int]): Set of all state groups
- to delete.
+ state_groups_to_delete: Set of all state groups to delete.
"""
await self.db_pool.runInteraction(
@@ -517,7 +541,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete,
)
- def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ def _purge_unreferenced_state_groups(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_groups_to_delete: Collection[int],
+ ) -> None:
logger.info(
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
@@ -546,8 +575,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# groups to non delta versions.
for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
+ curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state_by_group[sg]
self.db_pool.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
@@ -605,12 +634,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows}
- async def purge_room_state(self, room_id, state_groups_to_delete):
+ async def purge_room_state(
+ self, room_id: str, state_groups_to_delete: Collection[int]
+ ) -> None:
"""Deletes all record of a room from state tables
Args:
- room_id (str):
- state_groups_to_delete (list[int]): State groups to delete
+ room_id:
+ state_groups_to_delete: State groups to delete
"""
await self.db_pool.runInteraction(
@@ -620,7 +651,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete,
)
- def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ def _purge_room_state_txn(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_groups_to_delete: Collection[int],
+ ) -> None:
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
@@ -628,7 +664,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_groups_state",
column="state_group",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
@@ -639,7 +675,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_group_edges",
column="state_group",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
@@ -650,6 +686,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_groups",
column="id",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d4754c90..f31880b8 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -545,7 +545,7 @@ def _apply_module_schemas(
database_engine:
config: application config
"""
- for (mod, _config) in config.password_providers:
+ for (mod, _config) in config.authproviders.password_providers:
if not hasattr(mod, "get_db_schema_files"):
continue
modname = ".".join((mod.__module__, mod.__name__))
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index c552dbf0..10a46b5e 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -73,7 +73,7 @@ class RelationPaginationToken:
t, s = string.split("-")
return RelationPaginationToken(int(t), int(s))
except ValueError:
- raise SynapseError(400, "Invalid token")
+ raise SynapseError(400, "Invalid relation pagination token")
def to_string(self) -> str:
return "%d-%d" % (self.topological, self.stream)
@@ -103,7 +103,7 @@ class AggregationPaginationToken:
c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s))
except ValueError:
- raise SynapseError(400, "Invalid token")
+ raise SynapseError(400, "Invalid aggregation pagination token")
def to_string(self) -> str:
return "%d-%d" % (self.count, self.stream)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index af9cc699..573e05a4 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -14,7 +14,7 @@
# When updating these values, please leave a short summary of the changes below.
-SCHEMA_VERSION = 63
+SCHEMA_VERSION = 64
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -27,11 +27,22 @@ for more information on how this works.
Changes in SCHEMA_VERSION = 61:
- The `user_stats_historical` and `room_stats_historical` tables are not written and
are not read (previously, they were written but not read).
+ - MSC2716: Add `insertion_events` and `insertion_event_edges` tables to keep track
+ of insertion events in order to navigate historical chunks of messages.
+ - MSC2716: Add `chunk_events` table to track how the chunk is labeled and
+ determines which insertion event it points to.
+
+Changes in SCHEMA_VERSION = 62:
+ - MSC2716: Add `insertion_event_extremities` table that keeps track of which
+ insertion events need to be backfilled.
Changes in SCHEMA_VERSION = 63:
- The `public_room_list_stream` table is not written nor read to
(previously, it was written and read to, but not for any significant purpose).
https://github.com/matrix-org/synapse/pull/10565
+
+Changes in SCHEMA_VERSION = 64:
+ - MSC2716: Rename related tables and columns from "chunks" to "batches".
"""
diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py
index 8a1f3400..22a7901e 100644
--- a/synapse/storage/schema/main/delta/30/as_users.py
+++ b/synapse/storage/schema/main/delta/30/as_users.py
@@ -33,7 +33,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
config_files = []
try:
- config_files = config.app_service_config_files
+ config_files = config.appservice.app_service_config_files
except AttributeError:
logger.warning("Could not get app_service_config_files from config")
pass
diff --git a/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.postgres b/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.postgres
new file mode 100644
index 00000000..5f389932
--- /dev/null
+++ b/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.postgres
@@ -0,0 +1,23 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE insertion_events RENAME COLUMN next_chunk_id TO next_batch_id;
+DROP INDEX insertion_events_next_chunk_id;
+CREATE INDEX IF NOT EXISTS insertion_events_next_batch_id ON insertion_events(next_batch_id);
+
+ALTER TABLE chunk_events RENAME TO batch_events;
+ALTER TABLE batch_events RENAME COLUMN chunk_id TO batch_id;
+DROP INDEX chunk_events_chunk_id;
+CREATE INDEX IF NOT EXISTS batch_events_batch_id ON batch_events(batch_id);
diff --git a/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.sqlite b/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.sqlite
new file mode 100644
index 00000000..49895639
--- /dev/null
+++ b/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.sqlite
@@ -0,0 +1,37 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Re-create the insertion_events table since SQLite doesn't support better
+-- renames for columns (next_chunk_id -> next_batch_id)
+DROP TABLE insertion_events;
+CREATE TABLE IF NOT EXISTS insertion_events(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ next_batch_id TEXT NOT NULL
+);
+CREATE UNIQUE INDEX IF NOT EXISTS insertion_events_event_id ON insertion_events(event_id);
+CREATE INDEX IF NOT EXISTS insertion_events_next_batch_id ON insertion_events(next_batch_id);
+
+-- Re-create the chunk_events table since SQLite doesn't support better renames
+-- for columns (chunk_id -> batch_id)
+DROP TABLE chunk_events;
+CREATE TABLE IF NOT EXISTS batch_events(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ batch_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS batch_events_event_id ON batch_events(event_id);
+CREATE INDEX IF NOT EXISTS batch_events_batch_id ON batch_events(batch_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e5400d68..5e86befd 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,12 +25,15 @@ from typing import (
)
import attr
+from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
+ from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
+
from synapse.server import HomeServer
from synapse.storage.databases import Databases
@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True)
class StateFilter:
"""A filter used when querying for state.
@@ -53,14 +56,19 @@ class StateFilter:
appear in `types`.
"""
- types = attr.ib(type=Dict[str, Optional[Set[str]]])
+ types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {k: v for k, v in self.types.items() if v is not None}
+ # this is needed to work around the fact that StateFilter is frozen
+ object.__setattr__(
+ self,
+ "types",
+ frozendict({k: v for k, v in self.types.items() if v is not None}),
+ )
@staticmethod
def all() -> "StateFilter":
@@ -69,7 +77,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- return StateFilter(types={}, include_others=True)
+ return StateFilter(types=frozendict(), include_others=True)
@staticmethod
def none() -> "StateFilter":
@@ -78,7 +86,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- return StateFilter(types={}, include_others=False)
+ return StateFilter(types=frozendict(), include_others=False)
@staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +111,12 @@ class StateFilter:
type_dict.setdefault(typ, set()).add(s) # type: ignore
- return StateFilter(types=type_dict)
+ return StateFilter(
+ types=frozendict(
+ (k, frozenset(v) if v is not None else None)
+ for k, v in type_dict.items()
+ )
+ )
@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +129,10 @@ class StateFilter:
Returns:
The new state filter
"""
- return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
+ return StateFilter(
+ types=frozendict({EventTypes.Member: frozenset(members)}),
+ include_others=True,
+ )
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +189,7 @@ class StateFilter:
# We want to return all non-members, but only particular
# memberships
return StateFilter(
- types={EventTypes.Member: self.types[EventTypes.Member]},
+ types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)
@@ -245,14 +261,15 @@ class StateFilter:
return len(self.concrete_types())
- def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
- """Returns the state filtered with by this StateFilter
+ def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
+ """Returns the state filtered with by this StateFilter.
Args:
state: The state map to filter
Returns:
- The filtered state map
+ The filtered state map.
+ This is a copy, so it's safe to mutate.
"""
if self.is_full():
return dict(state_dict)
@@ -324,14 +341,16 @@ class StateFilter:
if state_keys is None:
member_filter = StateFilter.all()
else:
- member_filter = StateFilter({EventTypes.Member: state_keys})
+ member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others:
member_filter = StateFilter.all()
else:
member_filter = StateFilter.none()
non_member_filter = StateFilter(
- types={k: v for k, v in self.types.items() if k != EventTypes.Member},
+ types=frozendict(
+ {k: v for k, v in self.types.items() if k != EventTypes.Member}
+ ),
include_others=self.include_others,
)
@@ -358,7 +377,8 @@ class StateGroupStorage:
make up the delta between the old and new state groups.
"""
- return await self.stores.state.get_state_group_delta(state_group)
+ state_group_delta = await self.stores.state.get_state_group_delta(state_group)
+ return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 5e83dba2..806b6713 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -11,3 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from typing import Collection, Generic, List, Optional, Tuple, TypeVar
+
+from synapse.types import UserID
+
+# The key, this is either a stream token or int.
+K = TypeVar("K")
+# The return type.
+R = TypeVar("R")
+
+
+class EventSource(Generic[K, R]):
+ async def get_new_events(
+ self,
+ user: UserID,
+ from_key: K,
+ limit: Optional[int],
+ room_ids: Collection[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
+ ) -> Tuple[List[R], K]:
+ ...
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index cf400598..c08d591f 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -81,7 +81,7 @@ class PaginationConfig:
raise SynapseError(400, "Invalid request.")
def __repr__(self) -> str:
- return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
+ return "PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)" % (
self.from_token,
self.to_token,
self.direction,
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 99b0aac2..21591d0b 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -12,29 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Iterator, Tuple
+
+import attr
from synapse.handlers.account_data import AccountDataEventSource
from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
+from synapse.streams import EventSource
from synapse.types import StreamToken
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
-class EventSources:
- SOURCE_TYPES = {
- "room": RoomEventSource,
- "presence": PresenceEventSource,
- "typing": TypingNotificationEventSource,
- "receipt": ReceiptEventSource,
- "account_data": AccountDataEventSource,
- }
- def __init__(self, hs):
- self.sources: Dict[str, Any] = {
- name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
- }
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _EventSourcesInner:
+ room: RoomEventSource
+ presence: PresenceEventSource
+ typing: TypingNotificationEventSource
+ receipt: ReceiptEventSource
+ account_data: AccountDataEventSource
+
+ def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
+ for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined]
+ yield attribute.name, getattr(self, attribute.name)
+
+
+class EventSources:
+ def __init__(self, hs: "HomeServer"):
+ self.sources = _EventSourcesInner(
+ *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined]
+ )
self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken:
@@ -44,11 +55,11 @@ class EventSources:
groups_key = self.store.get_group_stream_token()
token = StreamToken(
- room_key=self.sources["room"].get_current_key(),
- presence_key=self.sources["presence"].get_current_key(),
- typing_key=self.sources["typing"].get_current_key(),
- receipt_key=self.sources["receipt"].get_current_key(),
- account_data_key=self.sources["account_data"].get_current_key(),
+ room_key=self.sources.room.get_current_key(),
+ presence_key=self.sources.presence.get_current_key(),
+ typing_key=self.sources.typing.get_current_key(),
+ receipt_key=self.sources.receipt.get_current_key(),
+ account_data_key=self.sources.account_data.get_current_key(),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
@@ -67,7 +78,7 @@ class EventSources:
The current token for pagination.
"""
token = StreamToken(
- room_key=self.sources["room"].get_current_key(),
+ room_key=self.sources.room.get_current_key(),
presence_key=0,
typing_key=0,
receipt_key=0,
diff --git a/synapse/types.py b/synapse/types.py
index d4759b2d..364ecf7d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -30,6 +30,7 @@ from typing import (
)
import attr
+from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@@ -457,6 +458,9 @@ class RoomStreamToken:
Note: The `RoomStreamToken` cannot have both a topological part and an
instance map.
+
+ For caching purposes, `RoomStreamToken`s and by extension, all their
+ attributes, must be hashable.
"""
topological = attr.ib(
@@ -466,12 +470,12 @@ class RoomStreamToken:
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
instance_map = attr.ib(
- type=Dict[str, int],
- factory=dict,
+ type="frozendict[str, int]",
+ factory=frozendict,
validator=attr.validators.deep_mapping(
key_validator=attr.validators.instance_of(str),
value_validator=attr.validators.instance_of(int),
- mapping_validator=attr.validators.instance_of(dict),
+ mapping_validator=attr.validators.instance_of(frozendict),
),
)
@@ -507,11 +511,11 @@ class RoomStreamToken:
return cls(
topological=None,
stream=stream,
- instance_map=instance_map,
+ instance_map=frozendict(instance_map),
)
except Exception:
pass
- raise SynapseError(400, "Invalid token %r" % (string,))
+ raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod
def parse_stream_token(cls, string: str) -> "RoomStreamToken":
@@ -520,7 +524,7 @@ class RoomStreamToken:
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
- raise SynapseError(400, "Invalid token %r" % (string,))
+ raise SynapseError(400, "Invalid room stream token %r" % (string,))
def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
"""Return a new token such that if an event is after both this token and
@@ -540,7 +544,7 @@ class RoomStreamToken:
for instance in set(self.instance_map).union(other.instance_map)
}
- return RoomStreamToken(None, max_stream, instance_map)
+ return RoomStreamToken(None, max_stream, frozendict(instance_map))
def as_historical_tuple(self) -> Tuple[int, int]:
"""Returns a tuple of `(topological, stream)` for historical tokens.
@@ -552,7 +556,7 @@ class RoomStreamToken:
"Cannot call `RoomStreamToken.as_historical_tuple` on live token"
)
- return (self.topological, self.stream)
+ return self.topological, self.stream
def get_stream_pos_for_instance(self, instance_name: str) -> int:
"""Get the stream position that the given writer was at at this token.
@@ -593,6 +597,12 @@ class RoomStreamToken:
@attr.s(slots=True, frozen=True)
class StreamToken:
+ """A collection of positions within multiple streams.
+
+ For caching purposes, `StreamToken`s and by extension, all their attributes,
+ must be hashable.
+ """
+
room_key = attr.ib(
type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
)
@@ -619,7 +629,7 @@ class StreamToken:
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
)
except Exception:
- raise SynapseError(400, "Invalid Token")
+ raise SynapseError(400, "Invalid stream token")
async def to_string(self, store: "DataStore") -> str:
return self._SEPARATOR.join(
@@ -756,7 +766,7 @@ def get_verify_key_from_cross_signing_key(key_info):
raise ValueError("Invalid key")
# and return that one key
for key_id, key_data in keys.items():
- return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)))
+ return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
@attr.s(auto_attribs=True, frozen=True, slots=True)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index cab1bf0c..df4d61e4 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -12,8 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import collections
import logging
+import typing
+from enum import Enum, auto
from sys import intern
from typing import Callable, Dict, Optional, Sized
@@ -34,7 +36,7 @@ collectors_by_name: Dict[str, "CacheMetric"] = {}
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
-cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
+cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name", "reason"])
cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
cache_memory_usage = Gauge(
@@ -46,11 +48,16 @@ cache_memory_usage = Gauge(
response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
response_cache_evicted = Gauge(
- "synapse_util_caches_response_cache:evicted_size", "", ["name"]
+ "synapse_util_caches_response_cache:evicted_size", "", ["name", "reason"]
)
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+class EvictionReason(Enum):
+ size = auto()
+ time = auto()
+
+
@attr.s(slots=True)
class CacheMetric:
@@ -61,7 +68,9 @@ class CacheMetric:
hits = attr.ib(default=0)
misses = attr.ib(default=0)
- evicted_size = attr.ib(default=0)
+ eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
+ factory=collections.Counter
+ )
memory_usage = attr.ib(default=None)
def inc_hits(self) -> None:
@@ -70,8 +79,8 @@ class CacheMetric:
def inc_misses(self) -> None:
self.misses += 1
- def inc_evictions(self, size: int = 1) -> None:
- self.evicted_size += size
+ def inc_evictions(self, reason: EvictionReason, size: int = 1) -> None:
+ self.eviction_size_by_reason[reason] += size
def inc_memory_usage(self, memory: int) -> None:
if self.memory_usage is None:
@@ -94,14 +103,20 @@ class CacheMetric:
if self._cache_type == "response_cache":
response_cache_size.labels(self._cache_name).set(len(self._cache))
response_cache_hits.labels(self._cache_name).set(self.hits)
- response_cache_evicted.labels(self._cache_name).set(self.evicted_size)
+ for reason in EvictionReason:
+ response_cache_evicted.labels(self._cache_name, reason.name).set(
+ self.eviction_size_by_reason[reason]
+ )
response_cache_total.labels(self._cache_name).set(
self.hits + self.misses
)
else:
cache_size.labels(self._cache_name).set(len(self._cache))
cache_hits.labels(self._cache_name).set(self.hits)
- cache_evicted.labels(self._cache_name).set(self.evicted_size)
+ for reason in EvictionReason:
+ cache_evicted.labels(self._cache_name, reason.name).set(
+ self.eviction_size_by_reason[reason]
+ )
cache_total.labels(self._cache_name).set(self.hits + self.misses)
if getattr(self._cache, "max_size", None):
cache_max_size.labels(self._cache_name).set(self._cache.max_size)
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index f05590da..6262efe0 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -73,6 +73,7 @@ class DeferredCache(Generic[KT, VT]):
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
+ prune_unread_entries: bool = True,
):
"""
Args:
@@ -105,6 +106,7 @@ class DeferredCache(Generic[KT, VT]):
size_callback=(lambda d: len(d) or 1) if iterable else None,
metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
+ prune_unread_entries=prune_unread_entries,
)
self.thread: Optional[threading.Thread] = None
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1ca31e41..b9dcca17 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -258,6 +258,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
tree=False,
cache_context=False,
iterable=False,
+ prune_unread_entries: bool = True,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
@@ -269,6 +270,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.max_entries = max_entries
self.tree = tree
self.iterable = iterable
+ self.prune_unread_entries = prune_unread_entries
def __get__(self, obj, owner):
cache: DeferredCache[CacheKey, Any] = DeferredCache(
@@ -276,6 +278,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
+ prune_unread_entries=self.prune_unread_entries,
)
get_cache_key = self.cache_key_builder
@@ -507,6 +510,7 @@ def cached(
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
+ prune_unread_entries: bool = True,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
@@ -515,6 +519,7 @@ def cached(
tree=tree,
cache_context=cache_context,
iterable=iterable,
+ prune_unread_entries=prune_unread_entries,
)
return cast(Callable[[F], _CachedFunction[F]], func)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index ade088aa..485ddb18 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -130,7 +130,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
sequence: int,
key: KT,
value: Dict[DKT, DV],
- fetched_keys: Optional[Set[DKT]] = None,
+ fetched_keys: Optional[Iterable[DKT]] = None,
) -> None:
"""Updates the entry in the cache
@@ -155,7 +155,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(
- self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
+ self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index bde16b85..c3f72aa0 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -22,7 +22,7 @@ from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
-from synapse.util.caches import register_cache
+from synapse.util.caches import EvictionReason, register_cache
logger = logging.getLogger(__name__)
@@ -98,9 +98,9 @@ class ExpiringCache(Generic[KT, VT]):
while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
if self.iterable:
- self.metrics.inc_evictions(len(value.value))
+ self.metrics.inc_evictions(EvictionReason.size, len(value.value))
else:
- self.metrics.inc_evictions()
+ self.metrics.inc_evictions(EvictionReason.size)
def __getitem__(self, key: KT) -> VT:
try:
@@ -175,9 +175,9 @@ class ExpiringCache(Generic[KT, VT]):
for k in keys_to_delete:
value = self._cache.pop(k)
if self.iterable:
- self.metrics.inc_evictions(len(value.value))
+ self.metrics.inc_evictions(EvictionReason.time, len(value.value))
else:
- self.metrics.inc_evictions()
+ self.metrics.inc_evictions(EvictionReason.time)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 39dce9dd..4ff62b40 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -40,7 +40,7 @@ from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.util import Clock, caches
-from synapse.util.caches import CacheMetric, register_cache
+from synapse.util.caches import CacheMetric, EvictionReason, register_cache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.linked_list import ListNode
@@ -202,10 +202,11 @@ class _Node:
cache: "weakref.ReferenceType[LruCache]",
clock: Clock,
callbacks: Collection[Callable[[], None]] = (),
+ prune_unread_entries: bool = True,
):
self._list_node = ListNode.insert_after(self, root)
- self._global_list_node = None
- if USE_GLOBAL_LIST:
+ self._global_list_node: Optional[_TimedListNode] = None
+ if USE_GLOBAL_LIST and prune_unread_entries:
self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT)
self._global_list_node.update_last_access(clock)
@@ -314,6 +315,7 @@ class LruCache(Generic[KT, VT]):
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None,
+ prune_unread_entries: bool = True,
):
"""
Args:
@@ -403,7 +405,7 @@ class LruCache(Generic[KT, VT]):
evicted_len = delete_node(node)
cache.pop(node.key, None)
if metrics:
- metrics.inc_evictions(evicted_len)
+ metrics.inc_evictions(EvictionReason.size, evicted_len)
def synchronized(f: FT) -> FT:
@wraps(f)
@@ -427,7 +429,15 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
- node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks)
+ node = _Node(
+ list_root,
+ key,
+ value,
+ weak_ref_to_self,
+ real_clock,
+ callbacks,
+ prune_unread_entries,
+ )
cache[key] = node
if size_callback:
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 8ac3eab2..4938ddf7 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -21,13 +21,28 @@ from typing import (
Iterable,
Iterator,
Mapping,
- Sequence,
Set,
+ Sized,
Tuple,
TypeVar,
)
+from typing_extensions import Protocol
+
T = TypeVar("T")
+S = TypeVar("S", bound="_SelfSlice")
+
+
+class _SelfSlice(Sized, Protocol):
+ """A helper protocol that matches types where taking a slice results in the
+ same type being returned.
+
+ This is more specific than `Sequence`, which allows another `Sequence` to be
+ returned.
+ """
+
+ def __getitem__(self: S, i: slice) -> S:
+ ...
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
@@ -46,7 +61,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
return iter(lambda: tuple(islice(sourceiter, size)), ())
-def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]:
+def chunk_seq(iseq: S, maxlen: int) -> Iterator[S]:
"""Split the given sequence into chunks of the given size
The last chunk may be shorter than the given size.