summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-01-13 14:00:16 +0100
committerAndrej Shadura <andrewsh@debian.org>2021-01-13 14:00:16 +0100
commit56044cac92cdd65dc3b4fd03557eaf32976e6da9 (patch)
treecaca2f4e58b83affd235455f61b1bc84bc202816 /synapse
parentf509bf3ab10e82c5ba6e4e3b5a7db0f9c55026c1 (diff)
New upstream version 1.25.0
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py9
-rw-r--r--synapse/api/auth_blocking.py7
-rw-r--r--synapse/api/constants.py9
-rw-r--r--synapse/app/_base.py6
-rw-r--r--synapse/app/generic_worker.py32
-rw-r--r--synapse/app/homeserver.py48
-rw-r--r--synapse/config/_base.py14
-rw-r--r--synapse/config/_base.pyi11
-rw-r--r--synapse/config/_util.py35
-rw-r--r--synapse/config/auth.py (renamed from synapse/config/password.py)26
-rw-r--r--synapse/config/emailconfig.py27
-rw-r--r--synapse/config/federation.py43
-rw-r--r--synapse/config/groups.py2
-rw-r--r--synapse/config/homeserver.py4
-rw-r--r--synapse/config/logger.py2
-rw-r--r--synapse/config/oidc_config.py7
-rw-r--r--synapse/config/password_auth_providers.py5
-rw-r--r--synapse/config/repository.py26
-rw-r--r--synapse/config/room_directory.py2
-rw-r--r--synapse/config/saml2_config.py2
-rw-r--r--synapse/config/server.py92
-rw-r--r--synapse/config/spam_checker.py9
-rw-r--r--synapse/config/sso.py7
-rw-r--r--synapse/config/third_party_event_rules.py4
-rw-r--r--synapse/config/workers.py10
-rw-r--r--synapse/crypto/context_factory.py2
-rw-r--r--synapse/crypto/event_signing.py29
-rw-r--r--synapse/crypto/keyring.py210
-rw-r--r--synapse/events/spamcheck.py55
-rw-r--r--synapse/federation/federation_base.py7
-rw-r--r--synapse/federation/federation_server.py1
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py4
-rw-r--r--synapse/handlers/_base.py4
-rw-r--r--synapse/handlers/admin.py63
-rw-r--r--synapse/handlers/auth.py150
-rw-r--r--synapse/handlers/cas_handler.py311
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/federation.py4
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py11
-rw-r--r--synapse/handlers/initial_sync.py4
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/handlers/oidc_handler.py157
-rw-r--r--synapse/handlers/receipts.py37
-rw-r--r--synapse/handlers/register.py9
-rw-r--r--synapse/handlers/room.py16
-rw-r--r--synapse/handlers/room_list.py101
-rw-r--r--synapse/handlers/room_member.py25
-rw-r--r--synapse/handlers/saml_handler.py135
-rw-r--r--synapse/handlers/sso.py411
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/handlers/user_directory.py85
-rw-r--r--synapse/http/client.py79
-rw-r--r--synapse/http/federation/matrix_federation_agent.py16
-rw-r--r--synapse/http/federation/well_known_resolver.py25
-rw-r--r--synapse/http/matrixfederationclient.py39
-rw-r--r--synapse/http/proxyagent.py16
-rw-r--r--synapse/http/server.py8
-rw-r--r--synapse/http/site.py3
-rw-r--r--synapse/logging/context.py24
-rw-r--r--synapse/metrics/background_process_metrics.py16
-rw-r--r--synapse/notifier.py6
-rw-r--r--synapse/push/__init__.py108
-rw-r--r--synapse/push/action_generator.py15
-rw-r--r--synapse/push/baserules.py23
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py98
-rw-r--r--synapse/push/clientformat.py23
-rw-r--r--synapse/push/emailpusher.py114
-rw-r--r--synapse/push/httppusher.py127
-rw-r--r--synapse/push/mailer.py129
-rw-r--r--synapse/push/presentable_names.py48
-rw-r--r--synapse/push/push_rule_evaluator.py28
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/push/pusher.py34
-rw-r--r--synapse/push/pusherpool.py172
-rw-r--r--synapse/replication/http/_base.py47
-rw-r--r--synapse/replication/http/login.py12
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py20
-rw-r--r--synapse/replication/slave/storage/pushers.py17
-rw-r--r--synapse/replication/tcp/protocol.py3
-rw-r--r--synapse/res/templates/notif.html2
-rw-r--r--synapse/res/username_picker/index.html19
-rw-r--r--synapse/res/username_picker/script.js95
-rw-r--r--synapse/res/username_picker/style.css27
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/rooms.py184
-rw-r--r--synapse/rest/admin/users.py23
-rw-r--r--synapse/rest/client/v1/login.py25
-rw-r--r--synapse/rest/client/v1/pusher.py15
-rw-r--r--synapse/rest/client/v1/room.py17
-rw-r--r--synapse/rest/client/v2_alpha/account.py10
-rw-r--r--synapse/rest/client/v2_alpha/groups.py48
-rw-r--r--synapse/rest/client/v2_alpha/register.py16
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py3
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
-rw-r--r--synapse/rest/media/v1/_base.py5
-rw-r--r--synapse/rest/media/v1/media_repository.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py6
-rw-r--r--synapse/rest/media/v1/storage_provider.py16
-rw-r--r--synapse/rest/media/v1/upload_resource.py2
-rw-r--r--synapse/rest/synapse/client/pick_username.py88
-rw-r--r--synapse/server.py39
-rw-r--r--synapse/state/__init__.py4
-rw-r--r--synapse/state/v2.py92
-rw-r--r--synapse/storage/__init__.py9
-rw-r--r--synapse/storage/_base.py36
-rw-r--r--synapse/storage/background_updates.py111
-rw-r--r--synapse/storage/databases/main/__init__.py10
-rw-r--r--synapse/storage/databases/main/client_ips.py7
-rw-r--r--synapse/storage/databases/main/devices.py32
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py10
-rw-r--r--synapse/storage/databases/main/keys.py10
-rw-r--r--synapse/storage/databases/main/pusher.py93
-rw-r--r--synapse/storage/databases/main/registration.py63
-rw-r--r--synapse/storage/databases/main/room.py4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/27local_invites.sql18
-rw-r--r--synapse/storage/databases/main/user_directory.py31
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/persist_events.py200
-rw-r--r--synapse/storage/prepare_database.py104
-rw-r--r--synapse/storage/purge_events.py11
-rw-r--r--synapse/storage/relations.py44
-rw-r--r--synapse/storage/state.py35
-rw-r--r--synapse/storage/util/id_generators.py4
-rw-r--r--synapse/types.py8
-rw-r--r--synapse/util/async_helpers.py8
-rw-r--r--synapse/util/distributor.py7
-rw-r--r--synapse/util/module_loader.py64
-rw-r--r--synapse/visibility.py44
134 files changed, 3743 insertions, 1558 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index f2d3ac68..193adca6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.24.0"
+__version__ = "1.25.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index bfcaf68b..48c4d7b0 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,7 @@ from twisted.web.server import Request
import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -31,7 +31,9 @@ from synapse.api.errors import (
MissingClientTokenError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.appservice import ApplicationService
from synapse.events import EventBase
+from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
@@ -474,7 +476,7 @@ class Auth:
now = self.hs.get_clock().time_msec()
return now < expiry
- def get_appservice_by_req(self, request):
+ def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
@@ -646,7 +648,8 @@ class Auth:
)
if (
visibility
- and visibility.content["history_visibility"] == "world_readable"
+ and visibility.content.get("history_visibility")
+ == HistoryVisibility.WORLD_READABLE
):
return Membership.JOIN, None
raise AuthError(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 9c227218..d8088f52 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -36,6 +36,7 @@ class AuthBlocking:
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname
+ self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
async def check_auth_blocking(
self,
@@ -76,6 +77,12 @@ class AuthBlocking:
# We never block the server from doing actions on behalf of
# users.
return
+ elif requester.app_service and not self._track_appservice_user_ips:
+ # If we're authenticated as an appservice then we only block
+ # auth if `track_appservice_user_ips` is set, as that option
+ # implicitly means that application services are part of MAU
+ # limits.
+ return
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 592abd84..565a8cd7 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -95,6 +95,8 @@ class EventTypes:
Presence = "m.presence"
+ Dummy = "org.matrix.dummy_event"
+
class RejectedReason:
AUTH_ERROR = "auth_error"
@@ -160,3 +162,10 @@ class RoomEncryptionAlgorithms:
class AccountDataTypes:
DIRECT = "m.direct"
IGNORED_USER_LIST = "m.ignored_user_list"
+
+
+class HistoryVisibility:
+ INVITED = "invited"
+ JOINED = "joined"
+ SHARED = "shared"
+ WORLD_READABLE = "world_readable"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 895b38ae..37ecdbe3 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -245,6 +245,8 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
+ reactor = hs.get_reactor()
+
@wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
@@ -260,7 +262,9 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry.
def run_sighup(*args, **kwargs):
- hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
+ # `callFromThread` should be "signal safe" as well as thread
+ # safe.
+ reactor.callFromThread(handle_sighup, *args, **kwargs)
signal.signal(signal.SIGHUP, run_sighup)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1b511890..fa23d9bb 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -89,7 +89,7 @@ from synapse.replication.tcp.streams import (
ToDeviceStream,
)
from synapse.rest.admin import register_servlets_for_media_repo
-from synapse.rest.client.v1 import events
+from synapse.rest.client.v1 import events, room
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.rest.client.v1.login import LoginRestServlet
from synapse.rest.client.v1.profile import (
@@ -98,20 +98,6 @@ from synapse.rest.client.v1.profile import (
ProfileRestServlet,
)
from synapse.rest.client.v1.push_rule import PushRuleRestServlet
-from synapse.rest.client.v1.room import (
- JoinedRoomMemberListRestServlet,
- JoinRoomAliasServlet,
- PublicRoomListRestServlet,
- RoomEventContextServlet,
- RoomInitialSyncRestServlet,
- RoomMemberListRestServlet,
- RoomMembershipRestServlet,
- RoomMessageListRestServlet,
- RoomSendEventRestServlet,
- RoomStateEventRestServlet,
- RoomStateRestServlet,
- RoomTypingRestServlet,
-)
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -266,7 +252,6 @@ class GenericWorkerPresence(BasePresenceHandler):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
- self.http_client = hs.get_simple_http_client()
self._presence_enabled = hs.config.use_presence
@@ -513,12 +498,6 @@ class GenericWorkerServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
- PublicRoomListRestServlet(self).register(resource)
- RoomMemberListRestServlet(self).register(resource)
- JoinedRoomMemberListRestServlet(self).register(resource)
- RoomStateRestServlet(self).register(resource)
- RoomEventContextServlet(self).register(resource)
- RoomMessageListRestServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource)
@@ -527,22 +506,19 @@ class GenericWorkerServer(HomeServer):
VoipRestServlet(self).register(resource)
PushRuleRestServlet(self).register(resource)
VersionsRestServlet(self).register(resource)
- RoomSendEventRestServlet(self).register(resource)
- RoomMembershipRestServlet(self).register(resource)
- RoomStateEventRestServlet(self).register(resource)
- JoinRoomAliasServlet(self).register(resource)
+
ProfileAvatarURLRestServlet(self).register(resource)
ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource)
KeyUploadServlet(self).register(resource)
AccountDataServlet(self).register(resource)
RoomAccountDataServlet(self).register(resource)
- RoomTypingRestServlet(self).register(resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
+ room.register_servlets(self, resource, True)
+ room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
- RoomInitialSyncRestServlet(self).register(resource)
user_directory.register_servlets(self, resource)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2b546541..8d9b53be 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -19,7 +19,7 @@ import gc
import logging
import os
import sys
-from typing import Iterable
+from typing import Iterable, Iterator
from twisted.application import service
from twisted.internet import defer, reactor
@@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore
@@ -90,7 +91,7 @@ class SynapseHomeServer(HomeServer):
tls = listener_config.tls
site_tag = listener_config.http_options.tag
if site_tag is None:
- site_tag = port
+ site_tag = str(port)
# We always include a health resource.
resources = {"/health": HealthResource()}
@@ -107,7 +108,10 @@ class SynapseHomeServer(HomeServer):
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = self.get_module_api()
for path, resmodule in additional_resources.items():
- handler_cls, config = load_module(resmodule)
+ handler_cls, config = load_module(
+ resmodule,
+ ("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
+ )
handler = handler_cls(config, module_api)
if IResource.providedBy(handler):
resource = handler
@@ -189,6 +193,7 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self),
"/_synapse/admin": AdminRestResource(self),
+ "/_synapse/client/pick_username": pick_username_resource(self),
}
)
@@ -342,7 +347,10 @@ def setup(config_options):
"Synapse Homeserver", config_options
)
except ConfigError as e:
- sys.stderr.write("\nERROR: %s\n" % (e,))
+ sys.stderr.write("\n")
+ for f in format_config_error(e):
+ sys.stderr.write(f)
+ sys.stderr.write("\n")
sys.exit(1)
if not config:
@@ -445,6 +453,38 @@ def setup(config_options):
return hs
+def format_config_error(e: ConfigError) -> Iterator[str]:
+ """
+ Formats a config error neatly
+
+ The idea is to format the immediate error, plus the "causes" of those errors,
+ hopefully in a way that makes sense to the user. For example:
+
+ Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
+ Failed to parse config for module 'JinjaOidcMappingProvider':
+ invalid jinja template:
+ unexpected end of template, expected 'end of print statement'.
+
+ Args:
+ e: the error to be formatted
+
+ Returns: An iterator which yields string fragments to be formatted
+ """
+ yield "Error in configuration"
+
+ if e.path:
+ yield " at '%s'" % (".".join(e.path),)
+
+ yield ":\n %s" % (e.msg,)
+
+ e = e.__cause__
+ indent = 1
+ while e:
+ indent += 1
+ yield ":\n%s%s" % (" " * indent, str(e))
+ e = e.__cause__
+
+
class SynapseService(service.Service):
"""
A twisted Service class that will start synapse. Used to run synapse
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 85f65da4..2931a882 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -23,7 +23,7 @@ import urllib.parse
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
-from typing import Any, Callable, List, MutableMapping, Optional
+from typing import Any, Callable, Iterable, List, MutableMapping, Optional
import attr
import jinja2
@@ -32,7 +32,17 @@ import yaml
class ConfigError(Exception):
- pass
+ """Represents a problem parsing the configuration
+
+ Args:
+ msg: A textual description of the error.
+ path: Where appropriate, an indication of where in the configuration
+ the problem lies.
+ """
+
+ def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+ self.msg = msg
+ self.path = path
# We split these messages out to allow packages to override with package
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index b8faafa9..29aa064e 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,8 +1,9 @@
-from typing import Any, List, Optional
+from typing import Any, Iterable, List, Optional
from synapse.config import (
api,
appservice,
+ auth,
captcha,
cas,
consent_config,
@@ -14,7 +15,6 @@ from synapse.config import (
logger,
metrics,
oidc_config,
- password,
password_auth_providers,
push,
ratelimiting,
@@ -35,7 +35,10 @@ from synapse.config import (
workers,
)
-class ConfigError(Exception): ...
+class ConfigError(Exception):
+ def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+ self.msg = msg
+ self.path = path
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
@@ -62,7 +65,7 @@ class RootConfig:
sso: sso.SSOConfig
oidc: oidc_config.OIDCConfig
jwt: jwt_config.JWTConfig
- password: password.PasswordConfig
+ auth: auth.AuthConfig
email: emailconfig.EmailConfig
worker: workers.WorkerConfig
authproviders: password_auth_providers.PasswordAuthProviderConfig
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index c74969a9..1bbe83c3 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -38,14 +38,27 @@ def validate_config(
try:
jsonschema.validate(config, json_schema)
except jsonschema.ValidationError as e:
- # copy `config_path` before modifying it.
- path = list(config_path)
- for p in list(e.path):
- if isinstance(p, int):
- path.append("<item %i>" % p)
- else:
- path.append(str(p))
-
- raise ConfigError(
- "Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
- )
+ raise json_error_to_config_error(e, config_path)
+
+
+def json_error_to_config_error(
+ e: jsonschema.ValidationError, config_path: Iterable[str]
+) -> ConfigError:
+ """Converts a json validation error to a user-readable ConfigError
+
+ Args:
+ e: the exception to be converted
+ config_path: the path within the config file. This will be used as a basis
+ for the error message.
+
+ Returns:
+ a ConfigError
+ """
+ # copy `config_path` before modifying it.
+ path = list(config_path)
+ for p in list(e.path):
+ if isinstance(p, int):
+ path.append("<item %i>" % p)
+ else:
+ path.append(str(p))
+ return ConfigError(e.message, path)
diff --git a/synapse/config/password.py b/synapse/config/auth.py
index 9c0ea8c3..2b3e2ce8 100644
--- a/synapse/config/password.py
+++ b/synapse/config/auth.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,11 +17,11 @@
from ._base import Config
-class PasswordConfig(Config):
- """Password login configuration
+class AuthConfig(Config):
+ """Password and login configuration
"""
- section = "password"
+ section = "auth"
def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
@@ -35,6 +36,10 @@ class PasswordConfig(Config):
self.password_policy = password_config.get("policy") or {}
self.password_policy_enabled = self.password_policy.get("enabled", False)
+ # User-interactive authentication
+ ui_auth = config.get("ui_auth") or {}
+ self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0)
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -87,4 +92,19 @@ class PasswordConfig(Config):
# Defaults to 'false'.
#
#require_uppercase: true
+
+ ui_auth:
+ # The number of milliseconds to allow a user-interactive authentication
+ # session to be active.
+ #
+ # This defaults to 0, meaning the user is queried for their credentials
+ # before every action, but this can be overridden to alow a single
+ # validation to be re-used. This weakens the protections afforded by
+ # the user-interactive authentication process, by allowing for multiple
+ # (and potentially different) operations to use the same validation session.
+ #
+ # Uncomment below to allow for credential validation to last for 15
+ # seconds.
+ #
+ #session_timeout: 15000
"""
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index cceffbfe..d4328c46 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -322,6 +322,22 @@ class EmailConfig(Config):
self.email_subjects = EmailSubjectConfig(**subjects)
+ # The invite client location should be a HTTP(S) URL or None.
+ self.invite_client_location = email_config.get("invite_client_location") or None
+ if self.invite_client_location:
+ if not isinstance(self.invite_client_location, str):
+ raise ConfigError(
+ "Config option email.invite_client_location must be type str"
+ )
+ if not (
+ self.invite_client_location.startswith("http://")
+ or self.invite_client_location.startswith("https://")
+ ):
+ raise ConfigError(
+ "Config option email.invite_client_location must be a http or https URL",
+ path=("email", "invite_client_location"),
+ )
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return (
"""\
@@ -389,10 +405,15 @@ class EmailConfig(Config):
#
#validation_token_lifetime: 15m
- # Directory in which Synapse will try to find the template files below.
- # If not set, default templates from within the Synapse package will be used.
+ # The web client location to direct users to during an invite. This is passed
+ # to the identity server as the org.matrix.web_client_location key. Defaults
+ # to unset, giving no guidance to the identity server.
#
- # Do not uncomment this setting unless you want to customise the templates.
+ #invite_client_location: https://app.element.io
+
+ # Directory in which Synapse will try to find the template files below.
+ # If not set, or the files named below are not found within the template
+ # directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index ffd8fca5..9f3c57e6 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -12,12 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Optional
-from netaddr import IPSet
-
-from synapse.config._base import Config, ConfigError
+from synapse.config._base import Config
from synapse.config._util import validate_config
@@ -36,23 +33,6 @@ class FederationConfig(Config):
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
- self.federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", []
- )
-
- # Attempt to create an IPSet from the given ranges
- try:
- self.federation_ip_range_blacklist = IPSet(
- self.federation_ip_range_blacklist
- )
-
- # Always blacklist 0.0.0.0, ::
- self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
- except Exception as e:
- raise ConfigError(
- "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
- )
-
federation_metrics_domains = config.get("federation_metrics_domains") or []
validate_config(
_METRICS_FOR_DOMAINS_SCHEMA,
@@ -76,27 +56,6 @@ class FederationConfig(Config):
# - nyc.example.com
# - syd.example.com
- # Prevent federation requests from being sent to the following
- # blacklist IP address CIDR ranges. If this option is not specified, or
- # specified with an empty list, no ip range blacklist will be enforced.
- #
- # As of Synapse v1.4.0 this option also affects any outbound requests to identity
- # servers provided by user input.
- #
- # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
- # listed here, since they correspond to unroutable addresses.)
- #
- federation_ip_range_blacklist:
- - '127.0.0.0/8'
- - '10.0.0.0/8'
- - '172.16.0.0/12'
- - '192.168.0.0/16'
- - '100.64.0.0/10'
- - '169.254.0.0/16'
- - '::1/128'
- - 'fe80::/64'
- - 'fc00::/7'
-
# Report prometheus metrics on the age of PDUs being sent to and received from
# the following domains. This can be used to give an idea of "delay" on inbound
# and outbound federation, though be aware that any delay can be due to problems
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index d6862d9a..7b7860ea 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -32,5 +32,5 @@ class GroupsConfig(Config):
# If enabled, non server admins can only create groups with local parts
# starting with this prefix
#
- #group_creation_prefix: "unofficial/"
+ #group_creation_prefix: "unofficial_"
"""
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index be655545..4bd2b358 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -17,6 +17,7 @@
from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
+from .auth import AuthConfig
from .cache import CacheConfig
from .captcha import CaptchaConfig
from .cas import CasConfig
@@ -30,7 +31,6 @@ from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
from .oidc_config import OIDCConfig
-from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
from .ratelimiting import RatelimitConfig
@@ -76,7 +76,7 @@ class HomeServerConfig(RootConfig):
CasConfig,
SSOConfig,
JWTConfig,
- PasswordConfig,
+ AuthConfig,
EmailConfig,
PasswordAuthProviderConfig,
PushConfig,
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index d4e887a3..4df3f93c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
- log_context_filter = LoggingContextFilter(request="")
+ log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory()
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 69d18834..4e305528 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -66,7 +66,7 @@ class OIDCConfig(Config):
(
self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config,
- ) = load_module(ump_config)
+ ) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
@@ -203,9 +203,10 @@ class OIDCConfig(Config):
# * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token
#
- # This must be configured if using the default mapping provider.
+ # If this is not set, the user will be prompted to choose their
+ # own username.
#
- localpart_template: "{{{{ user.preferred_username }}}}"
+ #localpart_template: "{{{{ user.preferred_username }}}}"
# Jinja2 template for the display name to set on first login.
#
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 4fda8ae9..85d07c4f 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config):
providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
providers.extend(config.get("password_providers") or [])
- for provider in providers:
+ for i, provider in enumerate(providers):
mod_name = provider["module"]
# This is for backwards compat when the ldap auth provider resided
@@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config):
mod_name = LDAP_PROVIDER
(provider_class, provider_config) = load_module(
- {"module": mod_name, "config": provider["config"]}
+ {"module": mod_name, "config": provider["config"]},
+ ("password_providers", "<item %i>" % i),
)
self.password_providers.append((provider_class, provider_config))
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index ba1e9d23..850ac3eb 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -17,6 +17,9 @@ import os
from collections import namedtuple
from typing import Dict, List
+from netaddr import IPSet
+
+from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
@@ -142,7 +145,7 @@ class ContentRepositoryConfig(Config):
# them to be started.
self.media_storage_providers = [] # type: List[tuple]
- for provider_config in storage_providers:
+ for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend
if provider_config["module"] == "file_system":
@@ -151,7 +154,9 @@ class ContentRepositoryConfig(Config):
".FileStorageProviderBackend"
)
- provider_class, parsed_config = load_module(provider_config)
+ provider_class, parsed_config = load_module(
+ provider_config, ("media_storage_providers", "<item %i>" % i)
+ )
wrapper_config = MediaStorageProviderConfig(
provider_config.get("store_local", False),
@@ -182,9 +187,6 @@ class ContentRepositoryConfig(Config):
"to work"
)
- # netaddr is a dependency for url_preview
- from netaddr import IPSet
-
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
@@ -213,6 +215,10 @@ class ContentRepositoryConfig(Config):
# strip final NL
formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
+ ip_range_blacklist = "\n".join(
+ " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
+ )
+
return (
r"""
## Media Store ##
@@ -283,15 +289,7 @@ class ContentRepositoryConfig(Config):
# you uncomment the following list as a starting point.
#
#url_preview_ip_range_blacklist:
- # - '127.0.0.0/8'
- # - '10.0.0.0/8'
- # - '172.16.0.0/12'
- # - '192.168.0.0/16'
- # - '100.64.0.0/10'
- # - '169.254.0.0/16'
- # - '::1/128'
- # - 'fe80::/64'
- # - 'fc00::/7'
+%(ip_range_blacklist)s
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 92e1b675..9a3e1c3e 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -180,7 +180,7 @@ class _RoomDirectoryRule:
self._alias_regex = glob_to_regex(alias)
self._room_id_regex = glob_to_regex(room_id)
except Exception as e:
- raise ConfigError("Failed to parse glob into regex: %s", e)
+ raise ConfigError("Failed to parse glob into regex") from e
def matches(self, user_id, room_id, aliases):
"""Tests if this rule matches the given user_id, room_id and aliases.
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c1b8e98a..7b97d4f1 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -125,7 +125,7 @@ class SAML2Config(Config):
(
self.saml2_user_mapping_provider_class,
self.saml2_user_mapping_provider_config,
- ) = load_module(ump_dict)
+ ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 85aa49c0..7242a4aa 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set
import attr
import yaml
+from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
@@ -39,6 +40,34 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
+DEFAULT_IP_RANGE_BLACKLIST = [
+ # Localhost
+ "127.0.0.0/8",
+ # Private networks.
+ "10.0.0.0/8",
+ "172.16.0.0/12",
+ "192.168.0.0/16",
+ # Carrier grade NAT.
+ "100.64.0.0/10",
+ # Address registry.
+ "192.0.0.0/24",
+ # Link-local networks.
+ "169.254.0.0/16",
+ # Testing networks.
+ "198.18.0.0/15",
+ "192.0.2.0/24",
+ "198.51.100.0/24",
+ "203.0.113.0/24",
+ # Multicast.
+ "224.0.0.0/4",
+ # Localhost
+ "::1/128",
+ # Link-local addresses.
+ "fe80::/10",
+ # Unique local addresses.
+ "fc00::/7",
+]
+
DEFAULT_ROOM_VERSION = "6"
ROOM_COMPLEXITY_TOO_GREAT = (
@@ -256,6 +285,38 @@ class ServerConfig(Config):
# due to resource constraints
self.admin_contact = config.get("admin_contact", None)
+ ip_range_blacklist = config.get(
+ "ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST
+ )
+
+ # Attempt to create an IPSet from the given ranges
+ try:
+ self.ip_range_blacklist = IPSet(ip_range_blacklist)
+ except Exception as e:
+ raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e
+ # Always blacklist 0.0.0.0, ::
+ self.ip_range_blacklist.update(["0.0.0.0", "::"])
+
+ try:
+ self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ()))
+ except Exception as e:
+ raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e
+
+ # The federation_ip_range_blacklist is used for backwards-compatibility
+ # and only applies to federation and identity servers. If it is not given,
+ # default to ip_range_blacklist.
+ federation_ip_range_blacklist = config.get(
+ "federation_ip_range_blacklist", ip_range_blacklist
+ )
+ try:
+ self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
+ except Exception as e:
+ raise ConfigError(
+ "Invalid range(s) provided in federation_ip_range_blacklist."
+ ) from e
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
@@ -561,6 +622,10 @@ class ServerConfig(Config):
def generate_config_section(
self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
):
+ ip_range_blacklist = "\n".join(
+ " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
+ )
+
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
@@ -752,6 +817,33 @@ class ServerConfig(Config):
#
#enable_search: false
+ # Prevent outgoing requests from being sent to the following blacklisted IP address
+ # CIDR ranges. If this option is not specified then it defaults to private IP
+ # address ranges (see the example below).
+ #
+ # The blacklist applies to the outbound requests for federation, identity servers,
+ # push servers, and for checking key validity for third-party invite events.
+ #
+ # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+ # listed here, since they correspond to unroutable addresses.)
+ #
+ # This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
+ #
+ #ip_range_blacklist:
+%(ip_range_blacklist)s
+
+ # List of IP address CIDR ranges that should be allowed for federation,
+ # identity servers, push servers, and for checking key validity for
+ # third-party invite events. This is useful for specifying exceptions to
+ # wide-ranging blacklisted target IP ranges - e.g. for communication with
+ # a push server only visible in your network.
+ #
+ # This whitelist overrides ip_range_blacklist and defaults to an empty
+ # list.
+ #
+ #ip_range_whitelist:
+ # - '192.168.1.1'
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 3d067d29..3d05abc1 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -33,13 +33,14 @@ class SpamCheckerConfig(Config):
# spam checker, and thus was simply a dictionary with module
# and config keys. Support this old behaviour by checking
# to see if the option resolves to a dictionary
- self.spam_checkers.append(load_module(spam_checkers))
+ self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",)))
elif isinstance(spam_checkers, list):
- for spam_checker in spam_checkers:
+ for i, spam_checker in enumerate(spam_checkers):
+ config_path = ("spam_checker", "<item %i>" % i)
if not isinstance(spam_checker, dict):
- raise ConfigError("spam_checker syntax is incorrect")
+ raise ConfigError("expected a mapping", config_path)
- self.spam_checkers.append(load_module(spam_checker))
+ self.spam_checkers.append(load_module(spam_checker, config_path))
else:
raise ConfigError("spam_checker syntax is incorrect")
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 44276761..93bbd409 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -93,11 +93,8 @@ class SSOConfig(Config):
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
- # If not set, default templates from within the Synapse package will be used.
- #
- # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
- # If you *do* uncomment it, you will need to make sure that all the templates
- # below are in the directory.
+ # If not set, or the files named below are not found within the template
+ # directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index 10a99c79..c04e1c4e 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config):
provider = config.get("third_party_event_rules", None)
if provider is not None:
- self.third_party_event_rules = load_module(provider)
+ self.third_party_event_rules = load_module(
+ provider, ("third_party_event_rules",)
+ )
def generate_config_section(self, **kwargs):
return """\
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 57ab097e..7ca9efec 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -85,6 +85,9 @@ class WorkerConfig(Config):
# The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port")
+ # The shared secret used for authentication when connecting to the main synapse.
+ self.worker_replication_secret = config.get("worker_replication_secret", None)
+
self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@@ -185,6 +188,13 @@ class WorkerConfig(Config):
# data). If not provided this defaults to the main process.
#
#run_background_tasks_on: worker1
+
+ # A shared secret used by the replication APIs to authenticate HTTP requests
+ # from workers.
+ #
+ # By default this is unused and traffic is not authenticated.
+ #
+ #worker_replication_secret: ""
"""
def read_arguments(self, args):
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 57fd426e..74b67b23 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -227,7 +227,7 @@ class ConnectionVerifier:
# This code is based on twisted.internet.ssl.ClientTLSOptions.
- def __init__(self, hostname: bytes, verify_certs):
+ def __init__(self, hostname: bytes, verify_certs: bool):
self._verify_certs = verify_certs
_decoded = hostname.decode("ascii")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 0422c43f..8fb116ae 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -18,7 +18,7 @@
import collections.abc
import hashlib
import logging
-from typing import Dict
+from typing import Any, Callable, Dict, Tuple
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
@@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
+from synapse.events import EventBase
from synapse.events.utils import prune_event, prune_event_dict
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+Hasher = Callable[[bytes], "hashlib._Hash"]
-def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
+
+def check_event_content_hash(
+ event: EventBase, hash_algorithm: Hasher = hashlib.sha256
+) -> bool:
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
logger.debug(
@@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
return message_hash_bytes == expected_hash
-def compute_content_hash(event_dict, hash_algorithm):
+def compute_content_hash(
+ event_dict: Dict[str, Any], hash_algorithm: Hasher
+) -> Tuple[str, bytes]:
"""Compute the content hash of an event, which is the hash of the
unredacted event.
Args:
- event_dict (dict): The unredacted event as a dict
+ event_dict: The unredacted event as a dict
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
- tuple[str, bytes]: A tuple of the name of hash and the hash as raw
- bytes.
+ A tuple of the name of hash and the hash as raw bytes.
"""
event_dict = dict(event_dict)
event_dict.pop("age_ts", None)
@@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
return hashed.name, hashed.digest()
-def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+def compute_event_reference_hash(
+ event, hash_algorithm: Hasher = hashlib.sha256
+) -> Tuple[str, bytes]:
"""Computes the event reference hash. This is the hash of the redacted
event.
Args:
- event (FrozenEvent)
+ event
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
- tuple[str, bytes]: A tuple of the name of hash and the hash as raw
- bytes.
+ A tuple of the name of hash and the hash as raw bytes.
"""
tmp_event = prune_event(event)
event_dict = tmp_event.get_pdu_json()
@@ -156,7 +163,7 @@ def add_hashes_and_signatures(
event_dict: JsonDict,
signature_name: str,
signing_key: SigningKey,
-):
+) -> None:
"""Add content hash and sign the event
Args:
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c04ad77c..902128a2 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
import urllib
from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from signedjson.key import (
@@ -40,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.config.key import TrustedKeyServer
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
@@ -47,11 +50,15 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -61,16 +68,17 @@ class VerifyJsonRequest:
A request to verify a JSON object.
Attributes:
- server_name(str): The name of the server to verify against.
-
- key_ids(set[str]): The set of key_ids to that could be used to verify the
- JSON object
+ server_name: The name of the server to verify against.
- json_object(dict): The JSON object to verify.
+ json_object: The JSON object to verify.
- minimum_valid_until_ts (int): time at which we require the signing key to
+ minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
+ request_name: The name of the request.
+
+ key_ids: The set of key_ids to that could be used to verify the JSON object
+
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
@@ -80,12 +88,12 @@ class VerifyJsonRequest:
errbacks with an M_UNAUTHORIZED SynapseError.
"""
- server_name = attr.ib()
- json_object = attr.ib()
- minimum_valid_until_ts = attr.ib()
- request_name = attr.ib()
- key_ids = attr.ib(init=False)
- key_ready = attr.ib(default=attr.Factory(defer.Deferred))
+ server_name = attr.ib(type=str)
+ json_object = attr.ib(type=JsonDict)
+ minimum_valid_until_ts = attr.ib(type=int)
+ request_name = attr.ib(type=str)
+ key_ids = attr.ib(init=False, type=List[str])
+ key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name)
@@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
class Keyring:
- def __init__(self, hs, key_fetchers=None):
+ def __init__(
+ self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
+ ):
self.clock = hs.get_clock()
if key_fetchers is None:
@@ -112,22 +122,26 @@ class Keyring:
# completes.
#
# These are regular, logcontext-agnostic Deferreds.
- self.key_downloads = {}
+ self.key_downloads = {} # type: Dict[str, defer.Deferred]
def verify_json_for_server(
- self, server_name, json_object, validity_time, request_name
- ):
+ self,
+ server_name: str,
+ json_object: JsonDict,
+ validity_time: int,
+ request_name: str,
+ ) -> defer.Deferred:
"""Verify that a JSON object has been signed by a given server
Args:
- server_name (str): name of the server which must have signed this object
+ server_name: name of the server which must have signed this object
- json_object (dict): object to be checked
+ json_object: object to be checked
- validity_time (int): timestamp at which we require the signing key to
+ validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care)
- request_name (str): an identifier for this json object (eg, an event id)
+ request_name: an identifier for this json object (eg, an event id)
for logging.
Returns:
@@ -138,12 +152,14 @@ class Keyring:
requests = (req,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
- def verify_json_objects_for_server(self, server_and_json):
+ def verify_json_objects_for_server(
+ self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+ ) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
- server_and_json (iterable[Tuple[str, dict, int, str]):
+ server_and_json:
Iterable of (server_name, json_object, validity_time, request_name)
tuples.
@@ -164,13 +180,14 @@ class Keyring:
for server_name, json_object, validity_time, request_name in server_and_json
)
- def _verify_objects(self, verify_requests):
+ def _verify_objects(
+ self, verify_requests: Iterable[VerifyJsonRequest]
+ ) -> List[defer.Deferred]:
"""Does the work of verify_json_[objects_]for_server
Args:
- verify_requests (iterable[VerifyJsonRequest]):
- Iterable of verification requests.
+ verify_requests: Iterable of verification requests.
Returns:
List<Deferred[None]>: for each input item, a deferred indicating success
@@ -182,7 +199,7 @@ class Keyring:
key_lookups = []
handle = preserve_fn(_handle_key_deferred)
- def process(verify_request):
+ def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
"""Process an entry in the request list
Adds a key request to key_lookups, and returns a deferred which
@@ -222,18 +239,20 @@ class Keyring:
return results
- async def _start_key_lookups(self, verify_requests):
+ async def _start_key_lookups(
+ self, verify_requests: List[VerifyJsonRequest]
+ ) -> None:
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
Args:
- verify_requests (List[VerifyJsonRequest]):
+ verify_requests:
"""
try:
# map from server name to a set of outstanding request ids
- server_to_request_ids = {}
+ server_to_request_ids = {} # type: Dict[str, Set[int]]
for verify_request in verify_requests:
server_name = verify_request.server_name
@@ -275,11 +294,11 @@ class Keyring:
except Exception:
logger.exception("Error starting key lookups")
- async def wait_for_previous_lookups(self, server_names) -> None:
+ async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_names (Iterable[str]): list of servers which we want to look up
+ server_names: list of servers which we want to look up
Returns:
Resolves once all key lookups for the given servers have
@@ -304,7 +323,7 @@ class Keyring:
loop_count += 1
- def _get_server_verify_keys(self, verify_requests):
+ def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.key_ready is called back with
@@ -312,7 +331,7 @@ class Keyring:
with a SynapseError if none of the keys are found.
Args:
- verify_requests (list[VerifyJsonRequest]): list of verify requests
+ verify_requests: list of verify requests
"""
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@@ -366,17 +385,19 @@ class Keyring:
run_in_background(do_iterations)
- async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ async def _attempt_key_fetches_with_fetcher(
+ self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
+ ):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
- fetcher (KeyFetcher): fetcher to use to fetch the keys
- remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
+ fetcher: fetcher to use to fetch the keys
+ remaining_requests: outstanding key requests.
Any successfully-completed requests will be removed from the list.
"""
- # dict[str, dict[str, int]]: keys to fetch.
+ # The keys to fetch.
# server_name -> key_id -> min_valid_ts
- missing_keys = defaultdict(dict)
+ missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
for verify_request in remaining_requests:
# any completed requests should already have been removed
@@ -438,16 +459,18 @@ class Keyring:
remaining_requests.difference_update(completed)
-class KeyFetcher:
- async def get_keys(self, keys_to_fetch):
+class KeyFetcher(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, dict[str, int]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
- Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
- map from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
"""
raise NotImplementedError
@@ -455,31 +478,35 @@ class KeyFetcher:
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
- keys_to_fetch = (
+ key_ids_to_fetch = (
(server_name, key_id)
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
)
- res = await self.store.get_server_verify_keys(keys_to_fetch)
- keys = {}
+ res = await self.store.get_server_verify_keys(key_ids_to_fetch)
+ keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
return keys
-class BaseV2KeyFetcher:
- def __init__(self, hs):
+class BaseV2KeyFetcher(KeyFetcher):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.config = hs.get_config()
- async def process_v2_response(self, from_server, response_json, time_added_ms):
+ async def process_v2_response(
+ self, from_server: str, response_json: JsonDict, time_added_ms: int
+ ) -> Dict[str, FetchKeyResult]:
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
@@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
to /_matrix/key/v2/query.
Args:
- from_server (str): the name of the server producing this result: either
+ from_server: the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
- response_json (dict): the json-decoded Server Keys response object
+ response_json: the json-decoded Server Keys response object
- time_added_ms (int): the timestamp to record in server_keys_json
+ time_added_ms: the timestamp to record in server_keys_json
Returns:
- Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ Map from key_id to result object
"""
ts_valid_until_ms = response_json["valid_until_ts"]
@@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
- async def get_key(key_server):
+ async def get_key(key_server: TrustedKeyServer) -> Dict:
try:
- result = await self.get_server_verify_key_v2_indirect(
+ return await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
- return result
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError)
)
- union_of_keys = {}
+ union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
return union_of_keys
- async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+ async def get_server_verify_key_v2_indirect(
+ self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, dict[str, int]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
- key_server (synapse.config.key.TrustedKeyServer): notary server to query for
- the keys
+ key_server: notary server to query for the keys
Returns:
- dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
- from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
Raises:
KeyLookupError if there was an error processing the entire response from
@@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
- keys = {}
- added_keys = []
+ keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
+ added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
time_now_ms = self.clock.time_msec()
+ assert isinstance(query_response, dict)
for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter
server_name = response.get("server_name")
@@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return keys
- def _validate_perspectives_response(self, key_server, response):
+ def _validate_perspectives_response(
+ self, key_server: TrustedKeyServer, response: JsonDict
+ ) -> None:
"""Optionally check the signature on the result of a /key/query request
Args:
- key_server (synapse.config.key.TrustedKeyServer): the notary server that
- produced this result
+ key_server: the notary server that produced this result
- response (dict): the json-decoded Server Keys response object
+ response: the json-decoded Server Keys response object
"""
perspective_name = key_server.server_name
perspective_keys = key_server.verify_keys
@@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, iterable[str]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_ids
Returns:
- dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
- map from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
"""
results = {}
- async def get_key(key_to_fetch_item):
+ async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
server_name, key_ids = key_to_fetch_item
try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
@@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
await yieldable_gather_results(get_key, keys_to_fetch.items())
return results
- async def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ async def get_server_verify_key_v2_direct(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, FetchKeyResult]:
"""
Args:
- server_name (str):
- key_ids (iterable[str]):
+ server_name:
+ key_ids:
Returns:
- dict[str, FetchKeyResult]: map from key ID to lookup result
+ Map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
"""
- keys = {} # type: dict[str, FetchKeyResult]
+ keys = {} # type: Dict[str, FetchKeyResult]
for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
@@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
+ assert isinstance(response, dict)
if response["server_name"] != server_name:
raise KeyLookupError(
"Expected a response for server %r not %r"
@@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys
-async def _handle_key_deferred(verify_request) -> None:
+async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
- verify_request (VerifyJsonRequest):
+ verify_request:
Raises:
SynapseError if there was a problem performing the verification
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 93689665..e7e3a7b9 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,10 +15,11 @@
# limitations under the License.
import inspect
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import synapse.events
@@ -39,7 +40,9 @@ class SpamChecker:
else:
self.spam_checkers.append(module(config=config))
- def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
+ async def check_event_for_spam(
+ self, event: "synapse.events.EventBase"
+ ) -> Union[bool, str]:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
@@ -50,15 +53,16 @@ class SpamChecker:
event: the event to be checked
Returns:
- True if the event is spammy.
+ True or a string if the event is spammy. If a string is returned it
+ will be used as the error message returned to the user.
"""
for spam_checker in self.spam_checkers:
- if spam_checker.check_event_for_spam(event):
+ if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
return True
return False
- def user_may_invite(
+ async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
) -> bool:
"""Checks if a given user may send an invite
@@ -75,14 +79,18 @@ class SpamChecker:
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ await maybe_awaitable(
+ spam_checker.user_may_invite(
+ inviter_userid, invitee_userid, room_id
+ )
+ )
is False
):
return False
return True
- def user_may_create_room(self, userid: str) -> bool:
+ async def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
@@ -94,12 +102,15 @@ class SpamChecker:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room(userid) is False:
+ if (
+ await maybe_awaitable(spam_checker.user_may_create_room(userid))
+ is False
+ ):
return False
return True
- def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
+ async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
If this method returns false, the association request will be rejected.
@@ -112,12 +123,17 @@ class SpamChecker:
True if the user may create a room alias, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+ if (
+ await maybe_awaitable(
+ spam_checker.user_may_create_room_alias(userid, room_alias)
+ )
+ is False
+ ):
return False
return True
- def user_may_publish_room(self, userid: str, room_id: str) -> bool:
+ async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected.
@@ -130,12 +146,17 @@ class SpamChecker:
True if the user may publish the room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_publish_room(userid, room_id) is False:
+ if (
+ await maybe_awaitable(
+ spam_checker.user_may_publish_room(userid, room_id)
+ )
+ is False
+ ):
return False
return True
- def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+ async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
@@ -157,12 +178,12 @@ class SpamChecker:
if checker:
# Make a copy of the user profile object to ensure the spam checker
# cannot modify it.
- if checker(user_profile.copy()):
+ if await maybe_awaitable(checker(user_profile.copy())):
return True
return False
- def check_registration_for_spam(
+ async def check_registration_for_spam(
self,
email_threepid: Optional[dict],
username: Optional[str],
@@ -185,7 +206,9 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
- behaviour = checker(email_threepid, username, request_info)
+ behaviour = await maybe_awaitable(
+ checker(email_threepid, username, request_info)
+ )
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 38aa4796..38373752 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -78,6 +78,7 @@ class FederationBase:
ctx = current_context()
+ @defer.inlineCallbacks
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
@@ -105,7 +106,11 @@ class FederationBase:
)
return redacted_event
- if self.spam_checker.check_event_for_spam(pdu):
+ result = yield defer.ensureDeferred(
+ self.spam_checker.check_event_for_spam(pdu)
+ )
+
+ if result:
logger.warning(
"Event contains spam, redacting %s: %s",
pdu.event_id,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4b6ab470..35e345ce 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -845,7 +845,6 @@ class FederationHandlerRegistry:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
self.clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f62..abe9168c 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -35,7 +35,7 @@ class TransportLayerClient:
def __init__(self, hs):
self.server_name = hs.hostname
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b53e7a20..cfd094e5 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -144,7 +144,7 @@ class Authenticator:
):
raise FederationDeniedError(origin)
- if not json_request["signatures"]:
+ if origin is None or not json_request["signatures"]:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
@@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args:
hs (synapse.server.HomeServer): homeserver
- resource (TransportLayerServer): resource class to register to
+ resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e8..d29b066a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
class BaseHandler:
"""
Common base class for the event handlers.
+
+ Deprecated: new code should not use this. Instead, Handler classes should define the
+ fields they actually need. The utility methods should either be factored out to
+ standalone helper functions, or to different Handler classes.
"""
def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a7039445..37e63da9 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,27 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import List
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from synapse.api.constants import Membership
-from synapse.events import FrozenEvent
-from synapse.types import RoomStreamToken, StateMap
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.storage = hs.get_storage()
self.state_store = self.storage.state
- async def get_whois(self, user):
+ async def get_whois(self, user: UserID) -> JsonDict:
connections = []
sessions = await self.store.get_user_ip_and_agents(user)
@@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
return ret
- async def get_user(self, user):
+ async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
ret = await self.store.get_user_by_id(user.to_string())
if ret:
@@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
ret["threepids"] = threepids
return ret
- async def export_user_data(self, user_id, writer):
+ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
"""Write all data we have on the user to the given writer.
Args:
- user_id (str)
- writer (ExfiltrationWriter)
+ user_id: The user ID to fetch data of.
+ writer: The writer to write to.
Returns:
Resolves when all data for a user has been written.
@@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
- written_events = set() # Events that we've processed in this room
+ # Events that we've processed in this room
+ written_events = set() # type: Set[str]
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
@@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
- # events "children". dict[str, set[str]]
- unseen_to_child_events = {}
+ # events "children".
+ unseen_to_child_events = {} # type: Dict[str, Set[str]]
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
@@ -197,38 +202,46 @@ class AdminHandler(BaseHandler):
return writer.finished()
-class ExfiltrationWriter:
+class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data.
"""
- def write_events(self, room_id: str, events: List[FrozenEvent]):
+ @abc.abstractmethod
+ def write_events(self, room_id: str, events: List[EventBase]) -> None:
"""Write a batch of events for a room.
"""
- pass
+ raise NotImplementedError()
- def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
+ @abc.abstractmethod
+ def write_state(
+ self, room_id: str, event_id: str, state: StateMap[EventBase]
+ ) -> None:
"""Write the state at the given event in the room.
This only gets called for backward extremities rather than for each
event.
"""
- pass
+ raise NotImplementedError()
- def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
+ @abc.abstractmethod
+ def write_invite(
+ self, room_id: str, event: EventBase, state: StateMap[dict]
+ ) -> None:
"""Write an invite for the room, with associated invite state.
Args:
- room_id
- event
- state: A subset of the state at the
- invite, with a subset of the event keys (type, state_key
- content and sender)
+ room_id: The room ID the invite is for.
+ event: The invite event.
+ state: A subset of the state at the invite, with a subset of the
+ event keys (type, state_key content and sender).
"""
+ raise NotImplementedError()
- def finished(self):
+ @abc.abstractmethod
+ def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
This functions return value is passed to the caller of
`export_user_data`.
"""
- pass
+ raise NotImplementedError()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc0700..f4434673 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import time
import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Dict,
Iterable,
@@ -36,6 +36,8 @@ import attr
import bcrypt
import pymacaroons
+from twisted.web.http import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -56,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -193,39 +196,27 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = (
- hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
- )
-
- # we keep this as a list despite the O(N^2) implication so that we can
- # keep PASSWORD first and avoid confusing clients which pick the first
- # type in the list. (NB that the spec doesn't require us to do so and
- # clients which favour types that they don't understand over those that
- # they do are technically broken)
+ self._password_localdb_enabled = hs.config.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
- login_types = []
- if hs.config.password_localdb_enabled:
- login_types.append(LoginType.PASSWORD)
+ login_types = set()
+ if self._password_localdb_enabled:
+ login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
- if hasattr(provider, "get_supported_login_types"):
- for t in provider.get_supported_login_types().keys():
- if t not in login_types:
- login_types.append(t)
+ login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
+ login_types.discard(LoginType.PASSWORD)
+
+ # Some clients just pick the first type in the list. In this case, we want
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
+ # that comes first, where it's present.
+ self._supported_login_types = []
+ if LoginType.PASSWORD in login_types:
+ self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
-
- self._supported_login_types = login_types
-
- # Login types and UI Auth types have a heavy overlap, but are not
- # necessarily identical. Login types have SSO (and other login types)
- # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
- ui_auth_types = login_types.copy()
- if self._sso_enabled:
- ui_auth_types.append(LoginType.SSO)
- self._supported_ui_auth_types = ui_auth_types
+ self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@@ -235,6 +226,9 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # The number of seconds to keep a UI auth session active.
+ self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
+
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -292,7 +286,7 @@ class AuthHandler(BaseHandler):
request_body: Dict[str, Any],
clientip: str,
description: str,
- ) -> Tuple[dict, str]:
+ ) -> Tuple[dict, Optional[str]]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -319,7 +313,8 @@ class AuthHandler(BaseHandler):
have been given only in a previous call).
'session_id' is the ID of this session, either passed in by the
- client or assigned by this call
+ client or assigned by this call. This is None if UI auth was
+ skipped (by re-using a previous validation).
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
@@ -333,13 +328,26 @@ class AuthHandler(BaseHandler):
"""
+ if self._ui_auth_session_timeout:
+ last_validated = await self.store.get_access_token_last_validated(
+ requester.access_token_id
+ )
+ if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout:
+ # Return the input parameters, minus the auth key, which matches
+ # the logic in check_ui_auth.
+ request_body.pop("auth", None)
+ return request_body, None
+
user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_ui_auth_types]
+ supported_ui_auth_types = await self._get_available_ui_auth_types(
+ requester.user
+ )
+ flows = [[login_type] for login_type in supported_ui_auth_types]
try:
result, params, session_id = await self.check_ui_auth(
@@ -351,7 +359,7 @@ class AuthHandler(BaseHandler):
raise
# find the completed login type
- for login_type in self._supported_ui_auth_types:
+ for login_type in supported_ui_auth_types:
if login_type not in result:
continue
@@ -365,8 +373,46 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
+ # Note that the access token has been validated.
+ await self.store.update_access_token_last_validated(requester.access_token_id)
+
return params, session_id
+ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+ """Get a list of the authentication types this user can use
+ """
+
+ ui_auth_types = set()
+
+ # if the HS supports password auth, and the user has a non-null password, we
+ # support password auth
+ if self._password_localdb_enabled and self._password_enabled:
+ lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+ if lookupres:
+ _, password_hash = lookupres
+ if password_hash:
+ ui_auth_types.add(LoginType.PASSWORD)
+
+ # also allow auth from password providers
+ for provider in self.password_providers:
+ for t in provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
+
+ # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+ # from sso to mxid.
+ if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+ if await self.store.get_external_ids_by_user(user.to_string()):
+ ui_auth_types.add(LoginType.SSO)
+
+ # Our CAS impl does not (yet) correctly register users in user_external_ids,
+ # so always offer that if it's available.
+ if self.hs.config.cas.cas_enabled:
+ ui_auth_types.add(LoginType.SSO)
+
+ return ui_auth_types
+
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -423,13 +469,10 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows.
"""
- authdict = None
sid = None # type: Optional[str]
- if clientdict and "auth" in clientdict:
- authdict = clientdict["auth"]
- del clientdict["auth"]
- if "session" in authdict:
- sid = authdict["session"]
+ authdict = clientdict.pop("auth", {})
+ if "session" in authdict:
+ sid = authdict["session"]
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
@@ -534,6 +577,8 @@ class AuthHandler(BaseHandler):
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
+ # If all the required credentials have been supplied, the user has
+ # successfully completed the UI auth process!
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log
@@ -709,6 +754,7 @@ class AuthHandler(BaseHandler):
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
+ is_appservice_ghost: bool = False,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -725,6 +771,7 @@ class AuthHandler(BaseHandler):
we should always have a device ID)
valid_until_ms: when the token is valid until. None for
no expiry.
+ is_appservice_ghost: Whether the user is an application ghost user
Returns:
The access token for the user's session.
Raises:
@@ -745,7 +792,11 @@ class AuthHandler(BaseHandler):
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
)
- await self.auth.check_auth_blocking(user_id)
+ if (
+ not is_appservice_ghost
+ or self.hs.config.appservice.track_appservice_user_ips
+ ):
+ await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
@@ -831,7 +882,7 @@ class AuthHandler(BaseHandler):
async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -974,7 +1025,7 @@ class AuthHandler(BaseHandler):
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1029,7 +1080,7 @@ class AuthHandler(BaseHandler):
if result:
return result
- if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+ if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
@@ -1052,7 +1103,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1303,15 +1354,14 @@ class AuthHandler(BaseHandler):
)
async def complete_sso_ui_auth(
- self, registered_user_id: str, session_id: str, request: SynapseRequest,
+ self, registered_user_id: str, session_id: str, request: Request,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ session_id: The ID of the user-interactive auth session.
request: The request to complete.
- client_redirect_url: The URL to which to redirect the user at the end of the
- process.
"""
# Mark the stage of the authentication as successful.
# Save the user who authenticated with SSO, this will be used to ensure
@@ -1327,7 +1377,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
@@ -1355,7 +1405,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
@@ -1609,6 +1659,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
- result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
- if inspect.isawaitable(result):
- await result
+ await maybe_awaitable(
+ g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ )
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f4ea0a97..fca210a5 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET
+import attr
+
from twisted.web.client import PartialDownloadError
-from synapse.api.errors import Codes, LoginError
+from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -29,6 +32,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class CasError(Exception):
+ """Used to catch errors when validating the CAS ticket.
+ """
+
+ def __init__(self, error, error_description=None):
+ self.error = error
+ self.error_description = error_description
+
+ def __str__(self):
+ if self.error_description:
+ return "{}: {}".format(self.error, self.error_description)
+ return self.error
+
+
+@attr.s(slots=True, frozen=True)
+class CasResponse:
+ username = attr.ib(type=str)
+ attributes = attr.ib(type=Dict[str, Optional[str]])
+
+
class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
@@ -40,6 +63,7 @@ class CasHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
+ self._store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -50,6 +74,11 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client()
+ # identifier for the external_ids table
+ self._auth_provider_id = "cas"
+
+ self._sso_handler = hs.get_sso_handler()
+
def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
@@ -69,14 +98,20 @@ class CasHandler:
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
- ) -> Tuple[str, Optional[str]]:
+ ) -> CasResponse:
"""
- Validate a CAS ticket with the server, parse the response, and return the user and display name.
+ Validate a CAS ticket with the server, and return the parsed the response.
Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`.
+
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
+ Returns:
+ The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
@@ -89,66 +124,65 @@ class CasHandler:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
+ except HttpResponseException as e:
+ description = (
+ (
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=e.code),
+ )
+ raise CasError("server_error", description) from e
- user, attributes = self._parse_cas_response(body)
- displayname = attributes.pop(self._cas_displayname_attribute, None)
-
- for required_attribute, required_value in self._cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- return user, displayname
+ return self._parse_cas_response(body)
- def _parse_cas_response(
- self, cas_response_body: bytes
- ) -> Tuple[str, Dict[str, Optional[str]]]:
+ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
Returns:
- A tuple of the user and a mapping of other attributes.
+ The parsed CAS response.
"""
+
+ # Ensure the response is valid.
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise CasError(
+ "missing_service_response",
+ "root of CAS response is not serviceResponse",
+ )
+
+ success = root[0].tag.endswith("authenticationSuccess")
+ if not success:
+ raise CasError("unsucessful_response", "Unsuccessful CAS response")
+
+ # Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.exception("Error parsing CAS response")
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
- )
- return user, attributes
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes[tag] = attribute.text
+
+ # Ensure a user was found.
+ if user is None:
+ raise CasError("no_user", "CAS response does not contain user")
+
+ return CasResponse(user, attributes)
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
"""
@@ -201,59 +235,150 @@ class CasHandler:
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
- username, user_display_name = await self._validate_ticket(ticket, args)
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
- # Get the matrix ID from the CAS username.
- user_id = await self._map_cas_user_to_matrix_user(
- username, user_display_name, user_agent, ip_address
+ try:
+ cas_response = await self._validate_ticket(ticket, args)
+ except CasError as e:
+ logger.exception("Could not validate ticket")
+ self._sso_handler.render_error(request, e.error, e.error_description, 401)
+ return
+
+ await self._handle_cas_response(
+ request, cas_response, client_redirect_url, session
)
+ async def _handle_cas_response(
+ self,
+ request: SynapseRequest,
+ cas_response: CasResponse,
+ client_redirect_url: Optional[str],
+ session: Optional[str],
+ ) -> None:
+ """Handle a CAS response to a ticket request.
+
+ Assumes that the response has been validated. Maps the user onto an MXID,
+ registering them if necessary, and returns a response to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+
+ cas_response: The parsed CAS response.
+
+ client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+ session: The session parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the UI Auth session id.
+ """
+
+ # first check if we're doing a UIA
if session:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, session, request,
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id, cas_response.username, session, request,
)
- else:
- # If this not a UI auth request than there must be a redirect URL.
- assert client_redirect_url
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
- )
+ # otherwise, we're handling a login request.
+
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ for required_attribute, required_value in self._cas_required_attributes.items():
+ # If required attribute was not in CAS Response - Forbidden
+ if required_attribute not in cas_response.attributes:
+ self._sso_handler.render_error(
+ request,
+ "unauthorised",
+ "You are not authorised to log in here.",
+ 401,
+ )
+ return
+
+ # Also need to check value
+ if required_value is not None:
+ actual_value = cas_response.attributes[required_attribute]
+ # If required attribute value does not match expected - Forbidden
+ if required_value != actual_value:
+ self._sso_handler.render_error(
+ request,
+ "unauthorised",
+ "You are not authorised to log in here.",
+ 401,
+ )
+ return
+
+ # Call the mapper to register/login the user
+
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url is not None
+
+ try:
+ await self._complete_cas_login(cas_response, request, client_redirect_url)
+ except MappingException as e:
+ logger.exception("Could not map user")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
- async def _map_cas_user_to_matrix_user(
+ async def _complete_cas_login(
self,
- remote_user_id: str,
- display_name: Optional[str],
- user_agent: str,
- ip_address: str,
- ) -> str:
+ cas_response: CasResponse,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ) -> None:
"""
- Given a CAS username, retrieve the user ID for it and possibly register the user.
+ Given a CAS response, complete the login flow
- Args:
- remote_user_id: The username from the CAS response.
- display_name: The display name from the CAS response.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
- Returns:
- The user ID associated with this response.
+ Args:
+ cas_response: The parsed CAS response.
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
"""
+ # Note that CAS does not support a mapping provider, so the logic is hard-coded.
+ localpart = map_username_to_mxid_localpart(cas_response.username)
+
+ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Map from CAS attributes to user attributes.
+ """
+ # Due to the grandfathering logic matching any previously registered
+ # mxids it isn't expected for there to be any failures.
+ if failures:
+ raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+ display_name = cas_response.attributes.get(
+ self._cas_displayname_attribute, None
+ )
- localpart = map_username_to_mxid_localpart(remote_user_id)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ return UserAttributes(localpart=localpart, display_name=display_name)
- # If the user does not exist, register it.
- if not registered_user_id:
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=display_name,
- user_agent_ips=[(user_agent, ip_address)],
+ async def grandfather_existing_users() -> Optional[str]:
+ # Since CAS did not always use the user_external_ids table, always
+ # to attempt to map to existing users.
+ user_id = UserID(localpart, self._hostname).to_string()
+
+ logger.debug(
+ "Looking for existing account based on mapped %s", user_id,
)
- return registered_user_id
+ users = await self._store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ registered_user_id = list(users.keys())[0]
+ logger.info("Grandfathering mapping to %s", registered_user_id)
+ return registered_user_id
+
+ return None
+
+ await self._sso_handler.complete_sso_login_request(
+ self._auth_provider_id,
+ cas_response.username,
+ request,
+ client_redirect_url,
+ cas_response_to_user_attributes,
+ grandfather_existing_users,
+ )
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad5683d2..abcf8635 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it"
)
- if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ if not await self.spam_checker.user_may_create_room_alias(
+ user_id, room_alias
+ ):
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_publish_room(user_id, room_id):
+ if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403, "This user is not permitted to publish rooms to the room list"
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090..fd8de869 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
@@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id
):
raise SynapseError(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index abd8d2af..df29edeb 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -29,7 +29,7 @@ def _create_rerouter(func_name):
async def f(self, group_id, *args, **kwargs):
if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
if self.is_mine_id(group_id):
return await getattr(self.groups_server_handler, func_name)(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b45..c05036ad 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -46,15 +46,17 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs):
super().__init__(hs)
+ # An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
- # We create a blacklisting instance of SimpleHttpClient for contacting identity
- # servers specified by clients
+ # An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
- self.federation_http_client = hs.get_http_client()
+ self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
+ self._web_client_location = hs.config.invite_client_location
+
async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
@@ -803,6 +805,9 @@ class IdentityHandler(BaseHandler):
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
+ # If a custom web client location is available, include it in the request.
+ if self._web_client_location:
+ invite_config["org.matrix.web_client_location"] = self._web_client_location
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index cb11754b..fbd8df9d 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
- room_state = await self.state_store.get_state_for_events([member_event_id])
-
- room_state = room_state[member_event_id]
+ room_state = await self.state_store.get_state_for_event(member_event_id)
limit = pagin_config.limit if pagin_config else None
if limit is None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 96843338..9dfeab09 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -744,7 +744,7 @@ class EventCreationHandler:
event.sender,
)
- spam_error = self.spam_checker.check_event_for_spam(event)
+ spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
@@ -1261,7 +1261,7 @@ class EventCreationHandler:
event, context = await self.create_event(
requester,
{
- "type": "org.matrix.dummy_event",
+ "type": EventTypes.Dummy,
"content": {},
"room_id": room_id,
"sender": user_id,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f708..709f8dfc 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -115,8 +115,6 @@ class OidcHandler(BaseHandler):
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client()
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
@@ -674,38 +672,29 @@ class OidcHandler(BaseHandler):
self._sso_handler.render_error(request, "invalid_token", str(e))
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
+ # first check if we're doing a UIA
+ if ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
+ except Exception as e:
+ logger.exception("Could not extract remote user id")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+ )
+
+ # otherwise, it's a login
# Call the mapper to register/login the user
try:
- user_id = await self._map_userinfo_to_user(
- userinfo, token, user_agent, ip_address
+ await self._complete_oidc_login(
+ userinfo, token, request, client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
-
- # Mapping providers might not have get_extra_attributes: only call this
- # method if it exists.
- extra_attributes = None
- get_extra_attributes = getattr(
- self._user_mapping_provider, "get_extra_attributes", None
- )
- if get_extra_attributes:
- extra_attributes = await get_extra_attributes(userinfo, token)
-
- # and finally complete the login
- if ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, ui_auth_session_id, request
- )
- else:
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url, extra_attributes
- )
def _generate_oidc_session_token(
self,
@@ -828,10 +817,14 @@ class OidcHandler(BaseHandler):
now = self.clock.time_msec()
return now < expiry
- async def _map_userinfo_to_user(
- self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
- ) -> str:
- """Maps a UserInfo object to a mxid.
+ async def _complete_oidc_login(
+ self,
+ userinfo: UserInfo,
+ token: Token,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ) -> None:
+ """Given a UserInfo response, complete the login flow
UserInfo should have a claim that uniquely identifies users. This claim
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
@@ -843,27 +836,23 @@ class OidcHandler(BaseHandler):
If a user already exists with the mxid we've mapped and allow_existing_users
is disabled, raise an exception.
+ Otherwise, render a redirect back to the client_redirect_url with a loginToken.
+
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
Raises:
MappingException: if there was an error while mapping some properties
-
- Returns:
- The mxid of the user
"""
try:
- remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
- # Some OIDC providers use integer IDs, but Synapse expects external IDs
- # to be strings.
- remote_user_id = str(remote_user_id)
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
@@ -924,18 +913,41 @@ class OidcHandler(BaseHandler):
return None
- return await self._sso_handler.get_mxid_from_sso(
+ # Mapping providers might not have get_extra_attributes: only call this
+ # method if it exists.
+ extra_attributes = None
+ get_extra_attributes = getattr(
+ self._user_mapping_provider, "get_extra_attributes", None
+ )
+ if get_extra_attributes:
+ extra_attributes = await get_extra_attributes(userinfo, token)
+
+ await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
remote_user_id,
- user_agent,
- ip_address,
+ request,
+ client_redirect_url,
oidc_response_to_user_attributes,
grandfather_existing_users,
+ extra_attributes,
)
+ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+ """Extract the unique remote id from an OIDC UserInfo block
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ Returns:
+ remote user id
+ """
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ # Some OIDC providers use integer IDs, but Synapse expects external IDs
+ # to be strings.
+ return str(remote_user_id)
+
UserAttributeDict = TypedDict(
- "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+ "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -1016,10 +1028,10 @@ env = Environment(finalize=jinja_finalize)
@attr.s
class JinjaOidcMappingConfig:
- subject_claim = attr.ib() # type: str
- localpart_template = attr.ib() # type: Template
- display_name_template = attr.ib() # type: Optional[Template]
- extra_attributes = attr.ib() # type: Dict[str, Template]
+ subject_claim = attr.ib(type=str)
+ localpart_template = attr.ib(type=Optional[Template])
+ display_name_template = attr.ib(type=Optional[Template])
+ extra_attributes = attr.ib(type=Dict[str, Template])
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
- if "localpart_template" not in config:
- raise ConfigError(
- "missing key: oidc_config.user_mapping_provider.config.localpart_template"
- )
-
- try:
- localpart_template = env.from_string(config["localpart_template"])
- except Exception as e:
- raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
- % (e,)
- )
+ localpart_template = None # type: Optional[Template]
+ if "localpart_template" in config:
+ try:
+ localpart_template = env.from_string(config["localpart_template"])
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template", path=["localpart_template"]
+ ) from e
display_name_template = None # type: Optional[Template]
if "display_name_template" in config:
@@ -1054,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
display_name_template = env.from_string(config["display_name_template"])
except Exception as e:
raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
- % (e,)
- )
+ "invalid jinja template", path=["display_name_template"]
+ ) from e
extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config:
extra_attributes_config = config.get("extra_attributes") or {}
if not isinstance(extra_attributes_config, dict):
- raise ConfigError(
- "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
- )
+ raise ConfigError("must be a dict", path=["extra_attributes"])
for key, value in extra_attributes_config.items():
try:
extra_attributes[key] = env.from_string(value)
except Exception as e:
raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
- % (key, e)
- )
+ "invalid jinja template", path=["extra_attributes", key]
+ ) from e
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
@@ -1088,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
- localpart = self._config.localpart_template.render(user=userinfo).strip()
+ localpart = None
+
+ if self._config.localpart_template:
+ localpart = self._config.localpart_template.render(user=userinfo).strip()
- # Ensure only valid characters are included in the MXID.
- localpart = map_username_to_mxid_localpart(localpart)
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
- # Append suffix integer if last call to this function failed to produce
- # a usable mxid.
- localpart += str(failures) if failures else ""
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 153cbae7..a9abdf42 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -13,18 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
-from synapse.util.async_helpers import maybe_awaitable
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.config.server_name
@@ -37,7 +39,7 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- async def _received_remote_receipt(self, origin, content):
+ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = []
@@ -64,11 +66,11 @@ class ReceiptsHandler(BaseHandler):
await self._handle_new_receipts(receipts)
- async def _handle_new_receipts(self, receipts):
+ async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier.
"""
- min_batch_id = None
- max_batch_id = None
+ min_batch_id = None # type: Optional[int]
+ max_batch_id = None # type: Optional[int]
for receipt in receipts:
res = await self.store.insert_receipt(
@@ -90,7 +92,8 @@ class ReceiptsHandler(BaseHandler):
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
- if min_batch_id is None:
+ # Either both of these should be None or neither.
+ if min_batch_id is None or max_batch_id is None:
# no new receipts
return False
@@ -98,15 +101,15 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- await maybe_awaitable(
- self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
- )
+ await self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
)
return True
- async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+ async def received_client_receipt(
+ self, room_id: str, receipt_type: str, user_id: str, event_id: str
+ ) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -126,10 +129,12 @@ class ReceiptsHandler(BaseHandler):
class ReceiptEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(
+ self, from_key: int, room_ids: List[str], **kwargs
+ ) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()
@@ -174,5 +179,5 @@ class ReceiptEventSource:
return (events, to_key)
- def get_current_key(self, direction="f"):
+ def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0d85fd08..a2cf0f6f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
- result = self.spam_checker.check_registration_for_spam(
+ result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)
@@ -630,6 +630,7 @@ class RegistrationHandler(BaseHandler):
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
+ is_appservice_ghost: bool = False,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@@ -651,6 +652,7 @@ class RegistrationHandler(BaseHandler):
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
+ is_appservice_ghost=is_appservice_ghost,
)
return r["device_id"], r["access_token"]
@@ -672,7 +674,10 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
+ user_id,
+ device_id=registered_device_id,
+ valid_until_ms=valid_until_ms,
+ is_appservice_ghost=is_appservice_ghost,
)
return (registered_device_id, access_token)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 930047e7..1f809fa1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
+ HistoryVisibility,
JoinRules,
Membership,
RoomCreationPreset,
@@ -81,21 +82,21 @@ class RoomCreationHandler(BaseHandler):
self._presets_dict = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False,
"guest_can_join": True,
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": True,
"guest_can_join": True,
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False,
"guest_can_join": False,
"power_level_content_override": {},
@@ -358,7 +359,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -440,6 +441,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
+ ratelimit=False,
)
# Transfer membership events
@@ -608,7 +610,7 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
- if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -735,6 +737,7 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias,
power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
+ ratelimit=ratelimit,
)
if "name" in config:
@@ -838,6 +841,7 @@ class RoomCreationHandler(BaseHandler):
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
+ ratelimit: bool = True,
) -> int:
"""Sends the initial events into a new room.
@@ -884,7 +888,7 @@ class RoomCreationHandler(BaseHandler):
creator.user,
room_id,
"join",
- ratelimit=False,
+ ratelimit=ratelimit,
content=creator_join_profile,
)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 4a13c8e9..14f14db4 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,19 +15,22 @@
import logging
from collections import namedtuple
-from typing import Any, Dict, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.api.errors import Codes, HttpResponseException
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
@@ -37,37 +40,38 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
- self.response_cache = ResponseCache(hs, "room_list")
+ self.response_cache = ResponseCache(
+ hs, "room_list"
+ ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
hs, "remote_room_list", timeout_ms=30 * 1000
- )
+ ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(
self,
- limit=None,
- since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False,
- ):
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ from_federation: bool = False,
+ ) -> JsonDict:
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
party network. A client can ask for a specific list or to return all.
Args:
- limit (int|None)
- since_token (str|None)
- search_filter (dict|None)
- network_tuple (ThirdPartyInstanceID): Which public list to use.
+ limit
+ since_token
+ search_filter
+ network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
- from_federation (bool): true iff the request comes from the federation
- API
+ from_federation: true iff the request comes from the federation API
"""
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -107,10 +111,10 @@ class RoomListHandler(BaseHandler):
self,
limit: Optional[int] = None,
since_token: Optional[str] = None,
- search_filter: Optional[Dict] = None,
+ search_filter: Optional[dict] = None,
network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
- ) -> Dict[str, Any]:
+ ) -> JsonDict:
"""Generate a public room list.
Args:
limit: Maximum amount of rooms to return.
@@ -131,13 +135,17 @@ class RoomListHandler(BaseHandler):
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
- bounds = (batch_token.last_joined_members, batch_token.last_room_id)
+ bounds = (
+ batch_token.last_joined_members,
+ batch_token.last_room_id,
+ ) # type: Optional[Tuple[int, str]]
forwards = batch_token.direction_is_forward
+ has_batch_token = True
else:
- batch_token = None
bounds = None
forwards = True
+ has_batch_token = False
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
@@ -159,7 +167,8 @@ class RoomListHandler(BaseHandler):
"canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"],
"avatar_url": room["avatar"],
- "world_readable": room["history_visibility"] == "world_readable",
+ "world_readable": room["history_visibility"]
+ == HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
}
@@ -168,7 +177,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results]
- response = {}
+ response = {} # type: JsonDict
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
@@ -186,7 +195,7 @@ class RoomListHandler(BaseHandler):
initial_entry = results[0]
if forwards:
- if batch_token:
+ if has_batch_token:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
@@ -202,7 +211,7 @@ class RoomListHandler(BaseHandler):
direction_is_forward=True,
).to_token()
else:
- if batch_token:
+ if has_batch_token:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
@@ -292,7 +301,7 @@ class RoomListHandler(BaseHandler):
return None
# Return whether this room is open to federation users or not
- create_event = current_state.get((EventTypes.Create, ""))
+ create_event = current_state[EventTypes.Create, ""]
result["m.federate"] = create_event.content.get("m.federate", True)
name_event = current_state.get((EventTypes.Name, ""))
@@ -317,7 +326,7 @@ class RoomListHandler(BaseHandler):
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
- result["world_readable"] = visibility == "world_readable"
+ result["world_readable"] = visibility == HistoryVisibility.WORLD_READABLE
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
@@ -335,13 +344,13 @@ class RoomListHandler(BaseHandler):
async def get_remote_public_room_list(
self,
- server_name,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
- ):
+ server_name: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
+ ) -> JsonDict:
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -398,13 +407,13 @@ class RoomListHandler(BaseHandler):
async def _get_remote_list_cached(
self,
- server_name,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
- ):
+ server_name: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
+ ) -> JsonDict:
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
@@ -455,24 +464,24 @@ class RoomListNextBatch(
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod
- def from_token(cls, token):
+ def from_token(cls, token: str) -> "RoomListNextBatch":
decoded = msgpack.loads(decode_base64(token), raw=False)
return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
)
- def to_token(self):
+ def to_token(self) -> str:
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in self._asdict().items()}
)
)
- def copy_and_replace(self, **kwds):
+ def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
return self._replace(**kwds)
-def _matches_room_entry(room_entry, search_filter):
+def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper():
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index c0028863..cb5a29bc 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -203,7 +203,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
- if newly_joined:
+ if newly_joined and ratelimit:
time_now_s = self.clock.time()
(
allowed,
@@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id
):
logger.info("Blocking invite due to spam checker")
@@ -488,17 +488,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- time_now_s = self.clock.time()
- (
- allowed,
- time_allowed,
- ) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ if ratelimit:
+ time_now_s = self.clock.time()
+ (
+ allowed,
+ time_allowed,
+ ) = self._join_rate_limiter_remote.can_requester_do_action(
+ requester,
)
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169f..5fa7ab3f 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -34,7 +34,6 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
-from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -59,8 +58,6 @@ class SamlHandler(BaseHandler):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
@@ -81,9 +78,6 @@ class SamlHandler(BaseHandler):
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
- # a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
@@ -167,6 +161,29 @@ class SamlHandler(BaseHandler):
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+ await self._handle_authn_response(request, saml2_auth, relay_state)
+
+ async def _handle_authn_response(
+ self,
+ request: SynapseRequest,
+ saml2_auth: saml2.response.AuthnResponse,
+ relay_state: str,
+ ) -> None:
+ """Handle an AuthnResponse, having parsed it from the request params
+
+ Assumes that the signature on the response object has been checked. Maps
+ the user onto an MXID, registering them if necessary, and returns a response
+ to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+ saml2_auth: the parsed AuthnResponse object
+ relay_state: the RelayState query param, which encodes the URI to rediret
+ back to
+ """
+
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
@@ -183,6 +200,24 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None
)
+ # first check if we're doing a UIA
+ if current_session and current_session.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+ except MappingException as e:
+ logger.exception("Failed to extract remote user id from SAML response")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id,
+ remote_user_id,
+ current_session.ui_auth_session_id,
+ request,
+ )
+
+ # otherwise, we're handling a login request.
+
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
@@ -192,63 +227,39 @@ class SamlHandler(BaseHandler):
)
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
# Call the mapper to register/login the user
try:
- user_id = await self._map_saml_response_to_user(
- saml2_auth, relay_state, user_agent, ip_address
- )
+ await self._complete_saml_login(saml2_auth, request, relay_state)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
- # Complete the interactive auth session or the login.
- if current_session and current_session.ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, current_session.ui_auth_session_id, request
- )
-
- else:
- await self._auth_handler.complete_sso_login(user_id, request, relay_state)
-
- async def _map_saml_response_to_user(
+ async def _complete_saml_login(
self,
saml2_auth: saml2.response.AuthnResponse,
+ request: SynapseRequest,
client_redirect_url: str,
- user_agent: str,
- ip_address: str,
- ) -> str:
+ ) -> None:
"""
- Given a SAML response, retrieve the user ID for it and possibly register the user.
+ Given a SAML response, complete the login flow
+
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
Args:
saml2_auth: The parsed SAML2 response.
+ request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
-
- Returns:
- The user ID associated with this response.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
-
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ remote_user_id = self._remote_id_from_saml_response(
saml2_auth, client_redirect_url
)
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
@@ -294,16 +305,44 @@ class SamlHandler(BaseHandler):
return None
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- return await self._sso_handler.get_mxid_from_sso(
- self._auth_provider_id,
- remote_user_id,
- user_agent,
- ip_address,
- saml_response_to_remapped_user_attributes,
- grandfather_existing_users,
+ await self._sso_handler.complete_sso_login_request(
+ self._auth_provider_id,
+ remote_user_id,
+ request,
+ client_redirect_url,
+ saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
+ )
+
+ def _remote_id_from_saml_response(
+ self,
+ saml2_auth: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
+ ) -> str:
+ """Extract the unique remote id from a SAML2 AuthnResponse
+
+ Args:
+ saml2_auth: The parsed SAML2 response.
+ client_redirect_url: The redirect URL passed in by the client.
+ Returns:
+ remote user id
+
+ Raises:
+ MappingException if there was an error extracting the user id
+ """
+ # It's not obvious why we need to pass in the redirect URI to the mapping
+ # provider, but we do :/
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise MappingException(
+ "Failed to extract remote user id from SAML response"
)
+ return remote_user_id
+
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f9..33cd6bc1 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -13,14 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
import attr
+from typing_extensions import NoReturn
-from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
+from twisted.web.http import Request
+
+from synapse.api.errors import RedirectException, SynapseError
from synapse.http.server import respond_with_html
-from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -37,22 +42,70 @@ class MappingException(Exception):
@attr.s
class UserAttributes:
- localpart = attr.ib(type=str)
+ # the localpart of the mxid that the mapper has assigned to the user.
+ # if `None`, the mapper has not picked a userid, and the user should be prompted to
+ # enter one.
+ localpart = attr.ib(type=Optional[str])
display_name = attr.ib(type=Optional[str], default=None)
emails = attr.ib(type=List[str], default=attr.Factory(list))
-class SsoHandler(BaseHandler):
+@attr.s(slots=True)
+class UsernameMappingSession:
+ """Data we track about SSO sessions"""
+
+ # A unique identifier for this SSO provider, e.g. "oidc" or "saml".
+ auth_provider_id = attr.ib(type=str)
+
+ # user ID on the IdP server
+ remote_user_id = attr.ib(type=str)
+
+ # attributes returned by the ID mapper
+ display_name = attr.ib(type=Optional[str])
+ emails = attr.ib(type=List[str])
+
+ # An optional dictionary of extra attributes to be provided to the client in the
+ # login response.
+ extra_login_attributes = attr.ib(type=Optional[JsonDict])
+
+ # where to redirect the client back to
+ client_redirect_url = attr.ib(type=str)
+
+ # expiry time for the session, in milliseconds
+ expiry_time_ms = attr.ib(type=int)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
+class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
+ # the time a UsernameMappingSession remains valid for
+ _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._clock = hs.get_clock()
+ self._store = hs.get_datastore()
+ self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
+ self._auth_handler = hs.get_auth_handler()
+
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
+
+ # a map from session id to session data
+ self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
def render_error(
- self, request, error: str, error_description: Optional[str] = None
+ self,
+ request: Request,
+ error: str,
+ error_description: Optional[str] = None,
+ code: int = 400,
) -> None:
"""Renders the error template and responds with it.
@@ -64,11 +117,12 @@ class SsoHandler(BaseHandler):
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
+ code: The integer error code (an HTTP response code)
"""
html = self._error_template.render(
error=error, error_description=error_description
)
- respond_with_html(request, 400, html)
+ respond_with_html(request, code, html)
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
@@ -95,7 +149,7 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
- previously_registered_user_id = await self.store.get_user_by_external_id(
+ previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
@@ -112,15 +166,16 @@ class SsoHandler(BaseHandler):
# No match.
return None
- async def get_mxid_from_sso(
+ async def complete_sso_login_request(
self,
auth_provider_id: str,
remote_user_id: str,
- user_agent: str,
- ip_address: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
- grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
- ) -> str:
+ grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
+ extra_login_attributes: Optional[JsonDict] = None,
+ ) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -139,12 +194,18 @@ class SsoHandler(BaseHandler):
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
+ Finally, we generate a redirect to the supplied redirect uri, with a login token
+
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
+
remote_user_id: The unique identifier from the SSO provider.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+
+ request: The request to respond to
+
+ client_redirect_url: The redirect URL passed in by the client.
+
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
@@ -156,12 +217,13 @@ class SsoHandler(BaseHandler):
to the user.
RedirectException to redirect to an additional page (e.g.
to prompt the user for more information).
+
grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned
matrix ID.
- Returns:
- The user ID associated with the SSO response.
+ extra_login_attributes: An optional dictionary of extra
+ attributes to be provided to the client in the login response.
Raises:
MappingException if there was a problem mapping the response to a user.
@@ -169,24 +231,55 @@ class SsoHandler(BaseHandler):
to an additional page. (e.g. to prompt for more information)
"""
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
- )
- if previously_registered_user_id:
- return previously_registered_user_id
+ # grab a lock while we try to find a mapping for this user. This seems...
+ # optimistic, especially for implementations that end up redirecting to
+ # interstitial pages.
+ with await self._mapping_lock.queue(auth_provider_id):
+ # first of all, check if we already have a mapping for this user
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
- # Check for grandfathering of users.
- if grandfather_existing_users:
- previously_registered_user_id = await grandfather_existing_users()
- if previously_registered_user_id:
- # Future logins should also match this user ID.
- await self.store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
+ # Check for grandfathering of users.
+ if not user_id:
+ user_id = await grandfather_existing_users()
+ if user_id:
+ # Future logins should also match this user ID.
+ await self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, user_id
+ )
+
+ # Otherwise, generate a new user.
+ if not user_id:
+ attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+ if attributes.localpart is None:
+ # the mapper doesn't return a username. bail out with a redirect to
+ # the username picker.
+ await self._redirect_to_username_picker(
+ auth_provider_id,
+ remote_user_id,
+ attributes,
+ client_redirect_url,
+ extra_login_attributes,
+ )
+
+ user_id = await self._register_mapped_user(
+ attributes,
+ auth_provider_id,
+ remote_user_id,
+ request.get_user_agent(""),
+ request.getClientIP(),
)
- return previously_registered_user_id
- # Otherwise, generate a new user.
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url, extra_login_attributes
+ )
+
+ async def _call_attribute_mapper(
+ self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ ) -> UserAttributes:
+ """Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
try:
attributes = await sso_to_matrix_id_mapper(i)
@@ -208,14 +301,12 @@ class SsoHandler(BaseHandler):
)
if not attributes.localpart:
- raise MappingException(
- "Error parsing SSO response: SSO mapping provider plugin "
- "did not return a localpart value"
- )
+ # the mapper has not picked a localpart
+ return attributes
# Check if this mxid already exists
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- if not await self.store.get_users_by_id_case_insensitive(user_id):
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -224,10 +315,101 @@ class SsoHandler(BaseHandler):
raise MappingException(
"Unable to generate a Matrix ID from the SSO response"
)
+ return attributes
+
+ async def _redirect_to_username_picker(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ attributes: UserAttributes,
+ client_redirect_url: str,
+ extra_login_attributes: Optional[JsonDict],
+ ) -> NoReturn:
+ """Creates a UsernameMappingSession and redirects the browser
+
+ Called if the user mapping provider doesn't return a localpart for a new user.
+ Raises a RedirectException which redirects the browser to the username picker.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+
+ remote_user_id: The unique identifier from the SSO provider.
+
+ attributes: the user attributes returned by the user mapping provider.
+
+ client_redirect_url: The redirect URL passed in by the client, which we
+ will eventually redirect back to.
+
+ extra_login_attributes: An optional dictionary of extra
+ attributes to be provided to the client in the login response.
+
+ Raises:
+ RedirectException
+ """
+ session_id = random_string(16)
+ now = self._clock.time_msec()
+ session = UsernameMappingSession(
+ auth_provider_id=auth_provider_id,
+ remote_user_id=remote_user_id,
+ display_name=attributes.display_name,
+ emails=attributes.emails,
+ client_redirect_url=client_redirect_url,
+ expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+ extra_login_attributes=extra_login_attributes,
+ )
+
+ self._username_mapping_sessions[session_id] = session
+ logger.info("Recorded registration session id %s", session_id)
+
+ # Set the cookie and redirect to the username picker
+ e = RedirectException(b"/_synapse/client/pick_username")
+ e.cookies.append(
+ b"%s=%s; path=/"
+ % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+ )
+ raise e
+
+ async def _register_mapped_user(
+ self,
+ attributes: UserAttributes,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
+ """Register a new SSO user.
+
+ This is called once we have successfully mapped the remote user id onto a local
+ user id, one way or another.
+
+ Args:
+ attributes: user attributes returned by the user mapping provider,
+ including a non-empty localpart.
+
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+
+ remote_user_id: The unique identifier from the SSO provider.
+
+ user_agent: The user-agent in the HTTP request (used for potential
+ shadow-banning.)
+
+ ip_address: The IP address of the requester (used for potential
+ shadow-banning.)
+
+ Raises:
+ a MappingException if the localpart is invalid.
+
+ a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+ is already taken.
+ """
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(attributes.localpart):
+ if not attributes.localpart or contains_invalid_mxid_characters(
+ attributes.localpart
+ ):
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -238,7 +420,152 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)],
)
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
+
+ async def complete_sso_ui_auth_request(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ ui_auth_session_id: str,
+ request: Request,
+ ) -> None:
+ """
+ Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+ Note that this requires that the user is mapped in the "user_external_ids"
+ table. This will be the case if they have ever logged in via SAML or OIDC in
+ recentish synapse versions, but may not be for older users.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ ui_auth_session_id: The ID of the user-interactive auth session.
+ request: The request to complete.
+ """
+
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ if not user_id:
+ logger.warning(
+ "Remote user %s/%s has not previously logged in here: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ )
+ # Let the UIA flow handle this the same as if they presented creds for a
+ # different user.
+ user_id = ""
+
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
+
+ async def check_username_availability(
+ self, localpart: str, session_id: str,
+ ) -> bool:
+ """Handle an "is username available" callback check
+
+ Args:
+ localpart: desired localpart
+ session_id: the session id for the username picker
+ Returns:
+ True if the username is available
+ Raises:
+ SynapseError if the localpart is invalid or the session is unknown
+ """
+
+ # make sure that there is a valid mapping session, to stop people dictionary-
+ # scanning for accounts
+
+ self._expire_old_sessions()
+ session = self._username_mapping_sessions.get(session_id)
+ if not session:
+ logger.info("Couldn't find session id %s", session_id)
+ raise SynapseError(400, "unknown session")
+
+ logger.info(
+ "[session %s] Checking for availability of username %s",
+ session_id,
+ localpart,
+ )
+
+ if contains_invalid_mxid_characters(localpart):
+ raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+ user_id = UserID(localpart, self._server_name).to_string()
+ user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+ logger.info("[session %s] users: %s", session_id, user_infos)
+ return not user_infos
+
+ async def handle_submit_username_request(
+ self, request: SynapseRequest, localpart: str, session_id: str
+ ) -> None:
+ """Handle a request to the username-picker 'submit' endpoint
+
+ Will serve an HTTP response to the request.
+
+ Args:
+ request: HTTP request
+ localpart: localpart requested by the user
+ session_id: ID of the username mapping session, extracted from a cookie
+ """
+ self._expire_old_sessions()
+ session = self._username_mapping_sessions.get(session_id)
+ if not session:
+ logger.info("Couldn't find session id %s", session_id)
+ raise SynapseError(400, "unknown session")
+
+ logger.info("[session %s] Registering localpart %s", session_id, localpart)
+
+ attributes = UserAttributes(
+ localpart=localpart,
+ display_name=session.display_name,
+ emails=session.emails,
+ )
+
+ # the following will raise a 400 error if the username has been taken in the
+ # meantime.
+ user_id = await self._register_mapped_user(
+ attributes,
+ session.auth_provider_id,
+ session.remote_user_id,
+ request.get_user_agent(""),
+ request.getClientIP(),
+ )
+
+ logger.info("[session %s] Registered userid %s", session_id, user_id)
+
+ # delete the mapping session and the cookie
+ del self._username_mapping_sessions[session_id]
+
+ # delete the cookie
+ request.addCookie(
+ USERNAME_MAPPING_SESSION_COOKIE_NAME,
+ b"",
+ expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+ path=b"/",
+ )
+
+ await self._auth_handler.complete_sso_login(
+ user_id,
+ request,
+ session.client_redirect_url,
+ session.extra_login_attributes,
+ )
+
+ def _expire_old_sessions(self):
+ to_expire = []
+ now = int(self._clock.time_msec())
+
+ for session_id, session in self._username_mapping_sessions.items():
+ if session.expiry_time_ms <= now:
+ to_expire.append(session_id)
+
+ for session_id in to_expire:
+ logger.info("Expiring mapping session %s", session_id)
+ del self._username_mapping_sessions[session_id]
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9827c7eb..5c7590f3 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -554,7 +554,7 @@ class SyncHandler:
event.event_id, state_filter=state_filter
)
if event.is_state():
- state_ids = state_ids.copy()
+ state_ids = dict(state_ids)
state_ids[(event.type, event.state_key)] = event.event_id
return state_ids
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index afbebfc2..d4651c83 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,14 +14,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
import synapse.metrics
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
+from synapse.types import JsonDict
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,7 +41,7 @@ class UserDirectoryHandler(StateDeltasHandler):
be in the directory or not when necessary.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
@@ -49,7 +54,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
- self.pos = None
+ self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -61,7 +66,9 @@ class UserDirectoryHandler(StateDeltasHandler):
# we start populating the user directory
self.clock.call_later(0, self.notify_new_event)
- async def search_users(self, user_id, search_term, limit):
+ async def search_users(
+ self, user_id: str, search_term: str, limit: int
+ ) -> JsonDict:
"""Searches for users in directory
Returns:
@@ -81,15 +88,15 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results.
- results["results"] = [
- user
- for user in results["results"]
- if not self.spam_checker.check_username_for_spam(user)
- ]
+ non_spammy_users = []
+ for user in results["results"]:
+ if not await self.spam_checker.check_username_for_spam(user):
+ non_spammy_users.append(user)
+ results["results"] = non_spammy_users
return results
- def notify_new_event(self):
+ def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.update_user_directory:
@@ -107,27 +114,33 @@ class UserDirectoryHandler(StateDeltasHandler):
self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process)
- async def handle_local_profile_change(self, user_id, profile):
+ async def handle_local_profile_change(
+ self, user_id: str, profile: ProfileInfo
+ ) -> None:
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- is_support = await self.store.is_support_user(user_id)
+
# Support users are for diagnostics and should not appear in the user directory.
- if not is_support:
+ is_support = await self.store.is_support_user(user_id)
+ # When change profile information of deactivated user it should not appear in the user directory.
+ is_deactivated = await self.store.get_user_deactivated_status(user_id)
+
+ if not (is_support or is_deactivated):
await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- async def handle_user_deactivated(self, user_id):
+ async def handle_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
await self.store.remove_from_user_dir(user_id)
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
@@ -162,7 +175,7 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos)
- async def _handle_deltas(self, deltas):
+ async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
"""Called with the state deltas to process
"""
for delta in deltas:
@@ -232,16 +245,20 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Ignoring irrelevant type: %r", typ)
async def _handle_room_publicity_change(
- self, room_id, prev_event_id, event_id, typ
- ):
+ self,
+ room_id: str,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ typ: str,
+ ) -> None:
"""Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
- room_id (str)
- prev_event_id (str|None): The previous event before the state change
- event_id (str|None): The new event after the state change
- typ (str): Type of the event
+ room_id: The ID of the room which changed.
+ prev_event_id: The previous event before the state change
+ event_id: The new event after the state change
+ typ: Type of the event
"""
logger.debug("Handling change for %s: %s", typ, room_id)
@@ -250,7 +267,7 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_event_id,
event_id,
key_name="history_visibility",
- public_value="world_readable",
+ public_value=HistoryVisibility.WORLD_READABLE,
)
elif typ == EventTypes.JoinRules:
change = await self._get_key_change(
@@ -299,12 +316,14 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile)
- async def _handle_new_user(self, room_id, user_id, profile):
+ async def _handle_new_user(
+ self, room_id: str, user_id: str, profile: ProfileInfo
+ ) -> None:
"""Called when we might need to add user to directory
Args:
- room_id (str): room_id that user joined or started being public
- user_id (str)
+ room_id: The room ID that user joined or started being public
+ user_id
"""
logger.debug("Adding new user to dir, %r", user_id)
@@ -352,12 +371,12 @@ class UserDirectoryHandler(StateDeltasHandler):
if to_insert:
await self.store.add_users_who_share_private_room(room_id, to_insert)
- async def _handle_remove_user(self, room_id, user_id):
+ async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
"""Called when we might need to remove user from directory
Args:
- room_id (str): room_id that user left or stopped being public that
- user_id (str)
+ room_id: The room ID that user left or stopped being public that
+ user_id
"""
logger.debug("Removing user %r", user_id)
@@ -370,7 +389,13 @@ class UserDirectoryHandler(StateDeltasHandler):
if len(rooms_user_is_in) == 0:
await self.store.remove_from_user_dir(user_id)
- async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+ async def _handle_profile_change(
+ self,
+ user_id: str,
+ room_id: str,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ ) -> None:
"""Check member event changes for any profile changes and update the
database if there are.
"""
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e5b13593..5f74ee11 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -125,7 +125,7 @@ def _make_scheduler(reactor):
return _scheduler
-class IPBlacklistingResolver:
+class _IPBlacklistingResolver:
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
addresses, preventing DNS rebinding attacks on URL preview.
@@ -199,6 +199,35 @@ class IPBlacklistingResolver:
return r
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+ """
+ A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+ addresses, to prevent DNS rebinding.
+ """
+
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
+ self._reactor = reactor
+
+ # We need to use a DNS resolver which filters out blacklisted IP
+ # addresses, to prevent DNS rebinding.
+ self._nameResolver = _IPBlacklistingResolver(
+ self._reactor, ip_whitelist, ip_blacklist
+ )
+
+ def __getattr__(self, attr: str) -> Any:
+ # Passthrough to the real reactor except for the DNS resolver.
+ if attr == "nameResolver":
+ return self._nameResolver
+ else:
+ return getattr(self._reactor, attr)
+
+
class BlacklistingAgentWrapper(Agent):
"""
An Agent wrapper which will prevent access to IP addresses being accessed
@@ -292,22 +321,11 @@ class SimpleHttpClient:
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
- real_reactor = hs.get_reactor()
# If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding.
- nameResolver = IPBlacklistingResolver(
- real_reactor, self._ip_whitelist, self._ip_blacklist
+ self.reactor = BlacklistingReactorWrapper(
+ hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
)
-
- @implementer(IReactorPluggableNameResolver)
- class Reactor:
- def __getattr__(_self, attr):
- if attr == "nameResolver":
- return nameResolver
- else:
- return getattr(real_reactor, attr)
-
- self.reactor = Reactor()
else:
self.reactor = hs.get_reactor()
@@ -323,6 +341,7 @@ class SimpleHttpClient:
self.agent = ProxyAgent(
self.reactor,
+ hs.get_reactor(),
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
@@ -702,11 +721,14 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- readBodyToFile(response, output_stream, max_size)
+ read_body_with_max_size(response, output_stream, max_size)
+ )
+ except BodyExceededMaxSize:
+ SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (max_size,),
+ Codes.TOO_LARGE,
)
- except SynapseError:
- # This can happen e.g. because the body is too large.
- raise
except Exception as e:
raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
@@ -730,7 +752,11 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-class _ReadBodyToFileProtocol(protocol.Protocol):
+class BodyExceededMaxSize(Exception):
+ """The maximum allowed size of the HTTP body was exceeded."""
+
+
+class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -743,13 +769,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
+ self.deferred.errback(BodyExceededMaxSize())
self.deferred = defer.Deferred()
self.transport.loseConnection()
@@ -764,12 +784,15 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-def readBodyToFile(
+def read_body_with_max_size(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ If the maximum file size is reached, the returned Deferred will resolve to a
+ Failure with a BodyExceededMaxSize exception.
+
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
@@ -780,7 +803,7 @@ def readBodyToFile(
"""
d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
+ response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index e77f9587..3b756a7d 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -16,7 +16,7 @@ import logging
import urllib.parse
from typing import List, Optional
-from netaddr import AddrFormatError, IPAddress
+from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
from twisted.internet import defer
@@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
from synapse.crypto.context_factory import FederationPolicyForHTTPS
+from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -70,6 +71,7 @@ class MatrixFederationAgent:
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
+ ip_blacklist: IPSet,
_srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None,
):
@@ -90,12 +92,18 @@ class MatrixFederationAgent:
self.user_agent = user_agent
if _well_known_resolver is None:
+ # Note that the name resolver has already been wrapped in a
+ # IPBlacklistingResolver by MatrixFederationHttpClient.
_well_known_resolver = WellKnownResolver(
self._reactor,
- agent=Agent(
+ agent=BlacklistingAgentWrapper(
+ Agent(
+ self._reactor,
+ pool=self._pool,
+ contextFactory=tls_client_options_factory,
+ ),
self._reactor,
- pool=self._pool,
- contextFactory=tls_client_options_factory,
+ ip_blacklist=ip_blacklist,
),
user_agent=self.user_agent,
)
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 5e08ef16..b3b6dbca 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -15,17 +15,19 @@
import logging
import random
import time
+from io import BytesIO
from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime
-from twisted.web.client import RedirectAgent, readBody
+from twisted.web.client import RedirectAgent
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IResponse
+from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
from synapse.util.caches.ttlcache import TTLCache
@@ -53,6 +55,9 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
# lower bound for .well-known cache period
WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
+# The maximum size (in bytes) to allow a well-known file to be.
+WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB
+
# Attempt to refetch a cached well-known N% of the TTL before it expires.
# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
# we'll start trying to refetch 1 minute before it expires.
@@ -229,6 +234,9 @@ class WellKnownResolver:
server_name: name of the server, from the requested url
retry: Whether to retry the request if it fails.
+ Raises:
+ _FetchWellKnownFailure if we fail to lookup a result
+
Returns:
Returns the response object and body. Response may be a non-200 response.
"""
@@ -250,7 +258,11 @@ class WellKnownResolver:
b"GET", uri, headers=Headers(headers)
)
)
- body = await make_deferred_yieldable(readBody(response))
+ body_stream = BytesIO()
+ await make_deferred_yieldable(
+ read_body_with_max_size(response, body_stream, WELL_KNOWN_MAX_SIZE)
+ )
+ body = body_stream.getvalue()
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
@@ -259,6 +271,15 @@ class WellKnownResolver:
except defer.CancelledError:
# Bail if we've been cancelled
raise
+ except BodyExceededMaxSize:
+ # If the well-known file was too large, do not keep attempting
+ # to download it, but consider it a temporary error.
+ logger.warning(
+ "Requested .well-known file for %s is too large > %r bytes",
+ server_name.decode("ascii"),
+ WELL_KNOWN_MAX_SIZE,
+ )
+ raise _FetchWellKnownFailure(temporary=True)
except Exception as e:
if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
logger.info("Error fetching %s: %s", uri_str, e)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 4e27f93b..b261e078 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -26,11 +26,10 @@ import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
-from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
+from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
@@ -38,16 +37,19 @@ from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
+ Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
+ SynapseError,
)
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
- IPBlacklistingResolver,
+ BlacklistingReactorWrapper,
+ BodyExceededMaxSize,
encode_query_args,
- readBodyToFile,
+ read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
@@ -221,31 +223,22 @@ class MatrixFederationHttpClient:
self.signing_key = hs.signing_key
self.server_name = hs.hostname
- real_reactor = hs.get_reactor()
-
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
- nameResolver = IPBlacklistingResolver(
- real_reactor, None, hs.config.federation_ip_range_blacklist
+ self.reactor = BlacklistingReactorWrapper(
+ hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
)
- @implementer(IReactorPluggableNameResolver)
- class Reactor:
- def __getattr__(_self, attr):
- if attr == "nameResolver":
- return nameResolver
- else:
- return getattr(real_reactor, attr)
-
- self.reactor = Reactor()
-
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
self.agent = MatrixFederationAgent(
- self.reactor, tls_client_options_factory, user_agent
+ self.reactor,
+ tls_client_options_factory,
+ user_agent,
+ hs.config.federation_ip_range_blacklist,
)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
@@ -985,9 +978,15 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
- d = readBodyToFile(response, output_stream, max_size)
+ d = read_body_with_max_size(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
+ except BodyExceededMaxSize:
+ msg = "Requested file is too large > %r bytes" % (max_size,)
+ logger.warning(
+ "{%s} [%s] %s", request.txn_id, request.destination, msg,
+ )
+ SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index e32d3f43..b730d2c6 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -39,6 +39,10 @@ class ProxyAgent(_AgentBase):
reactor: twisted reactor to place outgoing
connections.
+ proxy_reactor: twisted reactor to use for connections to the proxy server
+ reactor might have some blacklisting applied (i.e. for DNS queries),
+ but we need unblocked access to the proxy.
+
contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
verification parameters of OpenSSL. The default is to use a
`BrowserLikePolicyForHTTPS`, so unless you have special
@@ -59,6 +63,7 @@ class ProxyAgent(_AgentBase):
def __init__(
self,
reactor,
+ proxy_reactor=None,
contextFactory=BrowserLikePolicyForHTTPS(),
connectTimeout=None,
bindAddress=None,
@@ -68,6 +73,11 @@ class ProxyAgent(_AgentBase):
):
_AgentBase.__init__(self, reactor, pool)
+ if proxy_reactor is None:
+ self.proxy_reactor = reactor
+ else:
+ self.proxy_reactor = proxy_reactor
+
self._endpoint_kwargs = {}
if connectTimeout is not None:
self._endpoint_kwargs["timeout"] = connectTimeout
@@ -75,11 +85,11 @@ class ProxyAgent(_AgentBase):
self._endpoint_kwargs["bindAddress"] = bindAddress
self.http_proxy_endpoint = _http_proxy_endpoint(
- http_proxy, reactor, **self._endpoint_kwargs
+ http_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
self.https_proxy_endpoint = _http_proxy_endpoint(
- https_proxy, reactor, **self._endpoint_kwargs
+ https_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
self._policy_for_https = contextFactory
@@ -137,7 +147,7 @@ class ProxyAgent(_AgentBase):
request_path = uri
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
endpoint = HTTPConnectProxyEndpoint(
- self._reactor,
+ self.proxy_reactor,
self.https_proxy_endpoint,
parsed_uri.host,
parsed_uri.port,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 6a4e429a..e464bfe6 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -275,6 +275,10 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON.
"""
+ def __init__(self, canonical_json=False, extract_context=False):
+ super().__init__(extract_context)
+ self.canonical_json = canonical_json
+
def _send_response(
self, request: Request, code: int, response_object: Any,
):
@@ -318,9 +322,7 @@ class JsonResource(DirectServeJsonResource):
)
def __init__(self, hs, canonical_json=True, extract_context=False):
- super().__init__(extract_context)
-
- self.canonical_json = canonical_json
+ super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs = {}
self.hs = hs
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5f0581dc..5a579083 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -128,8 +128,7 @@ class SynapseRequest(Request):
# create a LogContext for this request
request_id = self.get_request_id()
- logcontext = self.logcontext = LoggingContext(request_id)
- logcontext.request = request_id
+ self.logcontext = LoggingContext(request_id, request=request_id)
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774c..a507a83e 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel:
def copy_to(self, record):
pass
- def copy_to_twisted_log_entry(self, record):
- record["request"] = None
- record["scope"] = None
-
def start(self, rusage: "Optional[resource._RUsage]"):
pass
@@ -372,13 +368,6 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
- def copy_to_twisted_log_entry(self, record) -> None:
- """
- Copy logging fields from this context to a Twisted log record.
- """
- record["request"] = self.request
- record["scope"] = self.scope
-
def start(self, rusage: "Optional[resource._RUsage]") -> None:
"""
Record that this logcontext is currently running.
@@ -542,13 +531,10 @@ class LoggingContext:
class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each
record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
"""
- def __init__(self, **defaults) -> None:
- self.defaults = defaults
+ def __init__(self, request: str = ""):
+ self._default_request = request
def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
@@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter):
True to include the record in the log output.
"""
context = current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
+ record.request = self._default_request
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
- context.copy_to(record)
+ # Logging is interested in the request.
+ record.request = context.request
return True
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 658f6ecd..70e0fa45 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import threading
from functools import wraps
@@ -25,6 +24,7 @@ from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@@ -199,19 +199,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
- with BackgroundProcessLoggingContext(desc) as context:
- context.request = "%s-%i" % (desc, count)
+ with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
try:
ctx = noop_context_manager()
if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request})
with ctx:
- result = func(*args, **kwargs)
-
- if inspect.isawaitable(result):
- result = await result
-
- return result
+ return await maybe_awaitable(func(*args, **kwargs))
except Exception:
logger.exception(
"Background process '%s' threw an exception", desc,
@@ -249,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"]
- def __init__(self, name: str):
- super().__init__(name)
+ def __init__(self, name: str, request: Optional[str] = None):
+ super().__init__(name, request=request)
self._proc = _BackgroundProcess(name, self)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index a17352ef..c4c8bb27 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -34,7 +34,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.server
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -611,7 +611,9 @@ class Notifier:
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
- return state.content["history_visibility"] == "world_readable"
+ return (
+ state.content["history_visibility"] == HistoryVisibility.WORLD_READABLE
+ )
else:
return False
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5a437f98..f4f7ec96 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -13,7 +13,111 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+import attr
+
+from synapse.types import JsonDict, RoomStreamToken
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
+@attr.s(slots=True)
+class PusherConfig:
+ """Parameters necessary to configure a pusher."""
+
+ id = attr.ib(type=Optional[str])
+ user_name = attr.ib(type=str)
+ access_token = attr.ib(type=Optional[int])
+ profile_tag = attr.ib(type=str)
+ kind = attr.ib(type=str)
+ app_id = attr.ib(type=str)
+ app_display_name = attr.ib(type=str)
+ device_display_name = attr.ib(type=str)
+ pushkey = attr.ib(type=str)
+ ts = attr.ib(type=int)
+ lang = attr.ib(type=Optional[str])
+ data = attr.ib(type=Optional[JsonDict])
+ last_stream_ordering = attr.ib(type=int)
+ last_success = attr.ib(type=Optional[int])
+ failing_since = attr.ib(type=Optional[int])
+
+ def as_dict(self) -> Dict[str, Any]:
+ """Information that can be retrieved about a pusher after creation."""
+ return {
+ "app_display_name": self.app_display_name,
+ "app_id": self.app_id,
+ "data": self.data,
+ "device_display_name": self.device_display_name,
+ "kind": self.kind,
+ "lang": self.lang,
+ "profile_tag": self.profile_tag,
+ "pushkey": self.pushkey,
+ }
+
+
+@attr.s(slots=True)
+class ThrottleParams:
+ """Parameters for controlling the rate of sending pushes via email."""
+
+ last_sent_ts = attr.ib(type=int)
+ throttle_ms = attr.ib(type=int)
+
+
+class Pusher(metaclass=abc.ABCMeta):
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+ self.hs = hs
+ self.store = self.hs.get_datastore()
+ self.clock = self.hs.get_clock()
+
+ self.pusher_id = pusher_config.id
+ self.user_id = pusher_config.user_name
+ self.app_id = pusher_config.app_id
+ self.pushkey = pusher_config.pushkey
+
+ self.last_stream_ordering = pusher_config.last_stream_ordering
+
+ # This is the highest stream ordering we know it's safe to process.
+ # When new events arrive, we'll be given a window of new events: we
+ # should honour this rather than just looking for anything higher
+ # because of potential out-of-order event serialisation.
+ self.max_stream_ordering = self.store.get_room_max_stream_ordering()
+
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ max_stream_ordering = max_token.stream
+
+ self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ self._start_processing()
+
+ @abc.abstractmethod
+ def _start_processing(self):
+ """Start processing push notifications."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_started(self, have_notifs: bool) -> None:
+ """Called when this pusher has been started.
+
+ Args:
+ should_check_for_notifs: Whether we should immediately
+ check for push to send. Set to False only if it's known there
+ is nothing to send
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_stop(self) -> None:
+ raise NotImplementedError()
+
class PusherConfigException(Exception):
- def __init__(self, msg):
- super().__init__(msg)
+ """An error occurred when creating a pusher."""
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index fabc9ba1..aaed2865 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -14,19 +14,22 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure
-from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class ActionGenerator:
- def __init__(self, hs):
- self.hs = hs
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
- self.store = hs.get_datastore()
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
@@ -35,6 +38,8 @@ class ActionGenerator:
# event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users).
- async def handle_push_actions_for_event(self, event, context):
+ async def handle_push_actions_for_event(
+ self, event: EventBase, context: EventContext
+ ) -> None:
with Measure(self.clock, "action_for_event_by_user"):
await self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index f5788c1d..62115069 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -15,16 +15,19 @@
# limitations under the License.
import copy
+from typing import Any, Dict, List
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
-def list_with_base_rules(rawrules, use_new_defaults=False):
+def list_with_base_rules(
+ rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
+) -> List[Dict[str, Any]]:
"""Combine the list of rules set by the user with the default push rules
Args:
- rawrules(list): The rules the user has modified or set.
- use_new_defaults(bool): Whether to use the new experimental default rules when
+ rawrules: The rules the user has modified or set.
+ use_new_defaults: Whether to use the new experimental default rules when
appending or prepending default rules.
Returns:
@@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
return ruleslist
-def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
+def make_base_append_rules(
+ kind: str,
+ modified_base_rules: Dict[str, Dict[str, Any]],
+ use_new_defaults: bool = False,
+) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
@@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
+ assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"])
if modified:
r["actions"] = modified["actions"]
@@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
return rules
-def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
+def make_base_prepend_rules(
+ kind: str,
+ modified_base_rules: Dict[str, Dict[str, Any]],
+ use_new_defaults: bool = False,
+) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
@@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
+ assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"])
if modified:
r["actions"] = modified["actions"]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 82a72dc3..10f27e43 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Counter
@@ -25,16 +26,16 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import register_cache
+from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-logger = logging.getLogger(__name__)
-
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
-rules_by_room = {}
+logger = logging.getLogger(__name__)
push_rules_invalidation_counter = Counter(
@@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
room at once.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- async def _get_rules_for_event(self, event, context):
+ async def _get_rules_for_event(
+ self, event: EventBase, context: EventContext
+ ) -> Dict[str, List[Dict[str, Any]]]:
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.
@@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
return rules_by_user
@lru_cache()
- def _get_rules_for_room(self, room_id):
+ def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id
-
- Returns:
- RulesForRoom
"""
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
@@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
self.room_push_rule_cache_metrics,
)
- async def _get_power_levels_and_sender_level(self, event, context):
+ async def _get_power_levels_and_sender_level(
+ self, event: EventBase, context: EventContext
+ ) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
- pl_event = await self.store.get_event(pl_event_id)
- auth_events = {POWER_KEY: pl_event}
+ auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_dict = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def action_for_event_by_user(self, event, context) -> None:
+ async def action_for_event_by_user(
+ self, event: EventBase, context: EventContext
+ ) -> None:
"""Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the
event_push_actions_staging table.
@@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
- actions_by_user = {}
+ actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
room_members = await self.store.get_joined_users_from_context(event, context)
@@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels
)
- condition_cache = {}
+ condition_cache = {} # type: Dict[str, bool]
for uid, rules in rules_by_user.items():
if event.sender == uid:
@@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
)
-def _condition_checker(evaluator, conditions, uid, display_name, cache):
+def _condition_checker(
+ evaluator: PushRuleEvaluatorForEvent,
+ conditions: List[dict],
+ uid: str,
+ display_name: str,
+ cache: Dict[str, bool],
+) -> bool:
for cond in conditions:
_id = cond.get("_id", None)
if _id:
@@ -277,15 +286,19 @@ class RulesForRoom:
"""
def __init__(
- self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+ self,
+ hs: "HomeServer",
+ room_id: str,
+ rules_for_room_cache: LruCache,
+ room_push_rule_cache_metrics: CacheMetric,
):
"""
Args:
- hs (HomeServer)
- room_id (str)
+ hs: The HomeServer object.
+ room_id: The room ID.
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
- room_push_rule_cache_metrics (CacheMetric)
+ room_push_rule_cache_metrics: The metrics object
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
@@ -294,8 +307,10 @@ class RulesForRoom:
self.linearizer = Linearizer(name="rules_for_room")
- self.member_map = {} # event_id -> (user_id, state)
- self.rules_by_user = {} # user_id -> rules
+ # event_id -> (user_id, state)
+ self.member_map = {} # type: Dict[str, Tuple[str, str]]
+ # user_id -> rules
+ self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
@@ -315,7 +330,7 @@ class RulesForRoom:
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
- self.uninteresting_user_set = set()
+ self.uninteresting_user_set = set() # type: Set[str]
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
@@ -325,7 +340,9 @@ class RulesForRoom:
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
- async def get_rules(self, event, context):
+ async def get_rules(
+ self, event: EventBase, context: EventContext
+ ) -> Dict[str, List[Dict[str, dict]]]:
"""Given an event context return the rules for all users who are
currently in the room.
"""
@@ -356,6 +373,8 @@ class RulesForRoom:
else:
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
+ # Ensure the state IDs exist.
+ assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids))
@@ -420,18 +439,23 @@ class RulesForRoom:
return ret_rules_by_user
async def _update_rules_with_member_event_ids(
- self, ret_rules_by_user, member_event_ids, state_group, event
- ):
+ self,
+ ret_rules_by_user: Dict[str, list],
+ member_event_ids: Dict[str, str],
+ state_group: Optional[int],
+ event: EventBase,
+ ) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
- ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+ ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules.
- member_event_ids (dict): Dict of user id to event id for membership events
+ member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
+ event: The event we are currently computing push rules for.
"""
sequence = self.sequence
@@ -449,19 +473,19 @@ class RulesForRoom:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
- user_ids = {
+ joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
- logger.debug("Joined: %r", user_ids)
+ logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
- user_ids = list(filter(self.is_mine_id, user_ids))
+ user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
@@ -473,7 +497,7 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group)
- def invalidate_all(self):
+ def invalidate_all(self) -> None:
# Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use
@@ -485,7 +509,7 @@ class RulesForRoom:
self.rules_by_user = {}
push_rules_invalidation_counter.inc()
- def update_cache(self, sequence, members, rules_by_user, state_group):
+ def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
if sequence == self.sequence:
self.member_map.update(members)
self.rules_by_user = rules_by_user
@@ -506,7 +530,7 @@ class _Invalidation:
cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str)
- def __call__(self):
+ def __call__(self) -> None:
rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index a59b639f..0cadba76 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -14,24 +14,27 @@
# limitations under the License.
import copy
+from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+from synapse.types import UserID
-def format_push_rules_for_user(user, ruleslist):
+def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
- rules = {"global": {}, "device": {}}
+ rules = {
+ "global": {},
+ "device": {},
+ } # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r in ruleslist:
- rulearray = None
-
template_name = _priority_class_to_template_name(r["priority_class"])
# Remove internal stuff.
@@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
return rules
-def _add_empty_priority_class_arrays(d):
+def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
-def _rule_to_template(rule):
+def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unscoped_rule_id = None
if "rule_id" in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
@@ -82,6 +85,10 @@ def _rule_to_template(rule):
return None
templaterule = {"actions": rule["actions"]}
templaterule["pattern"] = thecond["pattern"]
+ else:
+ # This should not be reached unless this function is not kept in sync
+ # with PRIORITY_CLASS_INVERSE_MAP.
+ raise ValueError("Unexpected template_name: %s" % (template_name,))
if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id
@@ -90,9 +97,9 @@ def _rule_to_template(rule):
return templaterule
-def _rule_id_from_namespaced(in_rule_id):
+def _rule_id_from_namespaced(in_rule_id: str) -> str:
return in_rule_id.split("/")[-1]
-def _priority_class_to_template_name(pc):
+def _priority_class_to_template_name(pc: int) -> str:
return PRIORITY_CLASS_INVERSE_MAP[pc]
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c6763971..4ac1b317 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,11 +14,17 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Dict, List, Optional
+from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import RoomStreamToken
+from synapse.push import Pusher, PusherConfig, ThrottleParams
+from synapse.push.mailer import Mailer
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -46,7 +52,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
INCLUDE_ALL_UNREAD_NOTIFS = False
-class EmailPusher:
+class EmailPusher(Pusher):
"""
A pusher that sends email notifications about events (approximately)
when they happen.
@@ -54,37 +60,30 @@ class EmailPusher:
factor out the common parts
"""
- def __init__(self, hs, pusherdict, mailer):
- self.hs = hs
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
+ super().__init__(hs, pusher_config)
self.mailer = mailer
self.store = self.hs.get_datastore()
- self.clock = self.hs.get_clock()
- self.pusher_id = pusherdict["id"]
- self.user_id = pusherdict["user_name"]
- self.app_id = pusherdict["app_id"]
- self.email = pusherdict["pushkey"]
- self.last_stream_ordering = pusherdict["last_stream_ordering"]
- self.timed_call = None
- self.throttle_params = None
-
- # See httppusher
- self.max_stream_ordering = None
+ self.email = pusher_config.pushkey
+ self.timed_call = None # type: Optional[DelayedCall]
+ self.throttle_params = {} # type: Dict[str, ThrottleParams]
+ self._inited = False
self._is_processing = False
- def on_started(self, should_check_for_notifs):
+ def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
- should_check_for_notifs (bool): Whether we should immediately
+ should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
if should_check_for_notifs and self.mailer is not None:
self._start_processing()
- def on_stop(self):
+ def on_stop(self) -> None:
if self.timed_call:
try:
self.timed_call.cancel()
@@ -92,37 +91,23 @@ class EmailPusher:
pass
self.timed_call = None
- def on_new_notifications(self, max_token: RoomStreamToken):
- # We just use the minimum stream ordering and ignore the vector clock
- # component. This is safe to do as long as we *always* ignore the vector
- # clock components.
- max_stream_ordering = max_token.stream
-
- if self.max_stream_ordering:
- self.max_stream_ordering = max(
- max_stream_ordering, self.max_stream_ordering
- )
- else:
- self.max_stream_ordering = max_stream_ordering
- self._start_processing()
-
- def on_new_receipts(self, min_stream_id, max_stream_id):
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
pass
- def on_timer(self):
+ def on_timer(self) -> None:
self.timed_call = None
self._start_processing()
- def _start_processing(self):
+ def _start_processing(self) -> None:
if self._is_processing:
return
run_as_background_process("emailpush.process", self._process)
- def _pause_processing(self):
+ def _pause_processing(self) -> None:
"""Used by tests to temporarily pause processing of events.
Asserts that its not currently processing.
@@ -130,25 +115,27 @@ class EmailPusher:
assert not self._is_processing
self._is_processing = True
- def _resume_processing(self):
+ def _resume_processing(self) -> None:
"""Used by tests to resume processing of events after pausing.
"""
assert self._is_processing
self._is_processing = False
self._start_processing()
- async def _process(self):
+ async def _process(self) -> None:
# we should never get here if we are already processing
assert not self._is_processing
try:
self._is_processing = True
- if self.throttle_params is None:
+ if not self._inited:
# this is our first loop: load up the throttle params
+ assert self.pusher_id is not None
self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
+ self._inited = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
@@ -163,17 +150,18 @@ class EmailPusher:
finally:
self._is_processing = False
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
"""
Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
- fn = self.store.get_unread_push_actions_for_user_in_range_for_email
- unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
+ unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
+ self.user_id, start, self.max_stream_ordering
+ )
- soonest_due_at = None
+ soonest_due_at = None # type: Optional[int]
if not unprocessed:
await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
@@ -230,11 +218,9 @@ class EmailPusher:
self.seconds_until(soonest_due_at), self.on_timer
)
- async def save_last_stream_ordering_and_success(self, last_stream_ordering):
- if last_stream_ordering is None:
- # This happens if we haven't yet processed anything
- return
-
+ async def save_last_stream_ordering_and_success(
+ self, last_stream_ordering: int
+ ) -> None:
self.last_stream_ordering = last_stream_ordering
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
@@ -248,28 +234,30 @@ class EmailPusher:
# lets just stop and return.
self.on_stop()
- def seconds_until(self, ts_msec):
+ def seconds_until(self, ts_msec: int) -> float:
secs = (ts_msec - self.clock.time_msec()) / 1000
return max(secs, 0)
- def get_room_throttle_ms(self, room_id):
+ def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params:
- return self.throttle_params[room_id]["throttle_ms"]
+ return self.throttle_params[room_id].throttle_ms
else:
return 0
- def get_room_last_sent_ts(self, room_id):
+ def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params:
- return self.throttle_params[room_id]["last_sent_ts"]
+ return self.throttle_params[room_id].last_sent_ts
else:
return 0
- def room_ready_to_notify_at(self, room_id):
+ def room_ready_to_notify_at(self, room_id: str) -> int:
"""
Determines whether throttling should prevent us from sending an email
for the given room
- Returns: The timestamp when we are next allowed to send an email notif
- for this room
+
+ Returns:
+ The timestamp when we are next allowed to send an email notif
+ for this room
"""
last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id)
@@ -277,7 +265,9 @@ class EmailPusher:
may_send_at = last_sent_ts + throttle_ms
return may_send_at
- async def sent_notif_update_throttle(self, room_id, notified_push_action):
+ async def sent_notif_update_throttle(
+ self, room_id: str, notified_push_action: dict
+ ) -> None:
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
@@ -307,15 +297,15 @@ class EmailPusher:
new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
)
- self.throttle_params[room_id] = {
- "last_sent_ts": self.clock.time_msec(),
- "throttle_ms": new_throttle_ms,
- }
+ self.throttle_params[room_id] = ThrottleParams(
+ self.clock.time_msec(), new_throttle_ms,
+ )
+ assert self.pusher_id is not None
await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)
- async def send_notification(self, push_actions, reason):
+ async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
logger.info("Sending notif email for user %r", self.user_id)
await self.mailer.send_notification_mail(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index eff0975b..e048b0d5 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -14,19 +14,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import urllib.parse
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import PusherConfigException
-from synapse.types import RoomStreamToken
+from synapse.push import Pusher, PusherConfig, PusherConfigException
from . import push_rule_evaluator, push_tools
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
http_push_processed_counter = Counter(
@@ -50,91 +55,76 @@ http_badges_failed_counter = Counter(
)
-class HttpPusher:
+class HttpPusher(Pusher):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60
# This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
- def __init__(self, hs, pusherdict):
- self.hs = hs
- self.store = self.hs.get_datastore()
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+ super().__init__(hs, pusher_config)
self.storage = self.hs.get_storage()
- self.clock = self.hs.get_clock()
- self.state_handler = self.hs.get_state_handler()
- self.user_id = pusherdict["user_name"]
- self.app_id = pusherdict["app_id"]
- self.app_display_name = pusherdict["app_display_name"]
- self.device_display_name = pusherdict["device_display_name"]
- self.pushkey = pusherdict["pushkey"]
- self.pushkey_ts = pusherdict["ts"]
- self.data = pusherdict["data"]
- self.last_stream_ordering = pusherdict["last_stream_ordering"]
+ self.app_display_name = pusher_config.app_display_name
+ self.device_display_name = pusher_config.device_display_name
+ self.pushkey_ts = pusher_config.ts
+ self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.failing_since = pusherdict["failing_since"]
+ self.failing_since = pusher_config.failing_since
self.timed_call = None
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
- # This is the highest stream ordering we know it's safe to process.
- # When new events arrive, we'll be given a window of new events: we
- # should honour this rather than just looking for anything higher
- # because of potential out-of-order event serialisation. This starts
- # off as None though as we don't know any better.
- self.max_stream_ordering = None
-
- if "data" not in pusherdict:
- raise PusherConfigException("No 'data' key for HTTP pusher")
- self.data = pusherdict["data"]
+ self.data = pusher_config.data
+ if self.data is None:
+ raise PusherConfigException("'data' key can not be null for HTTP pusher")
self.name = "%s/%s/%s" % (
- pusherdict["user_name"],
- pusherdict["app_id"],
- pusherdict["pushkey"],
+ pusher_config.user_name,
+ pusher_config.app_id,
+ pusher_config.pushkey,
)
- if self.data is None:
- raise PusherConfigException("data can not be null for HTTP pusher")
-
+ # Validate that there's a URL and it is of the proper form.
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
- self.url = self.data["url"]
- self.http_client = hs.get_proxied_http_client()
+
+ url = self.data["url"]
+ if not isinstance(url, str):
+ raise PusherConfigException("'url' must be a string")
+ url_parts = urllib.parse.urlparse(url)
+ # Note that the specification also says the scheme must be HTTPS, but
+ # it isn't up to the homeserver to verify that.
+ if url_parts.path != "/_matrix/push/v1/notify":
+ raise PusherConfigException(
+ "'url' must have a path of '/_matrix/push/v1/notify'"
+ )
+
+ self.url = url
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
- def on_started(self, should_check_for_notifs):
+ def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
- should_check_for_notifs (bool): Whether we should immediately
+ should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
if should_check_for_notifs:
self._start_processing()
- def on_new_notifications(self, max_token: RoomStreamToken):
- # We just use the minimum stream ordering and ignore the vector clock
- # component. This is safe to do as long as we *always* ignore the vector
- # clock components.
- max_stream_ordering = max_token.stream
-
- self.max_stream_ordering = max(
- max_stream_ordering, self.max_stream_ordering or 0
- )
- self._start_processing()
-
- def on_new_receipts(self, min_stream_id, max_stream_id):
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
- async def _update_badge(self):
+ async def _update_badge(self) -> None:
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = await push_tools.get_badge_count(
@@ -144,10 +134,10 @@ class HttpPusher:
)
await self._send_badge(badge)
- def on_timer(self):
+ def on_timer(self) -> None:
self._start_processing()
- def on_stop(self):
+ def on_stop(self) -> None:
if self.timed_call:
try:
self.timed_call.cancel()
@@ -155,13 +145,13 @@ class HttpPusher:
pass
self.timed_call = None
- def _start_processing(self):
+ def _start_processing(self) -> None:
if self._is_processing:
return
run_as_background_process("httppush.process", self._process)
- async def _process(self):
+ async def _process(self) -> None:
# we should never get here if we are already processing
assert not self._is_processing
@@ -180,15 +170,13 @@ class HttpPusher:
finally:
self._is_processing = False
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
"""
Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
-
- fn = self.store.get_unread_push_actions_for_user_in_range_for_http
- unprocessed = await fn(
+ unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@@ -257,17 +245,12 @@ class HttpPusher:
)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
- pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
+ await self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
self.last_stream_ordering,
)
- if not pusher_still_exists:
- # The pusher has been deleted while we were processing, so
- # lets just stop and return.
- self.on_stop()
- return
self.failing_since = None
await self.store.update_pusher_failing_since(
@@ -283,7 +266,7 @@ class HttpPusher:
)
break
- async def _process_one(self, push_action):
+ async def _process_one(self, push_action: dict) -> bool:
if "notify" not in push_action["actions"]:
return True
@@ -314,7 +297,9 @@ class HttpPusher:
await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True
- async def _build_notification_dict(self, event, tweaks, badge):
+ async def _build_notification_dict(
+ self, event: EventBase, tweaks: Dict[str, bool], badge: int
+ ) -> Dict[str, Any]:
priority = "low"
if (
event.type == EventTypes.Encrypted
@@ -325,6 +310,8 @@ class HttpPusher:
# or may do so (i.e. is encrypted so has unknown effects).
priority = "high"
+ # This was checked in the __init__, but mypy doesn't seem to know that.
+ assert self.data is not None
if self.data.get("format") == "event_id_only":
d = {
"notification": {
@@ -344,9 +331,7 @@ class HttpPusher:
}
return d
- ctx = await push_tools.get_context_for_event(
- self.storage, self.state_handler, event, self.user_id
- )
+ ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
d = {
"notification": {
@@ -386,7 +371,9 @@ class HttpPusher:
return d
- async def dispatch_push(self, event, tweaks, badge):
+ async def dispatch_push(
+ self, event: EventBase, tweaks: Dict[str, bool], badge: int
+ ) -> Union[bool, Iterable[str]]:
notification_dict = await self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
return []
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 38195c8e..4d875dcb 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -19,7 +19,7 @@ import logging
import urllib.parse
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-from typing import Iterable, List, TypeVar
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
import bleach
import jinja2
@@ -27,16 +27,20 @@ import jinja2
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
descriptor_from_member_events,
name_from_member_event,
)
-from synapse.types import UserID
+from synapse.types import StateMap, UserID
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -93,7 +97,13 @@ ALLOWED_ATTRS = {
class Mailer:
- def __init__(self, hs, app_name, template_html, template_text):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ app_name: str,
+ template_html: jinja2.Template,
+ template_text: jinja2.Template,
+ ):
self.hs = hs
self.template_html = template_html
self.template_text = template_text
@@ -108,17 +118,19 @@ class Mailer:
logger.info("Created Mailer for app_name %s" % app_name)
- async def send_password_reset_mail(self, email_address, token, client_secret, sid):
+ async def send_password_reset_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a password reset link to a user
Args:
- email_address (str): Email address we're sending the password
+ email_address: Email address we're sending the password
reset to
- token (str): Unique token generated by the server to verify
+ token: Unique token generated by the server to verify
the email was received
- client_secret (str): Unique token generated by the client to
+ client_secret: Unique token generated by the client to
group together multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -136,17 +148,19 @@ class Mailer:
template_vars,
)
- async def send_registration_mail(self, email_address, token, client_secret, sid):
+ async def send_registration_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a registration confirmation link to a user
Args:
- email_address (str): Email address we're sending the registration
+ email_address: Email address we're sending the registration
link to
- token (str): Unique token generated by the server to verify
+ token: Unique token generated by the server to verify
the email was received
- client_secret (str): Unique token generated by the client to
+ client_secret: Unique token generated by the client to
group together multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -164,18 +178,20 @@ class Mailer:
template_vars,
)
- async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
+ async def send_add_threepid_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a validation link to a user for adding a 3pid to their account
Args:
- email_address (str): Email address we're sending the validation link to
+ email_address: Email address we're sending the validation link to
- token (str): Unique token generated by the server to verify the email was received
+ token: Unique token generated by the server to verify the email was received
- client_secret (str): Unique token generated by the client to group together
+ client_secret: Unique token generated by the client to group together
multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -194,8 +210,13 @@ class Mailer:
)
async def send_notification_mail(
- self, app_id, user_id, email_address, push_actions, reason
- ):
+ self,
+ app_id: str,
+ user_id: str,
+ email_address: str,
+ push_actions: Iterable[Dict[str, Any]],
+ reason: Dict[str, Any],
+ ) -> None:
"""Send email regarding a user's room notifications"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
@@ -203,7 +224,7 @@ class Mailer:
[pa["event_id"] for pa in push_actions]
)
- notifs_by_room = {}
+ notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@@ -262,7 +283,9 @@ class Mailer:
await self.send_email(email_address, summary_text, template_vars)
- async def send_email(self, email_address, subject, extra_template_vars):
+ async def send_email(
+ self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
+ ) -> None:
"""Send an email with the given information and template text"""
try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@@ -315,8 +338,13 @@ class Mailer:
)
async def get_room_vars(
- self, room_id, user_id, notifs, notif_events, room_state_ids
- ):
+ self,
+ room_id: str,
+ user_id: str,
+ notifs: Iterable[Dict[str, Any]],
+ notif_events: Dict[str, EventBase],
+ room_state_ids: StateMap[str],
+ ) -> Dict[str, Any]:
# Check if one of the notifs is an invite event for the user.
is_invite = False
for n in notifs:
@@ -334,7 +362,7 @@ class Mailer:
"notifs": [],
"invite": is_invite,
"link": self.make_room_link(room_id),
- }
+ } # type: Dict[str, Any]
if not is_invite:
for n in notifs:
@@ -365,7 +393,13 @@ class Mailer:
return room_vars
- async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
+ async def get_notif_vars(
+ self,
+ notif: Dict[str, Any],
+ user_id: str,
+ notif_event: EventBase,
+ room_state_ids: StateMap[str],
+ ) -> Dict[str, Any]:
results = await self.store.get_events_around(
notif["room_id"],
notif["event_id"],
@@ -391,7 +425,9 @@ class Mailer:
return ret
- async def get_message_vars(self, notif, event, room_state_ids):
+ async def get_message_vars(
+ self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
+ ) -> Optional[Dict[str, Any]]:
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return None
@@ -432,7 +468,9 @@ class Mailer:
return ret
- def add_text_message_vars(self, messagevars, event):
+ def add_text_message_vars(
+ self, messagevars: Dict[str, Any], event: EventBase
+ ) -> None:
msgformat = event.content.get("format")
messagevars["format"] = msgformat
@@ -445,15 +483,22 @@ class Mailer:
elif body:
messagevars["body_text_html"] = safe_text(body)
- return messagevars
-
- def add_image_message_vars(self, messagevars, event):
- messagevars["image_url"] = event.content["url"]
-
- return messagevars
+ def add_image_message_vars(
+ self, messagevars: Dict[str, Any], event: EventBase
+ ) -> None:
+ """
+ Potentially add an image URL to the message variables.
+ """
+ if "url" in event.content:
+ messagevars["image_url"] = event.content["url"]
async def make_summary_text(
- self, notifs_by_room, room_state_ids, notif_events, user_id, reason
+ self,
+ notifs_by_room: Dict[str, List[Dict[str, Any]]],
+ room_state_ids: Dict[str, StateMap[str]],
+ notif_events: Dict[str, EventBase],
+ user_id: str,
+ reason: Dict[str, Any],
):
if len(notifs_by_room) == 1:
# Only one room has new stuff
@@ -580,7 +625,7 @@ class Mailer:
"app": self.app_name,
}
- def make_room_link(self, room_id):
+ def make_room_link(self, room_id: str) -> str:
if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
elif self.app_name == "Vector":
@@ -590,7 +635,7 @@ class Mailer:
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
- def make_notif_link(self, notif):
+ def make_notif_link(self, notif: Dict[str, str]) -> str:
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
@@ -606,7 +651,9 @@ class Mailer:
else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
- def make_unsubscribe_link(self, user_id, app_id, email_address):
+ def make_unsubscribe_link(
+ self, user_id: str, app_id: str, email_address: str
+ ) -> str:
params = {
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id,
@@ -620,7 +667,7 @@ class Mailer:
)
-def safe_markup(raw_html):
+def safe_markup(raw_html: str) -> jinja2.Markup:
return jinja2.Markup(
bleach.linkify(
bleach.clean(
@@ -635,7 +682,7 @@ def safe_markup(raw_html):
)
-def safe_text(raw_text):
+def safe_text(raw_text: str) -> jinja2.Markup:
"""
Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
@@ -655,7 +702,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
return ret
-def string_ordinal_total(s):
+def string_ordinal_total(s: str) -> int:
tot = 0
for c in s:
tot += ord(c)
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index d8f4a453..7e50341d 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -15,8 +15,14 @@
import logging
import re
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.types import StateMap
+
+if TYPE_CHECKING:
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
async def calculate_room_name(
- store,
- room_state_ids,
- user_id,
- fallback_to_members=True,
- fallback_to_single_member=True,
-):
+ store: "DataStore",
+ room_state_ids: StateMap[str],
+ user_id: str,
+ fallback_to_members: bool = True,
+ fallback_to_single_member: bool = True,
+) -> Optional[str]:
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
Does not yet support internationalisation.
Args:
- room_state: Dictionary of the room's state
+ store: The data store to query.
+ room_state_ids: Dictionary of the room's state IDs.
user_id: The ID of the user to whom the room name is being presented
fallback_to_members: If False, return None instead of generating a name
based on the room's members if the room has no
title or aliases.
+ fallback_to_single_member: If False, return None instead of generating a
+ name based on the user who invited this user to the room if the room
+ has no title or aliases.
Returns:
- (string or None) A human readable name for the room.
+ A human readable name for the room, if possible.
"""
# does it have a name?
if (EventTypes.Name, "") in room_state_ids:
@@ -97,7 +107,7 @@ async def calculate_room_name(
name_from_member_event(inviter_member_event),
)
else:
- return
+ return None
else:
return "Room Invite"
@@ -150,19 +160,19 @@ async def calculate_room_name(
else:
return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member:
- return
- else:
- return descriptor_from_member_events(other_members)
+ return None
+
+ return descriptor_from_member_events(other_members)
-def descriptor_from_member_events(member_events):
+def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
"""Get a description of the room based on the member events.
Args:
- member_events (Iterable[FrozenEvent])
+ member_events: The events of a room.
Returns:
- str
+ The room description
"""
member_events = list(member_events)
@@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
)
-def name_from_member_event(member_event):
+def name_from_member_event(member_event: EventBase) -> str:
if (
member_event.content
and "displayname" in member_event.content
@@ -193,12 +203,12 @@ def name_from_member_event(member_event):
return member_event.state_key
-def _state_as_two_level_dict(state):
- ret = {}
+def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
+ ret = {} # type: Dict[str, Dict[str, str]]
for k, v in state.items():
ret.setdefault(k[0], {})[k[1]] = v
return ret
-def _looks_like_an_alias(string):
+def _looks_like_an_alias(string: str) -> bool:
return ALIAS_RE.match(string) is not None
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2ce9e444..ba1877ad 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
-def _room_member_count(ev, condition, room_member_count):
+def _room_member_count(
+ ev: EventBase, condition: Dict[str, Any], room_member_count: int
+) -> bool:
return _test_ineq_condition(condition, room_member_count)
-def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
+def _sender_notification_permission(
+ ev: EventBase,
+ condition: Dict[str, Any],
+ sender_power_level: int,
+ power_levels: Dict[str, Union[int, Dict[str, int]]],
+) -> bool:
notif_level_key = condition.get("key")
if notif_level_key is None:
return False
notif_levels = power_levels.get("notifications", {})
+ assert isinstance(notif_levels, dict)
room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level
-def _test_ineq_condition(condition, number):
+def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
if "is" not in condition:
return False
m = INEQUALITY_EXPR.match(condition["is"])
@@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
event: EventBase,
room_member_count: int,
sender_power_level: int,
- power_levels: dict,
+ power_levels: Dict[str, Union[int, Dict[str, int]]],
):
self._event = event
self._room_member_count = room_member_count
@@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
- def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
+ def matches(
+ self, condition: Dict[str, Any], user_id: str, display_name: str
+ ) -> bool:
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name":
@@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
return r"(^|\W)%s(\W|$)" % (r,)
-def _flatten_dict(d, prefix=[], result=None):
+def _flatten_dict(
+ d: Union[EventBase, dict],
+ prefix: Optional[List[str]] = None,
+ result: Optional[Dict[str, str]] = None,
+) -> Dict[str, str]:
+ if prefix is None:
+ prefix = []
if result is None:
result = {}
for key, value in d.items():
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 6e7c880d..df341032 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict
+
+from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
@@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
return badge
-async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
+async def get_context_for_event(
+ storage: Storage, ev: EventBase, user_id: str
+) -> Dict[str, str]:
ctx = {}
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 2a52e226..2aa7918f 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -14,25 +14,31 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Callable, Dict, Optional
+from synapse.push import Pusher, PusherConfig
from synapse.push.emailpusher import EmailPusher
+from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer
-from .httppusher import HttpPusher
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class PusherFactory:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.config = hs.config
- self.pusher_types = {"http": HttpPusher}
+ self.pusher_types = {
+ "http": HttpPusher
+ } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
- self.mailers = {} # app_name -> Mailer
+ self.mailers = {} # type: Dict[str, Mailer]
self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text
@@ -41,16 +47,18 @@ class PusherFactory:
logger.info("defined email pusher type")
- def create_pusher(self, pusherdict):
- kind = pusherdict["kind"]
+ def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
+ kind = pusher_config.kind
f = self.pusher_types.get(kind, None)
if not f:
return None
- logger.debug("creating %s pusher for %r", kind, pusherdict)
- return f(self.hs, pusherdict)
+ logger.debug("creating %s pusher for %r", kind, pusher_config)
+ return f(self.hs, pusher_config)
- def _create_email_pusher(self, _hs, pusherdict):
- app_name = self._app_name_from_pusherdict(pusherdict)
+ def _create_email_pusher(
+ self, _hs: "HomeServer", pusher_config: PusherConfig
+ ) -> EmailPusher:
+ app_name = self._app_name_from_pusherdict(pusher_config)
mailer = self.mailers.get(app_name)
if not mailer:
mailer = Mailer(
@@ -60,10 +68,10 @@ class PusherFactory:
template_text=self._notif_template_text,
)
self.mailers[app_name] = mailer
- return EmailPusher(self.hs, pusherdict, mailer)
+ return EmailPusher(self.hs, pusher_config, mailer)
- def _app_name_from_pusherdict(self, pusherdict):
- data = pusherdict["data"]
+ def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
+ data = pusher_config.data
if isinstance(data, dict):
brand = data.get("brand")
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f3259649..eed16dbf 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Union
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge
@@ -23,11 +23,9 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.push import PusherConfigException
-from synapse.push.emailpusher import EmailPusher
-from synapse.push.httppusher import HttpPusher
+from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory
-from synapse.types import RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
@@ -77,9 +75,9 @@ class PusherPool:
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
# map from user id to app_id:pushkey to pusher
- self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
+ self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
- def start(self):
+ def start(self) -> None:
"""Starts the pushers off in a background process.
"""
if not self._should_start_pushers:
@@ -89,52 +87,53 @@ class PusherPool:
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- lang,
- data,
- profile_tag="",
- ):
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ lang: Optional[str],
+ data: JsonDict,
+ profile_tag: str = "",
+ ) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
Returns:
- EmailPusher|HttpPusher
+ The newly created pusher.
"""
time_now_msec = self.clock.time_msec()
+ # create the pusher setting last_stream_ordering to the current maximum
+ # stream ordering, so it will process pushes from this point onwards.
+ last_stream_ordering = self.store.get_room_max_stream_ordering()
+
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self.pusher_factory.create_pusher(
- {
- "id": None,
- "user_name": user_id,
- "kind": kind,
- "app_id": app_id,
- "app_display_name": app_display_name,
- "device_display_name": device_display_name,
- "pushkey": pushkey,
- "ts": time_now_msec,
- "lang": lang,
- "data": data,
- "last_stream_ordering": None,
- "last_success": None,
- "failing_since": None,
- }
+ PusherConfig(
+ id=None,
+ user_name=user_id,
+ access_token=access_token,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=time_now_msec,
+ lang=lang,
+ data=data,
+ last_stream_ordering=last_stream_ordering,
+ last_success=None,
+ failing_since=None,
+ )
)
- # create the pusher setting last_stream_ordering to the current maximum
- # stream ordering in event_push_actions, so it will process
- # pushes from this point onwards.
- last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
-
await self.store.add_pusher(
user_id=user_id,
access_token=access_token,
@@ -154,43 +153,44 @@ class PusherPool:
return pusher
async def remove_pushers_by_app_id_and_pushkey_not_user(
- self, app_id, pushkey, not_user_id
- ):
+ self, app_id: str, pushkey: str, not_user_id: str
+ ) -> None:
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
- if p["user_name"] != not_user_id:
+ if p.user_name != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
app_id,
pushkey,
- p["user_name"],
+ p.user_name,
)
- await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+ await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
- async def remove_pushers_by_access_token(self, user_id, access_tokens):
+ async def remove_pushers_by_access_token(
+ self, user_id: str, access_tokens: Iterable[int]
+ ) -> None:
"""Remove the pushers for a given user corresponding to a set of
access_tokens.
Args:
- user_id (str): user to remove pushers for
- access_tokens (Iterable[int]): access token *ids* to remove pushers
- for
+ user_id: user to remove pushers for
+ access_tokens: access token *ids* to remove pushers for
"""
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id):
- if p["access_token"] in tokens:
+ if p.access_token in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
- p["app_id"],
- p["pushkey"],
- p["user_name"],
+ p.app_id,
+ p.pushkey,
+ p.user_name,
)
- await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+ await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
- def on_new_notifications(self, max_token: RoomStreamToken):
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
if not self.pushers:
# nothing to do here.
return
@@ -209,7 +209,7 @@ class PusherPool:
self._on_new_notifications(max_token)
@wrap_as_background_process("on_new_notifications")
- async def _on_new_notifications(self, max_token: RoomStreamToken):
+ async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@@ -239,7 +239,9 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+ async def on_new_receipts(
+ self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
+ ) -> None:
if not self.pushers:
# nothing to do here.
return
@@ -267,28 +269,30 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_receipts")
- async def start_pusher_by_id(self, app_id, pushkey, user_id):
+ async def start_pusher_by_id(
+ self, app_id: str, pushkey: str, user_id: str
+ ) -> Optional[Pusher]:
"""Look up the details for the given pusher, and start it
Returns:
- EmailPusher|HttpPusher|None: The pusher started, if any
+ The pusher started, if any
"""
if not self._should_start_pushers:
- return
+ return None
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
- return
+ return None
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
- pusher_dict = None
+ pusher_config = None
for r in resultlist:
- if r["user_name"] == user_id:
- pusher_dict = r
+ if r.user_name == user_id:
+ pusher_config = r
pusher = None
- if pusher_dict:
- pusher = await self._start_pusher(pusher_dict)
+ if pusher_config:
+ pusher = await self._start_pusher(pusher_config)
return pusher
@@ -303,44 +307,44 @@ class PusherPool:
logger.info("Started pushers")
- async def _start_pusher(self, pusherdict):
+ async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
"""Start the given pusher
Args:
- pusherdict (dict): dict with the values pulled from the db table
+ pusher_config: The pusher configuration with the values pulled from the db table
Returns:
- EmailPusher|HttpPusher
+ The newly created pusher or None.
"""
if not self._pusher_shard_config.should_handle(
- self._instance_name, pusherdict["user_name"]
+ self._instance_name, pusher_config.user_name
):
- return
+ return None
try:
- p = self.pusher_factory.create_pusher(pusherdict)
+ p = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e:
logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
- pusherdict["id"],
- pusherdict.get("user_name"),
- pusherdict.get("app_id"),
- pusherdict.get("pushkey"),
+ pusher_config.id,
+ pusher_config.user_name,
+ pusher_config.app_id,
+ pusher_config.pushkey,
e,
)
- return
+ return None
except Exception:
logger.exception(
- "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+ "Couldn't start pusher id %i: caught Exception", pusher_config.id,
)
- return
+ return None
if not p:
- return
+ return None
- appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
+ appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
- byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+ byuser = self.pushers.setdefault(pusher_config.user_name, {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
@@ -350,8 +354,8 @@ class PusherPool:
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
- user_id = pusherdict["user_name"]
- last_stream_ordering = pusherdict["last_stream_ordering"]
+ user_id = pusher_config.user_name
+ last_stream_ordering = pusher_config.last_stream_ordering
if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
@@ -365,7 +369,7 @@ class PusherPool:
return p
- async def remove_pusher(self, app_id, pushkey, user_id):
+ async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2b3972cb..1492ac92 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
assert self.METHOD in ("PUT", "POST", "GET")
+ self._replication_secret = None
+ if hs.config.worker.worker_replication_secret:
+ self._replication_secret = hs.config.worker.worker_replication_secret
+
+ def _check_auth(self, request) -> None:
+ # Get the authorization header.
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+ if len(auth_headers) > 1:
+ raise RuntimeError("Too many Authorization headers.")
+ parts = auth_headers[0].split(b" ")
+ if parts[0] == b"Bearer" and len(parts) == 2:
+ received_secret = parts[1].decode("ascii")
+ if self._replication_secret == received_secret:
+ # Success!
+ return
+
+ raise RuntimeError("Invalid Authorization header.")
+
@abc.abstractmethod
async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
@@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
+ replication_secret = None
+ if hs.config.worker.worker_replication_secret:
+ replication_secret = hs.config.worker.worker_replication_secret.encode(
+ "ascii"
+ )
+
@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs):
@@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# the master, and so whether we should clean up or not.
while True:
headers = {} # type: Dict[bytes, List[bytes]]
+ # Add an authorization header, if configured.
+ if replication_secret:
+ headers[b"Authorization"] = [b"Bearer " + replication_secret]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
@@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""
url_args = list(self.PATH_ARGS)
- handler = self._handle_request
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(
- method, [pattern], handler, self.__class__.__name__,
+ method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
)
- def _cached_handler(self, request, txn_id, **kwargs):
+ def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well.
- assert self.CACHE
+ # Check the authorization headers before handling the request.
+ if self._replication_secret:
+ self._check_auth(request)
+
+ if self.CACHE:
+ txn_id = kwargs.pop("txn_id")
+
+ return self.response_cache.wrap(
+ txn_id, self._handle_request, request, **kwargs
+ )
- return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
+ return self._handle_request(request, **kwargs)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 4c81e2d7..36071feb 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+ async def _serialize_payload(
+ user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
+ ):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
@@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"device_id": device_id,
"initial_display_name": initial_display_name,
"is_guest": is_guest,
+ "is_appservice_ghost": is_appservice_ghost,
}
async def _handle_request(self, request, user_id):
@@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
+ is_appservice_ghost = content["is_appservice_ghost"]
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest,
+ is_appservice_ghost=is_appservice_ghost,
)
return 200, {"device_id": device_id, "access_token": access_token}
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index eb74903d..0d39a93e 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -12,21 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional, Tuple
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker:
- def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+ def __init__(
+ self,
+ db_conn: Connection,
+ table: str,
+ column: str,
+ extra_tables: Optional[List[Tuple[str, str]]] = None,
+ step: int = 1,
+ ):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
- for table, column in extra_tables:
- self.advance(None, _load_current_id(db_conn, table, column))
+ if extra_tables:
+ for table, column in extra_tables:
+ self.advance(None, _load_current_id(db_conn, table, column))
- def advance(self, instance_name, new_id):
+ def advance(self, instance_name: Optional[str], new_id: int):
self._current = (max if self.step > 0 else min)(self._current, new_id)
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""
Returns:
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index c418730b..045bd014 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -13,26 +13,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.types import Connection
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- self._pushers_id_gen = SlavedIdTracker(
+ self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token, rows
+ ) -> None:
if stream_name == PushersStream.NAME:
- self._pushers_id_gen.advance(instance_name, token)
+ self._pushers_id_gen.advance(instance_name, token) # type: ignore
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599..804da994 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id
- self._logging_context = BackgroundProcessLoggingContext(ctx_name)
- self._logging_context.request = ctx_name
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html
index 6d76064d..0aaef97d 100644
--- a/synapse/res/templates/notif.html
+++ b/synapse/res/templates/notif.html
@@ -29,7 +29,7 @@
{{ message.body_text_html }}
{%- elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
- {%- elif message.msgtype == "m.image" %}
+ {%- elif message.msgtype == "m.image" and message.image_url %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{%- elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html
new file mode 100644
index 00000000..37ea8bb6
--- /dev/null
+++ b/synapse/res/username_picker/index.html
@@ -0,0 +1,19 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <title>Synapse Login</title>
+ <link rel="stylesheet" href="style.css" type="text/css" />
+ </head>
+ <body>
+ <div class="card">
+ <form method="post" class="form__input" id="form" action="submit">
+ <label for="field-username">Please pick your username:</label>
+ <input type="text" name="username" id="field-username" autofocus="">
+ <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
+ </form>
+ <!-- this is used for feedback -->
+ <div role=alert class="tooltip hidden" id="message"></div>
+ <script src="script.js"></script>
+ </div>
+ </body>
+</html>
diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js
new file mode 100644
index 00000000..416a7c6f
--- /dev/null
+++ b/synapse/res/username_picker/script.js
@@ -0,0 +1,95 @@
+let inputField = document.getElementById("field-username");
+let inputForm = document.getElementById("form");
+let submitButton = document.getElementById("button-submit");
+let message = document.getElementById("message");
+
+// Submit username and receive response
+function showMessage(messageText) {
+ // Unhide the message text
+ message.classList.remove("hidden");
+
+ message.textContent = messageText;
+};
+
+function doSubmit() {
+ showMessage("Success. Please wait a moment for your browser to redirect.");
+
+ // remove the event handler before re-submitting the form.
+ delete inputForm.onsubmit;
+ inputForm.submit();
+}
+
+function onResponse(response) {
+ // Display message
+ showMessage(response);
+
+ // Enable submit button and input field
+ submitButton.classList.remove('button--disabled');
+ submitButton.value = "Submit";
+};
+
+let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]");
+function usernameIsValid(username) {
+ return !allowedUsernameCharacters.test(username);
+}
+let allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
+
+function buildQueryString(params) {
+ return Object.keys(params)
+ .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k]))
+ .join('&');
+}
+
+function submitUsername(username) {
+ if(username.length == 0) {
+ onResponse("Please enter a username.");
+ return;
+ }
+ if(!usernameIsValid(username)) {
+ onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString);
+ return;
+ }
+
+ // if this browser doesn't support fetch, skip the availability check.
+ if(!window.fetch) {
+ doSubmit();
+ return;
+ }
+
+ let check_uri = 'check?' + buildQueryString({"username": username});
+ fetch(check_uri, {
+ // include the cookie
+ "credentials": "same-origin",
+ }).then((response) => {
+ if(!response.ok) {
+ // for non-200 responses, raise the body of the response as an exception
+ return response.text().then((text) => { throw text; });
+ } else {
+ return response.json();
+ }
+ }).then((json) => {
+ if(json.error) {
+ throw json.error;
+ } else if(json.available) {
+ doSubmit();
+ } else {
+ onResponse("This username is not available, please choose another.");
+ }
+ }).catch((err) => {
+ onResponse("Error checking username availability: " + err);
+ });
+}
+
+function clickSubmit() {
+ event.preventDefault();
+ if(submitButton.classList.contains('button--disabled')) { return; }
+
+ // Disable submit button and input field
+ submitButton.classList.add('button--disabled');
+
+ // Submit username
+ submitButton.value = "Checking...";
+ submitUsername(inputField.value);
+};
+
+inputForm.onsubmit = clickSubmit;
diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css
new file mode 100644
index 00000000..745bd4c6
--- /dev/null
+++ b/synapse/res/username_picker/style.css
@@ -0,0 +1,27 @@
+input[type="text"] {
+ font-size: 100%;
+ background-color: #ededf0;
+ border: 1px solid #fff;
+ border-radius: .2em;
+ padding: .5em .9em;
+ display: block;
+ width: 26em;
+}
+
+.button--disabled {
+ border-color: #fff;
+ background-color: transparent;
+ color: #000;
+ text-transform: none;
+}
+
+.hidden {
+ display: none;
+}
+
+.tooltip {
+ background-color: #f9f9fa;
+ padding: 1em;
+ margin: 1em 0;
+}
+
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 55ddebb4..6f7dc065 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -38,6 +38,7 @@ from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
+ MakeRoomAdminRestServlet,
RoomMembersRestServlet,
RoomRestServlet,
ShutdownRoomRestServlet,
@@ -228,6 +229,7 @@ def register_servlets(hs, http_server):
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
+ MakeRoomAdminRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 25f89e46..ab7cc910 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -14,10 +14,10 @@
# limitations under the License.
import logging
from http import HTTPStatus
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
-from synapse.api.constants import EventTypes, JoinRules
-from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -25,13 +25,18 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -45,12 +50,14 @@ class ShutdownRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -86,13 +93,15 @@ class DeleteRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -146,12 +155,12 @@ class ListRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -236,19 +245,24 @@ class RoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room_with_stats(room_id)
if not ret:
raise NotFoundError("Room not found")
- return 200, ret
+ members = await self.store.get_users_in_room(room_id)
+ ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+
+ return (200, ret)
class RoomMembersRestServlet(RestServlet):
@@ -258,12 +272,14 @@ class RoomMembersRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
@@ -280,14 +296,16 @@ class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
- async def on_POST(self, request, room_identifier):
+ async def on_POST(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -314,7 +332,6 @@ class JoinRoomAliasServlet(RestServlet):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
@@ -351,3 +368,134 @@ class JoinRoomAliasServlet(RestServlet):
)
return 200, {"room_id": room_id}
+
+
+class MakeRoomAdminRestServlet(RestServlet):
+ """Allows a server admin to get power in a room if a local user has power in
+ a room. Will also invite the user if they're not in the room and it's a
+ private room. Can specify another user (rather than the admin user) to be
+ granted power, e.g.:
+
+ POST/_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
+ {
+ "user_id": "@foo:example.com"
+ }
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.state_handler = hs.get_state_handler()
+ self.is_mine_id = hs.is_mine_id
+
+ async def on_POST(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ content = parse_json_object_from_request(request, allow_empty_body=True)
+
+ # Resolve to a room ID, if necessary.
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ # Which user to grant room admin rights to.
+ user_to_add = content.get("user_id", requester.user.to_string())
+
+ # Figure out which local users currently have power in the room, if any.
+ room_state = await self.state_handler.get_current_state(room_id)
+ if not room_state:
+ raise SynapseError(400, "Server not in room")
+
+ create_event = room_state[(EventTypes.Create, "")]
+ power_levels = room_state.get((EventTypes.PowerLevels, ""))
+
+ if power_levels is not None:
+ # We pick the local user with the highest power.
+ user_power = power_levels.content.get("users", {})
+ admin_users = [
+ user_id for user_id in user_power if self.is_mine_id(user_id)
+ ]
+ admin_users.sort(key=lambda user: user_power[user])
+
+ if not admin_users:
+ raise SynapseError(400, "No local admin user in room")
+
+ admin_user_id = admin_users[-1]
+
+ pl_content = power_levels.content
+ else:
+ # If there is no power level events then the creator has rights.
+ pl_content = {}
+ admin_user_id = create_event.sender
+ if not self.is_mine_id(admin_user_id):
+ raise SynapseError(
+ 400, "No local admin user in room",
+ )
+
+ # Grant the user power equal to the room admin by attempting to send an
+ # updated power level event.
+ new_pl_content = dict(pl_content)
+ new_pl_content["users"] = dict(pl_content.get("users", {}))
+ new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
+
+ fake_requester = create_requester(
+ admin_user_id, authenticated_entity=requester.authenticated_entity,
+ )
+
+ try:
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ fake_requester,
+ event_dict={
+ "content": new_pl_content,
+ "sender": admin_user_id,
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "room_id": room_id,
+ },
+ )
+ except AuthError:
+ # The admin user we found turned out not to have enough power.
+ raise SynapseError(
+ 400, "No local admin user in room with power to update power levels."
+ )
+
+ # Now we check if the user we're granting admin rights to is already in
+ # the room. If not and it's not a public room we invite them.
+ member_event = room_state.get((EventTypes.Member, user_to_add))
+ is_joined = False
+ if member_event:
+ is_joined = member_event.content["membership"] in (
+ Membership.JOIN,
+ Membership.INVITE,
+ )
+
+ if is_joined:
+ return 200, {}
+
+ join_rules = room_state.get((EventTypes.JoinRules, ""))
+ is_public = False
+ if join_rules:
+ is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
+
+ if is_public:
+ return 200, {}
+
+ await self.room_member_handler.update_membership(
+ fake_requester,
+ target=UserID.from_string(user_to_add),
+ room_id=room_id,
+ action=Membership.INVITE,
+ )
+
+ return 200, {}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1e..6658c2da 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -42,17 +42,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_GET_PUSHERS_ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@@ -320,9 +309,9 @@ class UserRestServletV2(RestServlet):
data={},
)
- if "avatar_url" in body and type(body["avatar_url"]) == str:
+ if "avatar_url" in body and isinstance(body["avatar_url"], str):
await self.profile_handler.set_avatar_url(
- user_id, requester, body["avatar_url"], True
+ target_user, requester, body["avatar_url"], True
)
ret = await self.admin_handler.get_user(target_user)
@@ -420,6 +409,9 @@ class UserRegisterServlet(RestServlet):
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
+ if "mac" not in body:
+ raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
got_mac = body["mac"]
want_mac_builder = hmac.new(
@@ -767,10 +759,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.store.get_pushers_by_user_id(user_id)
- filtered_pushers = [
- {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
- for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae1482..5f4c6703 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest):
- self._address_ratelimiter.ratelimit(request.getClientIP())
-
login_submission = parse_json_object_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
+
+ if appservice.is_rate_limited():
+ self._address_ratelimiter.ratelimit(request.getClientIP())
+
result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_token_login(login_submission)
else:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
- return await self._complete_login(qualified_user_id, login_submission)
+ return await self._complete_login(
+ qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+ )
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
+ ratelimit: bool = True,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
+ ratelimit: Whether to ratelimit the login request.
Returns:
result: Dictionary of account information after successful login.
@@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
- self._account_ratelimiter.ratelimit(user_id.lower())
+ if ratelimit:
+ self._account_ratelimiter.ratelimit(user_id.lower())
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f32..89823fcc 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- filtered_pushers = [
- {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers}
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 93c06afe..5647e8c5 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -963,25 +963,28 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs, http_server, is_worker=False):
RoomStateEventRestServlet(hs).register(http_server)
- RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
- RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)
- SearchRestServlet(hs).register(http_server)
- JoinedRoomsRestServlet(hs).register(http_server)
- RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
- RoomAliasListServlet(hs).register(http_server)
+
+ # Some servlets only get registered for the main process.
+ if not is_worker:
+ RoomCreateRestServlet(hs).register(http_server)
+ RoomForgetRestServlet(hs).register(http_server)
+ SearchRestServlet(hs).register(http_server)
+ JoinedRoomsRestServlet(hs).register(http_server)
+ RoomEventServlet(hs).register(http_server)
+ RoomAliasListServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index eebee44a..d837bde1 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -254,14 +254,18 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- # If we have a password in this request, prefer it. Otherwise, there
- # must be a password hash from an earlier request.
+ # If we have a password in this request, prefer it. Otherwise, use the
+ # password hash from an earlier request.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
- else:
+ elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
)
+ else:
+ # UI validation was skipped, but the request did not include a new
+ # password.
+ password_hash = None
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index a3bb095c..5b5da718 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from functools import wraps
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -25,6 +26,22 @@ from ._base import client_patterns
logger = logging.getLogger(__name__)
+def _validate_group_id(f):
+ """Wrapper to validate the form of the group ID.
+
+ Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
+ """
+
+ @wraps(f)
+ def wrapper(self, request, group_id, *args, **kwargs):
+ if not GroupID.is_valid(group_id):
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+ return f(self, request, group_id, *args, **kwargs)
+
+ return wrapper
+
+
class GroupServlet(RestServlet):
"""Get the group profile
"""
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
return 200, group_description
+ @_validate_group_id
async def on_POST(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
return 200, category
+ @_validate_group_id
async def on_PUT(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
return 200, category
+ @_validate_group_id
async def on_PUT(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result
+ @_validate_group_id
async def on_DELETE(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, room_id, config_key):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
+ @_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -565,6 +603,7 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -589,6 +628,7 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -613,6 +653,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -637,6 +678,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a89ae6dd..6b5a1b71 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -451,7 +451,7 @@ class RegisterRestServlet(RestServlet):
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
- raise SynapseError(403, "Registration has been disabled")
+ raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
# For regular registration, convert the provided username to lowercase
# before attempting to register it. This should mean that people who try
@@ -655,9 +655,13 @@ class RegisterRestServlet(RestServlet):
user_id = await self.registration_handler.appservice_register(
username, as_token
)
- return await self._create_registration_details(user_id, body)
+ return await self._create_registration_details(
+ user_id, body, is_appservice_ghost=True,
+ )
- async def _create_registration_details(self, user_id, params):
+ async def _create_registration_details(
+ self, user_id, params, is_appservice_ghost=False
+ ):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -674,7 +678,11 @@ class RegisterRestServlet(RestServlet):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=False
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest=False,
+ is_appservice_ghost=is_appservice_ghost,
)
result.update({"access_token": access_token, "device_id": device_id})
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f4363..a3dee14e 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging
from typing import Tuple
from synapse.http import servlet
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
@@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ("messages",))
sender_user_id = requester.user.to_string()
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f843f024..c57ac22e 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Dict, Set
+from typing import Dict
from signedjson.sign import sign_json
@@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
- cache_misses = {} # type: Dict[str, Set[str]]
+ # Note that the value is unused.
+ cache_misses = {} # type: Dict[str, Dict[str, int]]
for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
if not results and key_id is not None:
- cache_misses.setdefault(server_name, set()).add(key_id)
+ cache_misses.setdefault(server_name, {})[key_id] = 0
continue
if key_id is not None:
@@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
)
if miss:
- cache_misses.setdefault(server_name, set()).add(key_id)
+ cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 67aa993f..47c2b44b 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
request.setHeader(b"Content-Length", b"%d" % (file_size,))
+ # Tell web crawlers to not index, archive, or follow links in media. This
+ # should help to prevent things in the media repo from showing up in web
+ # search results.
+ request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
+
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74eb..83beb02b 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -66,7 +66,7 @@ class MediaRepository:
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index dce6c4d1..1082389d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -676,7 +676,11 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
-def decode_and_calc_og(body, media_uri, request_encoding=None):
+def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
+ # If there's no body, nothing useful is going to be found.
+ if not body:
+ return {}
+
from lxml import etree
try:
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48..67f67efd 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import os
import shutil
@@ -21,6 +20,7 @@ from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder
from .media_storage import FileResponder
@@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(self.backend.store_file(path, file_info))
else:
# TODO: Handle errors.
async def store():
try:
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(
+ self.backend.store_file(path, file_info)
+ )
except Exception:
logger.exception("Error storing file")
@@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
async def fetch(self, path, file_info):
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.fetch(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(self.backend.fetch(path, file_info))
class FileStorageProviderBackend(StorageProvider):
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389..42febc9a 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -44,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
- content_length = request.getHeader(b"Content-Length").decode("ascii")
+ content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
new file mode 100644
index 00000000..d3b6803e
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import pkg_resources
+
+from twisted.web.http import Request
+from twisted.web.resource import Resource
+from twisted.web.static import File
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
+from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+def pick_username_resource(hs: "HomeServer") -> Resource:
+ """Factory method to generate the username picker resource.
+
+ This resource gets mounted under /_synapse/client/pick_username. The top-level
+ resource is just a File resource which serves up the static files in the resources
+ "res" directory, but it has a couple of children:
+
+ * "submit", which does the mechanics of registering the new user, and redirects the
+ browser back to the client URL
+
+ * "check": checks if a userid is free.
+ """
+
+ # XXX should we make this path customisable so that admins can restyle it?
+ base_path = pkg_resources.resource_filename("synapse", "res/username_picker")
+
+ res = File(base_path)
+ res.putChild(b"submit", SubmitResource(hs))
+ res.putChild(b"check", AvailabilityCheckResource(hs))
+
+ return res
+
+
+class AvailabilityCheckResource(DirectServeJsonResource):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_GET(self, request: Request):
+ localpart = parse_string(request, "username", required=True)
+
+ session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+ if not session_id:
+ raise SynapseError(code=400, msg="missing session_id")
+
+ is_available = await self._sso_handler.check_username_availability(
+ localpart, session_id.decode("ascii", errors="replace")
+ )
+ return 200, {"available": is_available}
+
+
+class SubmitResource(DirectServeHtmlResource):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_POST(self, request: SynapseRequest):
+ localpart = parse_string(request, "username", required=True)
+
+ session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+ if not session_id:
+ raise SynapseError(code=400, msg="missing session_id")
+
+ await self._sso_handler.handle_submit_username_request(
+ request, localpart, session_id.decode("ascii", errors="replace")
+ )
diff --git a/synapse/server.py b/synapse/server.py
index b017e348..a198b0eb 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -350,17 +350,47 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client with no special configuration.
+ """
return SimpleHttpClient(self)
@cache_in_self
def get_proxied_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client that uses configured HTTP(S) proxies.
+ """
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
+ )
+
+ @cache_in_self
+ def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
+ based on the IP range blacklist/whitelist.
+ """
return SimpleHttpClient(
self,
+ ip_whitelist=self.config.ip_range_whitelist,
+ ip_blacklist=self.config.ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
@cache_in_self
+ def get_federation_http_client(self) -> MatrixFederationHttpClient:
+ """
+ An HTTP client for federation.
+ """
+ tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
+ self.config
+ )
+ return MatrixFederationHttpClient(self, tls_client_options_factory)
+
+ @cache_in_self
def get_room_creation_handler(self) -> RoomCreationHandler:
return RoomCreationHandler(self)
@@ -515,13 +545,6 @@ class HomeServer(metaclass=abc.ABCMeta):
return PusherPool(self)
@cache_in_self
- def get_http_client(self) -> MatrixFederationHttpClient:
- tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
- self.config
- )
- return MatrixFederationHttpClient(self, tls_client_options_factory)
-
- @cache_in_self
def get_media_repository_resource(self) -> MediaRepositoryResource:
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
@@ -595,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return StatsHandler(self)
@cache_in_self
- def get_spam_checker(self):
+ def get_spam_checker(self) -> SpamChecker:
return SpamChecker(self)
@cache_in_self
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1fa3b280..84f59c7d 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -783,7 +783,7 @@ class StateResolutionStore:
)
def get_auth_chain_difference(
- self, state_sets: List[Set[str]]
+ self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -796,4 +796,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs.
"""
- return self.store.get_auth_chain_difference(state_sets)
+ return self.store.get_auth_chain_difference(room_id, state_sets)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d7..e585954b 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
+ auth_diff = await _get_auth_chain_difference(
+ room_id, state_sets, event_map, state_res_store
+ )
full_conflicted_set = set(
itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference(
+ room_id: str,
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
@@ -252,9 +255,90 @@ async def _get_auth_chain_difference(
Set of event IDs
"""
+ # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+ # all events passed to it (and their auth chains) have been persisted
+ # previously. This is not the case for any events in the `event_map`, and so
+ # we need to manually handle those events.
+ #
+ # We do this by:
+ # 1. calculating the auth chain difference for the state sets based on the
+ # events in `event_map` alone
+ # 2. replacing any events in the state_sets that are also in `event_map`
+ # with their auth events (recursively), and then calling
+ # `store.get_auth_chain_difference` as normal
+ # 3. adding the results of 1 and 2 together.
+
+ # Map from event ID in `event_map` to their auth event IDs, and their auth
+ # event IDs if they appear in the `event_map`. This is the intersection of
+ # the event's auth chain with the events in the `event_map` *plus* their
+ # auth event IDs.
+ events_to_auth_chain = {} # type: Dict[str, Set[str]]
+ for event in event_map.values():
+ chain = {event.event_id}
+ events_to_auth_chain[event.event_id] = chain
+
+ to_search = [event]
+ while to_search:
+ for auth_id in to_search.pop().auth_event_ids():
+ chain.add(auth_id)
+ auth_event = event_map.get(auth_id)
+ if auth_event:
+ to_search.append(auth_event)
+
+ # We now a) calculate the auth chain difference for the unpersisted events
+ # and b) work out the state sets to pass to the store.
+ #
+ # Note: If the `event_map` is empty (which is the common case), we can do a
+ # much simpler calculation.
+ if event_map:
+ # The list of state sets to pass to the store, where each state set is a set
+ # of the event ids making up the state. This is similar to `state_sets`,
+ # except that (a) we only have event ids, not the complete
+ # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+ # unpersisted events and replaced them with the persisted events in
+ # their auth chain.
+ state_sets_ids = [] # type: List[Set[str]]
+
+ # For each state set, the unpersisted event IDs reachable (by their auth
+ # chain) from the events in that set.
+ unpersisted_set_ids = [] # type: List[Set[str]]
+
+ for state_set in state_sets:
+ set_ids = set() # type: Set[str]
+ state_sets_ids.append(set_ids)
+
+ unpersisted_ids = set() # type: Set[str]
+ unpersisted_set_ids.append(unpersisted_ids)
+
+ for event_id in state_set.values():
+ event_chain = events_to_auth_chain.get(event_id)
+ if event_chain is not None:
+ # We have an event in `event_map`. We add all the auth
+ # events that it references (that aren't also in `event_map`).
+ set_ids.update(e for e in event_chain if e not in event_map)
+
+ # We also add the full chain of unpersisted event IDs
+ # referenced by this state set, so that we can work out the
+ # auth chain difference of the unpersisted events.
+ unpersisted_ids.update(e for e in event_chain if e in event_map)
+ else:
+ set_ids.add(event_id)
+
+ # The auth chain difference of the unpersisted events of the state sets
+ # is calculated by taking the difference between the union and
+ # intersections.
+ union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+ intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+ difference_from_event_map = union - intersection # type: Collection[str]
+ else:
+ difference_from_event_map = ()
+ state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
difference = await state_res_store.get_auth_chain_difference(
- [set(state_set.values()) for state_set in state_sets]
+ room_id, state_sets_ids
)
+ difference.update(difference_from_event_map)
return difference
@@ -574,7 +658,7 @@ async def _get_mainline_depth_for_event(
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while tmp_event:
- depth = mainline_map.get(event.event_id)
+ depth = mainline_map.get(tmp_event.event_id)
if depth is not None:
return depth
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index bbff3c8d..c0d9d124 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
+from typing import TYPE_CHECKING
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
@@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
-__all__ = ["DataStores", "DataStore"]
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
class Storage:
"""The high level interfaces for talking to various storage layers.
"""
- def __init__(self, hs, stores: Databases):
+ def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2b196ded..a25c4093 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@
import logging
import random
from abc import ABCMeta
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
-from synapse.types import Collection, get_domain_from_id
+from synapse.storage.types import Connection
+from synapse.types import Collection, StreamToken, get_domain_from_id
from synapse.util import json_decoder
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
self.rand = random.SystemRandom()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: StreamToken,
+ rows: Iterable[Any],
+ ) -> None:
pass
- def _invalidate_state_caches(self, room_id, members_changed):
+ def _invalidate_state_caches(
+ self, room_id: str, members_changed: Iterable[str]
+ ) -> None:
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
Args:
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have
- changed
+ room_id: Room where state changed
+ members_changed: The user_ids of members that have changed
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
@@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
- ):
+ ) -> None:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
cache.invalidate(tuple(key))
-def db_to_json(db_content):
+def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
Take some data from a database row and return a JSON-decoded object.
Args:
- db_content (memoryview|buffer|bytes|bytearray|unicode)
+ db_content: The JSON-encoded contents from the database.
+
+ Returns:
+ The object decoded from JSON.
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721eb..29b8ca67 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
from synapse.util import json_encoder
from . import engines
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.storage.database import DatabasePool, LoggingTransaction
+
logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
- def __init__(self, name):
+ def __init__(self, name: str):
self.name = name
self.total_item_count = 0
- self.total_duration_ms = 0
- self.avg_item_count = 0
- self.avg_duration_ms = 0
+ self.total_duration_ms = 0.0
+ self.avg_item_count = 0.0
+ self.avg_duration_ms = 0.0
- def update(self, item_count, duration_ms):
+ def update(self, item_count: int, duration_ms: float) -> None:
"""Update the stats after doing an update"""
self.total_item_count += item_count
self.total_duration_ms += duration_ms
@@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
- def average_items_per_ms(self):
+ def average_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
# changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms)
- def total_items_per_ms(self):
+ def total_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@@ -83,21 +88,25 @@ class BackgroundUpdater:
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
- def __init__(self, hs, database):
+ def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
- self._background_update_performance = {}
- self._background_update_handlers = {}
+ self._background_update_performance = (
+ {}
+ ) # type: Dict[str, BackgroundUpdatePerformance]
+ self._background_update_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._all_done = False
- def start_doing_background_updates(self):
+ def start_doing_background_updates(self) -> None:
run_as_background_process("background_updates", self.run_background_updates)
- async def run_background_updates(self, sleep=True):
+ async def run_background_updates(self, sleep: bool = True) -> None:
logger.info("Starting background schema updates")
while True:
if sleep:
@@ -148,7 +157,7 @@ class BackgroundUpdater:
return False
- async def has_completed_background_update(self, update_name) -> bool:
+ async def has_completed_background_update(self, update_name: str) -> bool:
"""Check if the given background update has finished running.
"""
if self._all_done:
@@ -173,8 +182,7 @@ class BackgroundUpdater:
Returns once some amount of work is done.
Args:
- desired_duration_ms(float): How long we want to spend
- updating.
+ desired_duration_ms: How long we want to spend updating.
Returns:
True if we have finished running all the background updates, otherwise False
"""
@@ -220,6 +228,7 @@ class BackgroundUpdater:
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
+ assert self._current_background_update is not None
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
@@ -273,7 +282,11 @@ class BackgroundUpdater:
return len(self._background_update_performance)
- def register_background_update_handler(self, update_name, update_handler):
+ def register_background_update_handler(
+ self,
+ update_name: str,
+ update_handler: Callable[[JsonDict, int], Awaitable[int]],
+ ):
"""Register a handler for doing a background update.
The handler should take two arguments:
@@ -287,12 +300,12 @@ class BackgroundUpdater:
The handler is responsible for updating the progress of the update.
Args:
- update_name(str): The name of the update that this code handles.
- update_handler(function): The function that does the update.
+ update_name: The name of the update that this code handles.
+ update_handler: The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
- def register_noop_background_update(self, update_name):
+ def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
This is useful when we previously did a background update, but no
@@ -302,10 +315,10 @@ class BackgroundUpdater:
also be called to clear the update.
Args:
- update_name (str): Name of update
+ update_name: Name of update
"""
- async def noop_update(progress, batch_size):
+ async def noop_update(progress: JsonDict, batch_size: int) -> int:
await self._end_background_update(update_name)
return 1
@@ -313,14 +326,14 @@ class BackgroundUpdater:
def register_background_index_update(
self,
- update_name,
- index_name,
- table,
- columns,
- where_clause=None,
- unique=False,
- psql_only=False,
- ):
+ update_name: str,
+ index_name: str,
+ table: str,
+ columns: Iterable[str],
+ where_clause: Optional[str] = None,
+ unique: bool = False,
+ psql_only: bool = False,
+ ) -> None:
"""Helper for store classes to do a background index addition
To use:
@@ -332,19 +345,19 @@ class BackgroundUpdater:
2. In the Store constructor, call this method
Args:
- update_name (str): update_name to register for
- index_name (str): name of index to add
- table (str): table to add index to
- columns (list[str]): columns/expressions to include in index
- unique (bool): true to make a UNIQUE index
+ update_name: update_name to register for
+ index_name: name of index to add
+ table: table to add index to
+ columns: columns/expressions to include in index
+ unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
"""
- def create_index_psql(conn):
+ def create_index_psql(conn: Connection) -> None:
conn.rollback()
# postgres insists on autocommit for the index
- conn.set_session(autocommit=True)
+ conn.set_session(autocommit=True) # type: ignore
try:
c = conn.cursor()
@@ -371,9 +384,9 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
finally:
- conn.set_session(autocommit=False)
+ conn.set_session(autocommit=False) # type: ignore
- def create_index_sqlite(conn):
+ def create_index_sqlite(conn: Connection) -> None:
# Sqlite doesn't support concurrent creation of indexes.
#
# We don't use partial indices on SQLite as it wasn't introduced
@@ -399,7 +412,7 @@ class BackgroundUpdater:
c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine):
- runner = create_index_psql
+ runner = create_index_psql # type: Optional[Callable[[Connection], None]]
elif psql_only:
runner = None
else:
@@ -433,7 +446,9 @@ class BackgroundUpdater:
"background_updates", keyvalues={"update_name": update_name}
)
- async def _background_update_progress(self, update_name: str, progress: dict):
+ async def _background_update_progress(
+ self, update_name: str, progress: dict
+ ) -> None:
"""Update the progress of a background update
Args:
@@ -441,20 +456,22 @@ class BackgroundUpdater:
progress: The progress of the update.
"""
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
progress,
)
- def _background_update_progress_txn(self, txn, update_name, progress):
+ def _background_update_progress_txn(
+ self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
+ ) -> None:
"""Update the progress of a background update
Args:
- txn(cursor): The transaction.
- update_name(str): The name of the background update task
- progress(dict): The progress of the update.
+ txn: The transaction.
+ update_name: The name of the background update task
+ progress: The progress of the update.
"""
progress_json = json_encoder.encode(progress)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 43660ec4..701748f9 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -149,9 +149,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
- )
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
@@ -342,12 +339,13 @@ class DataStore(
filters = []
args = [self.hs.config.server_name]
+ # `name` is in database already in lower case
if name:
- filters.append("(name LIKE ? OR displayname LIKE ?)")
- args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
+ args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
elif user_id:
filters.append("name LIKE ?")
- args.extend(["%" + user_id + "%"])
+ args.extend(["%" + user_id.lower() + "%"])
if not guests:
filters.append("is_guest = 0")
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 339bd691..e96a8b3f 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,11 +14,12 @@
# limitations under the License.
import logging
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.types import UserID
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -546,7 +547,9 @@ class ClientIpStore(ClientIpWorkerStore):
}
return ret
- async def get_user_ip_and_agents(self, user):
+ async def get_user_ip_and_agents(
+ self, user: UserID
+ ) -> List[Dict[str, Union[str, int]]]:
user_id = user.to_string()
results = {}
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index dfb4f87b..90976776 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
+ async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+ """Retrieve number of all devices of given users.
+ Only returns number of devices that are not marked as hidden.
+
+ Args:
+ user_ids: The IDs of the users which owns devices
+ Returns:
+ Number of devices of this users.
+ """
+
+ def count_devices_by_users_txn(txn, user_ids):
+ sql = """
+ SELECT count(*)
+ FROM devices
+ WHERE
+ hidden = '0' AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", user_ids
+ )
+
+ txn.execute(sql + clause, args)
+ return txn.fetchone()[0]
+
+ if not user_ids:
+ return 0
+
+ return await self.db_pool.runInteraction(
+ "count_devices_by_users", count_devices_by_users_txn, user_ids
+ )
+
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 2e07c373..ebffd892 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
- async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
+ async def get_auth_chain_difference(
+ self, room_id: str, state_sets: List[Set[str]]
+ ) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 2e56dfaf..e5c03cc6 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -894,16 +894,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
- async def get_latest_push_action_stream_ordering(self):
- def f(txn):
- txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
- return txn.fetchone()
-
- result = await self.db_pool.runInteraction(
- "get_latest_push_action_stream_ordering", f
- )
- return result[0] or 0
-
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index f8f4bb9b..04ac2d0c 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
from synapse.storage.keys import FetchKeyResult
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore):
)
async def get_server_verify_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
- ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
+ ) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
@@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore):
"""
keys = {}
- def _get_keys(txn, batch):
+ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
@@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore):
# `ts_valid_until_ms`.
ts_valid_until_ms = 0
- res = FetchKeyResult(
+ keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
- keys[(server_name, key_id)] = res
- def _txn(txn):
+ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7997242d..77ba9d81 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -15,18 +15,32 @@
# limitations under the License.
import logging
-from typing import Iterable, Iterator, List, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
from canonicaljson import encode_canonical_json
+from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+ self._pushers_id_gen = StreamIdGenerator(
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+ )
+
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
@@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore):
)
continue
- yield r
+ yield PusherConfig(**r)
- async def user_has_pusher(self, user_id):
+ async def user_has_pusher(self, user_id: str) -> bool:
ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
- def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
- return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
+ async def get_pushers_by_app_id_and_pushkey(
+ self, app_id: str, pushkey: str
+ ) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
- def get_pushers_by_user_id(self, user_id):
- return self.get_pushers_by({"user_name": user_id})
+ async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"user_name": user_id})
- async def get_pushers_by(self, keyvalues):
+ async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
@@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- async def get_all_pushers(self):
+ async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
@@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore):
)
@cached(num_args=1, max_entries=15000)
- async def get_if_user_has_pusher(self, user_id):
+ async def get_if_user_has_pusher(self, user_id: str):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- async def get_if_users_have_pushers(self, user_ids):
+ async def get_if_users_have_pushers(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, bool]:
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
@@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
async def update_pusher_failing_since(
- self, app_id, pushkey, user_id, failing_since
+ self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
) -> None:
await self.db_pool.simple_update(
table="pushers",
@@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore):
desc="update_pusher_failing_since",
)
- async def get_throttle_params_by_room(self, pusher_id):
+ async def get_throttle_params_by_room(
+ self, pusher_id: str
+ ) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
@@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore):
params_by_room = {}
for row in res:
- params_by_room[row["room_id"]] = {
- "last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"],
- }
+ params_by_room[row["room_id"]] = ThrottleParams(
+ row["last_sent_ts"], row["throttle_ms"],
+ )
return params_by_room
- async def set_throttle_params(self, pusher_id, room_id, params) -> None:
+ async def set_throttle_params(
+ self, pusher_id: str, room_id: str, params: ThrottleParams
+ ) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
- params,
+ {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params",
lock=False,
)
class PusherStore(PusherWorkerStore):
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- pushkey_ts,
- lang,
- data,
- last_stream_ordering,
- profile_tag="",
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ pushkey_ts: int,
+ lang: Optional[str],
+ data: Optional[JsonDict],
+ last_stream_ordering: int,
+ profile_tag: str = "",
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
@@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore):
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
"add_pusher",
- self._invalidate_cache_and_stream,
+ self._invalidate_cache_and_stream, # type: ignore
self.get_if_user_has_pusher,
(user_id,),
)
async def delete_pusher_by_app_id_pushkey_user_id(
- self, app_id, pushkey, user_id
+ self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
- self._invalidate_cache_and_stream(
+ self._invalidate_cache_and_stream( # type: ignore
txn, self.get_if_user_has_pusher, (user_id,)
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fedb8a6c..8d05288e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_external_id",
)
+ async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
+ """Look up external ids for the given user
+
+ Args:
+ mxid: the MXID to be looked up
+
+ Returns:
+ Tuples of (auth_provider, external_id)
+ """
+ res = await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ )
+ return [(r["auth_provider"], r["external_id"]) for r in res]
+
async def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -926,6 +943,42 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="del_user_pending_deactivation",
)
+ async def get_access_token_last_validated(self, token_id: int) -> int:
+ """Retrieves the time (in milliseconds) of the last validation of an access token.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if the access token was not found.
+
+ Returns:
+ The last validation time.
+ """
+ result = await self.db_pool.simple_select_one_onecol(
+ "access_tokens", {"id": token_id}, "last_validated"
+ )
+
+ # If this token has not been validated (since starting to track this),
+ # return 0 instead of None.
+ return result or 0
+
+ async def update_access_token_last_validated(self, token_id: int) -> None:
+ """Updates the last time an access token was validated.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if there was a problem updating this.
+ """
+ now = self._clock.time_msec()
+
+ await self.db_pool.simple_update_one(
+ "access_tokens",
+ {"id": token_id},
+ {"last_validated": now},
+ desc="update_access_token_last_validated",
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -963,6 +1016,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
+ self.db_pool.updates.register_background_index_update(
+ "user_external_ids_user_id_idx",
+ index_name="user_external_ids_user_id_idx",
+ table="user_external_ids",
+ columns=["user_id"],
+ unique=False,
+ )
+
async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
@@ -1125,6 +1186,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
The token ID
"""
next_id = self._access_tokens_id_gen.get_next()
+ now = self._clock.time_msec()
await self.db_pool.simple_insert(
"access_tokens",
@@ -1135,6 +1197,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
+ "last_validated": now,
},
desc="add_access_token_to_user",
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6b89db15..4650d068 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -379,14 +379,14 @@ class RoomWorkerStore(SQLBaseStore):
# Filter room names by a string
where_statement = ""
if search_term:
- where_statement = "WHERE state.name LIKE ?"
+ where_statement = "WHERE LOWER(state.name) LIKE ?"
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term + "%"
+ search_term = "%" + search_term.lower() + "%"
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
new file mode 100644
index 00000000..8f5e65aa
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5825, 'user_external_ids_user_id_idx', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
new file mode 100644
index 00000000..1a101cd5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The last time this access token was "validated" (i.e. logged in or succeeded
+-- at user-interactive authentication).
+ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/27local_invites.sql b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
new file mode 100644
index 00000000..44b2a057
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
@@ -0,0 +1,18 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is unused since Synapse v1.17.0.
+DROP TABLE local_invites;
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index d87ceec6..ef11f1c3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -17,7 +17,7 @@ import logging
import re
from typing import Any, Dict, Iterable, Optional, Set, Tuple
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
@@ -360,7 +360,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
- if hist_vis_ev.content.get("history_visibility") == "world_readable":
+ if (
+ hist_vis_ev.content.get("history_visibility")
+ == HistoryVisibility.WORLD_READABLE
+ ):
return True
return False
@@ -393,9 +396,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute(
@@ -415,9 +418,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
)
"""
txn.execute(
@@ -432,9 +435,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
elif new_entry is False:
sql = """
UPDATE user_directory_search
- SET vector = setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ SET vector = setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
txn.execute(
@@ -761,7 +764,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
INNER JOIN user_directory AS d USING (user_id)
WHERE
%s
- AND vector @@ to_tsquery('english', ?)
+ AND vector @@ to_tsquery('simple', ?)
ORDER BY
(CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
* (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
@@ -770,13 +773,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
3 * ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
- to_tsquery('english', ?),
+ to_tsquery('simple', ?),
8
)
+ ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
- to_tsquery('english', ?),
+ to_tsquery('simple', ?),
8
)
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index afd10f7b..c03871f3 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -17,11 +17,12 @@
import logging
import attr
+from signedjson.types import VerifyKey
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class FetchKeyResult:
- verify_key = attr.ib() # VerifyKey: the key itself
- valid_until_ts = attr.ib() # int: how long we can use this key for
+ verify_key = attr.ib(type=VerifyKey) # the key itself
+ valid_until_ts = attr.ib(type=int) # how long we can use this key for
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 70e636b0..61fc49c6 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,14 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ Collection,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -68,6 +75,21 @@ stale_forward_extremities_counter = Histogram(
buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
)
+state_resolutions_during_persistence = Counter(
+ "synapse_storage_events_state_resolutions_during_persistence",
+ "Number of times we had to do state res to calculate new current state",
+)
+
+potential_times_prune_extremities = Counter(
+ "synapse_storage_events_potential_times_prune_extremities",
+ "Number of times we might be able to prune extremities",
+)
+
+times_pruned_extremities = Counter(
+ "synapse_storage_events_times_pruned_extremities",
+ "Number of times we were actually be able to prune extremities",
+)
+
class _EventPeristenceQueue:
"""Queues up events so that they can be persisted in bulk with only one
@@ -454,7 +476,15 @@ class EventsPersistenceStorage:
latest_event_ids,
new_latest_event_ids,
)
- current_state, delta_ids = res
+ current_state, delta_ids, new_latest_event_ids = res
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
# If either are not None then there has been a change,
# and we need to work out the delta (or use that
@@ -573,29 +603,35 @@ class EventsPersistenceStorage:
self,
room_id: str,
events_context: List[Tuple[EventBase, EventContext]],
- old_latest_event_ids: Iterable[str],
- new_latest_event_ids: Iterable[str],
- ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+ old_latest_event_ids: Set[str],
+ new_latest_event_ids: Set[str],
+ ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
"""Calculate the current state dict after adding some new events to
a room
Args:
- room_id (str):
+ room_id:
room to which the events are being added. Used for logging etc
- events_context (list[(EventBase, EventContext)]):
+ events_context:
events and contexts which are being added to the room
- old_latest_event_ids (iterable[str]):
+ old_latest_event_ids:
the old forward extremities for the room.
- new_latest_event_ids (iterable[str]):
+ new_latest_event_ids :
the new forward extremities for the room.
Returns:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
+ Returns a tuple of two state maps and a set of new forward
+ extremities.
+
+ The first state map is the full new current state and the second
+ is the delta to the existing current state. If both are None then
+ there has been no change.
+
+ The function may prune some old entries from the set of new
+ forward extremities if it's safe to do so.
If there has been a change then we only return the delta if its
already been calculated. Conversely if we do know the delta then
@@ -672,7 +708,7 @@ class EventsPersistenceStorage:
# If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
- return None, None
+ return None, None, new_latest_event_ids
if len(new_state_groups) == 1 and len(old_state_groups) == 1:
# If we're going from one state group to another, lets check if
@@ -689,7 +725,7 @@ class EventsPersistenceStorage:
# the current state in memory then lets also return that,
# but it doesn't matter if we don't.
new_state = state_groups_map.get(new_state_group)
- return new_state, delta_ids
+ return new_state, delta_ids, new_latest_event_ids
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
@@ -701,7 +737,7 @@ class EventsPersistenceStorage:
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- return state_groups_map[new_state_groups.pop()], None
+ return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids
# Ok, we need to defer to the state handler to resolve our state sets.
@@ -734,7 +770,139 @@ class EventsPersistenceStorage:
state_res_store=StateResolutionStore(self.main_store),
)
- return res.state, None
+ state_resolutions_during_persistence.inc()
+
+ # If the returned state matches the state group of one of the new
+ # forward extremities then we check if we are able to prune some state
+ # extremities.
+ if res.state_group and res.state_group in new_state_groups:
+ new_latest_event_ids = await self._prune_extremities(
+ room_id,
+ new_latest_event_ids,
+ res.state_group,
+ event_id_to_state_group,
+ events_context,
+ )
+
+ return res.state, None, new_latest_event_ids
+
+ async def _prune_extremities(
+ self,
+ room_id: str,
+ new_latest_event_ids: Set[str],
+ resolved_state_group: int,
+ event_id_to_state_group: Dict[str, int],
+ events_context: List[Tuple[EventBase, EventContext]],
+ ) -> Set[str]:
+ """See if we can prune any of the extremities after calculating the
+ resolved state.
+ """
+ potential_times_prune_extremities.inc()
+
+ # We keep all the extremities that have the same state group, and
+ # see if we can drop the others.
+ new_new_extrems = {
+ e
+ for e in new_latest_event_ids
+ if event_id_to_state_group[e] == resolved_state_group
+ }
+
+ dropped_extrems = set(new_latest_event_ids) - new_new_extrems
+
+ logger.debug("Might drop extremities: %s", dropped_extrems)
+
+ # We only drop events from the extremities list if:
+ # 1. we're not currently persisting them;
+ # 2. they're not our own events (or are dummy events); and
+ # 3. they're either:
+ # 1. over N hours old and more than N events ago (we use depth to
+ # calculate); or
+ # 2. we are persisting an event from the same domain and more than
+ # M events ago.
+ #
+ # The idea is that we don't want to drop events that are "legitimate"
+ # extremities (that we would want to include as prev events), only
+ # "stuck" extremities that are e.g. due to a gap in the graph.
+ #
+ # Note that we either drop all of them or none of them. If we only drop
+ # some of the events we don't know if state res would come to the same
+ # conclusion.
+
+ for ev, _ in events_context:
+ if ev.event_id in dropped_extrems:
+ logger.debug(
+ "Not dropping extremities: %s is being persisted", ev.event_id
+ )
+ return new_latest_event_ids
+
+ dropped_events = await self.main_store.get_events(
+ dropped_extrems,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+
+ new_senders = {get_domain_from_id(e.sender) for e, _ in events_context}
+
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ current_depth = max(e.depth for e, _ in events_context)
+ for event in dropped_events.values():
+ # If the event is a local dummy event then we should check it
+ # doesn't reference any local events, as we want to reference those
+ # if we send any new events.
+ #
+ # Note we do this recursively to handle the case where a dummy event
+ # references a dummy event that only references remote events.
+ #
+ # Ideally we'd figure out a way of still being able to drop old
+ # dummy events that reference local events, but this is good enough
+ # as a first cut.
+ events_to_check = [event]
+ while events_to_check:
+ new_events = set()
+ for event_to_check in events_to_check:
+ if self.is_mine_id(event_to_check.sender):
+ if event_to_check.type != EventTypes.Dummy:
+ logger.debug("Not dropping own event")
+ return new_latest_event_ids
+ new_events.update(event_to_check.prev_event_ids())
+
+ prev_events = await self.main_store.get_events(
+ new_events,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ events_to_check = prev_events.values()
+
+ if (
+ event.origin_server_ts < one_day_ago
+ and event.depth < current_depth - 100
+ ):
+ continue
+
+ # We can be less conservative about dropping extremities from the
+ # same domain, though we do want to wait a little bit (otherwise
+ # we'll immediately remove all extremities from a given server).
+ if (
+ get_domain_from_id(event.sender) in new_senders
+ and event.depth < current_depth - 20
+ ):
+ continue
+
+ logger.debug(
+ "Not dropping as too new and not in new_senders: %s", new_senders,
+ )
+
+ return new_latest_event_ids
+
+ times_pruned_extremities.inc()
+
+ logger.info(
+ "Pruning forward extremities in room %s: from %s -> %s",
+ room_id,
+ new_latest_event_ids,
+ new_new_extrems,
+ )
+ return new_new_extrems
async def _calculate_state_delta(
self, room_id: str, current_state: StateMap[str]
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754fe..f91a2eae 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging
import os
import re
from collections import Counter
-from typing import Optional, TextIO
+from typing import Generator, Iterable, List, Optional, TextIO, Tuple
import attr
+from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
@@ -70,7 +71,7 @@ def prepare_database(
db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
- databases: Collection[str] = ["main", "state"],
+ databases: Collection[str] = ("main", "state"),
):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -155,7 +156,9 @@ def prepare_database(
raise
-def _setup_new_database(cur, database_engine, databases):
+def _setup_new_database(
+ cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
+) -> None:
"""Sets up the physical database by finding a base set of "full schemas" and
then applying any necessary deltas, including schemas from the given data
stores.
@@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
folder as well those in the data stores specified.
Args:
- cur (Cursor): a database cursor
- database_engine (DatabaseEngine)
- databases (list[str]): The names of the databases to instantiate
- on the given physical database.
+ cur: a database cursor
+ database_engine
+ databases: The names of the databases to instantiate on the given physical database.
"""
# We're about to set up a brand new database so we check that its
@@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
database_engine.check_new_database(cur)
current_dir = os.path.join(dir_path, "schema", "full_schemas")
- directory_entries = os.listdir(current_dir)
# First we find the highest full schema version we have
valid_versions = []
- for filename in directory_entries:
+ for filename in os.listdir(current_dir):
try:
ver = int(filename)
except ValueError:
@@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
for database in databases
)
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
def _upgrade_existing_database(
- cur,
- current_version,
- applied_delta_files,
- upgraded,
- database_engine,
- config,
- databases,
- is_empty=False,
-):
+ cur: Cursor,
+ current_version: int,
+ applied_delta_files: List[str],
+ upgraded: bool,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str],
+ is_empty: bool = False,
+) -> None:
"""Upgrades an existing physical database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -323,21 +324,20 @@ def _upgrade_existing_database(
for a version before applying those in the next version.
Args:
- cur (Cursor)
- current_version (int): The current version of the schema.
- applied_delta_files (list): A list of deltas that have already been
- applied.
- upgraded (bool): Whether the current version was generated by having
+ cur
+ current_version: The current version of the schema.
+ applied_delta_files: A list of deltas that have already been applied.
+ upgraded: Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
- database_engine (DatabaseEngine)
- config (synapse.config.homeserver.HomeServerConfig|None):
+ database_engine
+ config:
None if we are initialising a blank database, otherwise the application
config
- databases (list[str]): The names of the databases to instantiate
+ databases: The names of the databases to instantiate
on the given physical database.
- is_empty (bool): Is this a blank database? I.e. do we need to run the
+ is_empty: Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts.
"""
if is_empty:
@@ -358,6 +358,7 @@ def _upgrade_existing_database(
if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade
+ assert config is not None
check_database_before_upgrade(cur, database_engine, config)
start_ver = current_version
@@ -388,10 +389,10 @@ def _upgrade_existing_database(
)
# Used to check if we have any duplicate file names
- file_name_counter = Counter()
+ file_name_counter = Counter() # type: CounterType[str]
# Now find which directories have anything of interest.
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
logger.debug("Looking for schema deltas in %s", directory)
try:
@@ -445,11 +446,11 @@ def _upgrade_existing_database(
module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
- module = imp.load_source(module_name, absolute_path, python_file)
+ module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
logger.info("Running script %s", relative_path)
- module.run_create(cur, database_engine)
+ module.run_create(cur, database_engine) # type: ignore
if not is_empty:
- module.run_upgrade(cur, database_engine, config=config)
+ module.run_upgrade(cur, database_engine, config=config) # type: ignore
elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
@@ -497,14 +498,15 @@ def _upgrade_existing_database(
logger.info("Schema now up to date")
-def _apply_module_schemas(txn, database_engine, config):
+def _apply_module_schemas(
+ txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
"""Apply the module schemas for the dynamic modules, if any
Args:
cur: database cursor
- database_engine: synapse database engine class
- config (synapse.config.homeserver.HomeServerConfig):
- application config
+ database_engine:
+ config: application config
"""
for (mod, _config) in config.password_providers:
if not hasattr(mod, "get_db_schema_files"):
@@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
)
-def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+def _apply_module_schema_files(
+ cur: Cursor,
+ database_engine: BaseDatabaseEngine,
+ modname: str,
+ names_and_streams: Iterable[Tuple[str, TextIO]],
+) -> None:
"""Apply the module schemas for a single module
Args:
cur: database cursor
database_engine: synapse database engine class
- modname (str): fully qualified name of the module
- names_and_streams (Iterable[(str, file)]): the names and streams of
- schemas to be applied
+ modname: fully qualified name of the module
+ names_and_streams: the names and streams of schemas to be applied
"""
cur.execute(
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
@@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
)
-def get_statements(f):
+def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
@@ -594,17 +600,19 @@ def get_statements(f):
statement_buffer = statements[-1].strip()
-def executescript(txn, schema_path):
+def executescript(txn: Cursor, schema_path: str) -> None:
with open(schema_path, "r") as f:
execute_statements_from_stream(txn, f)
-def execute_statements_from_stream(cur: Cursor, f: TextIO):
+def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
for statement in get_statements(f):
cur.execute(statement)
-def _get_or_create_schema_state(txn, database_engine):
+def _get_or_create_schema_state(
+ txn: Cursor, database_engine: BaseDatabaseEngine
+) -> Optional[Tuple[int, List[str], bool]]:
# Bluntly try creating the schema_version tables.
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
executescript(txn, schema_path)
@@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None
- upgraded = bool(row[1]) if row else None
if current_version:
txn.execute(
@@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
(current_version,),
)
applied_deltas = [d for d, in txn]
+ upgraded = bool(row[1])
return current_version, applied_deltas, upgraded
return None
@@ -634,5 +642,5 @@ class _DirectoryListing:
`file_name` attr is kept first.
"""
- file_name = attr.ib()
- absolute_path = attr.ib()
+ file_name = attr.ib(type=str)
+ absolute_path = attr.ib(type=str)
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index bfa0a9fd..6c359c1a 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,7 +15,12 @@
import itertools
import logging
-from typing import Set
+from typing import TYPE_CHECKING, Set
+
+from synapse.storage.databases import Databases
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -24,10 +29,10 @@ class PurgeEventsStorage:
"""High level interface for purging rooms and event history.
"""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: Databases):
self.stores = stores
- async def purge_room(self, room_id: str):
+ async def purge_room(self, room_id: str) -> None:
"""Deletes all record of a room
"""
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index cec96ad6..2564f34b 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+from typing import Any, Dict, List, Optional, Tuple
import attr
from synapse.api.errors import SynapseError
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -27,18 +29,18 @@ class PaginationChunk:
"""Returned by relation pagination APIs.
Attributes:
- chunk (list): The rows returned by pagination
- next_batch (Any|None): Token to fetch next set of results with, if
+ chunk: The rows returned by pagination
+ next_batch: Token to fetch next set of results with, if
None then there are no more results.
- prev_batch (Any|None): Token to fetch previous set of results with, if
+ prev_batch: Token to fetch previous set of results with, if
None then there are no previous results.
"""
- chunk = attr.ib()
- next_batch = attr.ib(default=None)
- prev_batch = attr.ib(default=None)
+ chunk = attr.ib(type=List[JsonDict])
+ next_batch = attr.ib(type=Optional[Any], default=None)
+ prev_batch = attr.ib(type=Optional[Any], default=None)
- def to_dict(self):
+ def to_dict(self) -> Dict[str, Any]:
d = {"chunk": self.chunk}
if self.next_batch:
@@ -59,25 +61,25 @@ class RelationPaginationToken:
boundaries of the chunk as pagination tokens.
Attributes:
- topological (int): The topological ordering of the boundary event
- stream (int): The stream ordering of the boundary event.
+ topological: The topological ordering of the boundary event
+ stream: The stream ordering of the boundary event.
"""
- topological = attr.ib()
- stream = attr.ib()
+ topological = attr.ib(type=int)
+ stream = attr.ib(type=int)
@staticmethod
- def from_string(string):
+ def from_string(string: str) -> "RelationPaginationToken":
try:
t, s = string.split("-")
return RelationPaginationToken(int(t), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
- def to_string(self):
+ def to_string(self) -> str:
return "%d-%d" % (self.topological, self.stream)
- def as_tuple(self):
+ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
@@ -89,23 +91,23 @@ class AggregationPaginationToken:
aggregation groups, we can just use them as our pagination token.
Attributes:
- count (int): The count of relations in the boundar group.
- stream (int): The MAX stream ordering in the boundary group.
+ count: The count of relations in the boundary group.
+ stream: The MAX stream ordering in the boundary group.
"""
- count = attr.ib()
- stream = attr.ib()
+ count = attr.ib(type=int)
+ stream = attr.ib(type=int)
@staticmethod
- def from_string(string):
+ def from_string(string: str) -> "AggregationPaginationToken":
try:
c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
- def to_string(self):
+ def to_string(self) -> str:
return "%d-%d" % (self.count, self.stream)
- def as_tuple(self):
+ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f..31ccbf23 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
import attr
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.storage.databases import Databases
+
logger = logging.getLogger(__name__)
# Used for generic functions below
@@ -330,10 +343,12 @@ class StateGroupStorage:
"""High level interface to fetching state for event.
"""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
- async def get_state_group_delta(self, state_group: int):
+ async def get_state_group_delta(
+ self, state_group: int
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -341,8 +356,8 @@ class StateGroupStorage:
state_group: The state group used to retrieve state deltas.
Returns:
- Tuple[Optional[int], Optional[StateMap[str]]]:
- (prev_group, delta_ids)
+ A tuple of the previous group and a state map of the event IDs which
+ make up the delta between the old and new state groups.
"""
return await self.stores.state.get_state_group_delta(state_group)
@@ -436,7 +451,7 @@ class StateGroupStorage:
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -472,7 +487,7 @@ class StateGroupStorage:
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
@@ -500,7 +515,7 @@ class StateGroupStorage:
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
@@ -516,7 +531,7 @@ class StateGroupStorage:
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 02d71302..133c0e7a 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -153,12 +153,12 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager())
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
- int
+ The maximum stream id.
"""
with self._lock:
if self._unfinished_ids:
diff --git a/synapse/types.py b/synapse/types.py
index 3ab6bdbe..c7d4e958 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -349,15 +349,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile(
)
-def map_username_to_mxid_localpart(username, case_sensitive=False):
+def map_username_to_mxid_localpart(
+ username: Union[str, bytes], case_sensitive: bool = False
+) -> str:
"""Map a username onto a string suitable for a MXID
This follows the algorithm laid out at
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
Args:
- username (unicode|bytes): username to be mapped
- case_sensitive (bool): true if TEST and test should be mapped
+ username: username to be mapped
+ case_sensitive: true if TEST and test should be mapped
onto different mxids
Returns:
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3..9a873c8e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
# limitations under the License.
import collections
+import inspect
import logging
from contextlib import contextmanager
from typing import (
Any,
+ Awaitable,
Callable,
Dict,
Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
raise StopIteration(self.value)
-def maybe_awaitable(value):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable.
"""
-
- if hasattr(value, "__await__"):
+ if inspect.isawaitable(value):
+ assert isinstance(value, Awaitable)
return value
return DoneAwaitable(value)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e9539..a6ee9eda 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -105,10 +105,7 @@ class Signal:
async def do(observer):
try:
- result = observer(*args, **kwargs)
- if inspect.isawaitable(result):
- result = await result
- return result
+ return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
logger.warning(
"%s signal observer %s failed: %r", self.name, observer, e,
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 94b59afb..1ee61851 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -15,28 +15,56 @@
import importlib
import importlib.util
+import itertools
+from typing import Any, Iterable, Tuple, Type
+
+import jsonschema
from synapse.config._base import ConfigError
+from synapse.config._util import json_error_to_config_error
-def load_module(provider):
+def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
""" Loads a synapse module with its config
- Take a dict with keys 'module' (the module name) and 'config'
- (the config dict).
+
+ Args:
+ provider: a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+ config_path: the path within the config file. This will be used as a basis
+ for any error message.
Returns
Tuple of (provider class, parsed config object)
"""
+
+ modulename = provider.get("module")
+ if not isinstance(modulename, str):
+ raise ConfigError(
+ "expected a string", path=itertools.chain(config_path, ("module",))
+ )
+
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = provider["module"].rsplit(".", 1)
+ module, clz = modulename.rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
+ module_config = provider.get("config")
try:
- provider_config = provider_class.parse_config(provider.get("config"))
+ provider_config = provider_class.parse_config(module_config)
+ except jsonschema.ValidationError as e:
+ raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
+ except ConfigError as e:
+ raise _wrap_config_error(
+ "Failed to parse config for module %r" % (modulename,),
+ prefix=itertools.chain(config_path, ("config",)),
+ e=e,
+ )
except Exception as e:
- raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
+ raise ConfigError(
+ "Failed to parse config for module %r" % (modulename,),
+ path=itertools.chain(config_path, ("config",)),
+ ) from e
return provider_class, provider_config
@@ -56,3 +84,27 @@ def load_python_module(location: str):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
+
+
+def _wrap_config_error(
+ msg: str, prefix: Iterable[str], e: ConfigError
+) -> "ConfigError":
+ """Wrap a relative ConfigError with a new path
+
+ This is useful when we have a ConfigError with a relative path due to a problem
+ parsing part of the config, and we now need to set it in context.
+ """
+ path = prefix
+ if e.path:
+ path = itertools.chain(prefix, e.path)
+
+ e1 = ConfigError(msg, path)
+
+ # ideally we would set the 'cause' of the new exception to the original exception;
+ # however now that we have merged the path into our own, the stringification of
+ # e will be incorrect, so instead we create a new exception with just the "msg"
+ # part.
+
+ e1.__cause__ = Exception(e.msg)
+ e1.__cause__.__cause__ = e.__cause__
+ return e1
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 52736549..ec50e7e9 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -12,11 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import operator
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import (
+ AccountDataTypes,
+ EventTypes,
+ HistoryVisibility,
+ Membership,
+)
from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.state import StateFilter
@@ -25,7 +29,12 @@ from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
-VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined")
+VISIBILITY_PRIORITY = (
+ HistoryVisibility.WORLD_READABLE,
+ HistoryVisibility.SHARED,
+ HistoryVisibility.INVITED,
+ HistoryVisibility.JOINED,
+)
MEMBERSHIP_PRIORITY = (
@@ -116,7 +125,7 @@ async def filter_events_for_client(
# see events in the room at that point in the DAG, and that shouldn't be decided
# on those checks.
if filter_send_to_client:
- if event.type == "org.matrix.dummy_event":
+ if event.type == EventTypes.Dummy:
return None
if not event.is_state() and event.sender in ignore_list:
@@ -150,12 +159,14 @@ async def filter_events_for_client(
# get the room_visibility at the time of the event.
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
- visibility = visibility_event.content.get("history_visibility", "shared")
+ visibility = visibility_event.content.get(
+ "history_visibility", HistoryVisibility.SHARED
+ )
else:
- visibility = "shared"
+ visibility = HistoryVisibility.SHARED
if visibility not in VISIBILITY_PRIORITY:
- visibility = "shared"
+ visibility = HistoryVisibility.SHARED
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
@@ -165,7 +176,7 @@ async def filter_events_for_client(
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
- prev_visibility = "shared"
+ prev_visibility = HistoryVisibility.SHARED
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
@@ -210,17 +221,17 @@ async def filter_events_for_client(
# otherwise, it depends on the room visibility.
- if visibility == "joined":
+ if visibility == HistoryVisibility.JOINED:
# we weren't a member at the time of the event, so we can't
# see this event.
return None
- elif visibility == "invited":
+ elif visibility == HistoryVisibility.INVITED:
# user can also see the event if they were *invited* at the time
# of the event.
return event if membership == Membership.INVITE else None
- elif visibility == "shared" and is_peeking:
+ elif visibility == HistoryVisibility.SHARED and is_peeking:
# if the visibility is shared, users cannot see the event unless
# they have *subequently* joined the room (or were members at the
# time, of course)
@@ -284,8 +295,10 @@ async def filter_events_for_server(
def check_event_is_visible(event, state):
history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if history:
- visibility = history.content.get("history_visibility", "shared")
- if visibility in ["invited", "joined"]:
+ visibility = history.content.get(
+ "history_visibility", HistoryVisibility.SHARED
+ )
+ if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
@@ -305,7 +318,7 @@ async def filter_events_for_server(
if memtype == Membership.JOIN:
return True
elif memtype == Membership.INVITE:
- if visibility == "invited":
+ if visibility == HistoryVisibility.INVITED:
return True
else:
# server has no users in the room: redact
@@ -336,7 +349,8 @@ async def filter_events_for_server(
else:
event_map = await storage.main.get_events(visibility_ids)
all_open = all(
- e.content.get("history_visibility") in (None, "shared", "world_readable")
+ e.content.get("history_visibility")
+ in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
for e in event_map.values()
)