summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-09-14 21:38:49 +0100
committerAndrej Shadura <andrewsh@debian.org>2021-09-14 21:38:49 +0100
commit255062dc0befee4b07ad1f566e75ffbab3c8ac50 (patch)
treea5a4fe8006e1be63140cea69eeddb353a958aecd /synapse
parent8f729bf8d275ec12051a30581f5eae996eadf802 (diff)
New upstream version 1.42.0
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/constants.py1
-rw-r--r--synapse/api/errors.py8
-rw-r--r--synapse/api/room_versions.py31
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/app/generic_worker.py8
-rw-r--r--synapse/config/account_validity.py14
-rw-r--r--synapse/config/emailconfig.py14
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/config/ratelimiting.py11
-rw-r--r--synapse/config/registration.py15
-rw-r--r--synapse/config/server.py15
-rw-r--r--synapse/config/sso.py13
-rw-r--r--synapse/events/presence_router.py192
-rw-r--r--synapse/events/utils.py7
-rw-r--r--synapse/events/validator.py77
-rw-r--r--synapse/federation/federation_client.py113
-rw-r--r--synapse/federation/federation_server.py9
-rw-r--r--synapse/handlers/auth.py28
-rw-r--r--synapse/handlers/federation.py1927
-rw-r--r--synapse/handlers/federation_event.py1825
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/presence.py5
-rw-r--r--synapse/handlers/register.py18
-rw-r--r--synapse/handlers/room_member.py17
-rw-r--r--synapse/handlers/room_summary.py76
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/handlers/sync.py198
-rw-r--r--synapse/handlers/ui_auth/__init__.py5
-rw-r--r--synapse/handlers/ui_auth/checkers.py75
-rw-r--r--synapse/http/additional_resource.py13
-rw-r--r--synapse/http/federation/srv_resolver.py45
-rw-r--r--synapse/http/proxyagent.py4
-rw-r--r--synapse/module_api/__init__.py10
-rw-r--r--synapse/python_dependencies.py3
-rw-r--r--synapse/replication/http/federation.py4
-rw-r--r--synapse/res/templates/recaptcha.html3
-rw-r--r--synapse/res/templates/registration_token.html23
-rw-r--r--synapse/res/templates/terms.html3
-rw-r--r--synapse/rest/admin/__init__.py17
-rw-r--r--synapse/rest/admin/purge_room_servlet.py58
-rw-r--r--synapse/rest/admin/registration_tokens.py321
-rw-r--r--synapse/rest/admin/rooms.py35
-rw-r--r--synapse/rest/admin/server_notice_servlet.py19
-rw-r--r--synapse/rest/admin/users.py55
-rw-r--r--synapse/rest/client/account_validity.py39
-rw-r--r--synapse/rest/client/auth.py74
-rw-r--r--synapse/rest/client/capabilities.py14
-rw-r--r--synapse/rest/client/devices.py48
-rw-r--r--synapse/rest/client/directory.py78
-rw-r--r--synapse/rest/client/events.py38
-rw-r--r--synapse/rest/client/filter.py26
-rw-r--r--synapse/rest/client/groups.py3
-rw-r--r--synapse/rest/client/initial_sync.py16
-rw-r--r--synapse/rest/client/keys.py71
-rw-r--r--synapse/rest/client/knock.py3
-rw-r--r--synapse/rest/client/login.py48
-rw-r--r--synapse/rest/client/logout.py17
-rw-r--r--synapse/rest/client/notifications.py13
-rw-r--r--synapse/rest/client/openid.py16
-rw-r--r--synapse/rest/client/password_policy.py18
-rw-r--r--synapse/rest/client/presence.py24
-rw-r--r--synapse/rest/client/profile.py37
-rw-r--r--synapse/rest/client/pusher.py22
-rw-r--r--synapse/rest/client/read_marker.py15
-rw-r--r--synapse/rest/client/register.py85
-rw-r--r--synapse/rest/client/room_upgrade_rest_servlet.py18
-rw-r--r--synapse/rest/client/shared_rooms.py16
-rw-r--r--synapse/rest/client/sync.py132
-rw-r--r--synapse/rest/client/tags.py25
-rw-r--r--synapse/rest/client/thirdparty.py36
-rw-r--r--synapse/rest/client/tokenrefresh.py14
-rw-r--r--synapse/rest/client/user_directory.py17
-rw-r--r--synapse/rest/client/versions.py14
-rw-r--r--synapse/rest/client/voip.py13
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py26
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/server_notices/server_notices_manager.py17
-rw-r--r--synapse/static/client/register/style.css6
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py29
-rw-r--r--synapse/storage/databases/main/purge_events.py1
-rw-r--r--synapse/storage/databases/main/pusher.py73
-rw-r--r--synapse/storage/databases/main/registration.py341
-rw-r--r--synapse/storage/databases/main/roommember.py16
-rw-r--r--synapse/storage/databases/main/session.py145
-rw-r--r--synapse/storage/databases/main/ui_auth.py43
-rw-r--r--synapse/storage/databases/main/user_directory.py2
-rw-r--r--synapse/storage/roommember.py44
-rw-r--r--synapse/storage/schema/__init__.py2
-rw-r--r--synapse/storage/schema/main/delta/63/01create_registration_tokens.sql23
-rw-r--r--synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql20
-rw-r--r--synapse/storage/schema/main/delta/63/03session_store.sql23
93 files changed, 4448 insertions, 2686 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 06d80f79..dc7ae242 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.41.1"
+__version__ = "1.42.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e0e24fdd..829061c8 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -79,6 +79,7 @@ class LoginType:
TERMS = "m.login.terms"
SSO = "m.login.sso"
DUMMY = "m.login.dummy"
+ REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
# This is used in the `type` parameter for /register when called by
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index dc662bca..9480f448 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -147,6 +147,14 @@ class SynapseError(CodeMessageException):
return cs_error(self.msg, self.errcode)
+class InvalidAPICallError(SynapseError):
+ """You called an existing API endpoint, but fed that endpoint
+ invalid or incomplete data."""
+
+ def __init__(self, msg: str):
+ super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON)
+
+
class ProxiedRequestError(SynapseError):
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 8abcdfd4..a19be670 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -70,6 +70,9 @@ class RoomVersion:
msc2176_redaction_rules = attr.ib(type=bool)
# MSC3083: Support the 'restricted' join_rule.
msc3083_join_rules = attr.ib(type=bool)
+ # MSC3375: Support for the proper redaction rules for MSC3083. This mustn't
+ # be enabled if MSC3083 is not.
+ msc3375_redaction_rules = attr.ib(type=bool)
# MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
# m.room.membership event with membership 'knock'.
msc2403_knocking = attr.ib(type=bool)
@@ -92,6 +95,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
@@ -107,6 +111,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
@@ -122,6 +127,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
@@ -137,6 +143,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
@@ -152,6 +159,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
@@ -167,6 +175,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
@@ -182,6 +191,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=True,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
@@ -197,6 +207,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
@@ -212,6 +223,23 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=True,
+ msc3375_redaction_rules=False,
+ msc2403_knocking=True,
+ msc2716_historical=False,
+ msc2716_redactions=False,
+ )
+ V9 = RoomVersion(
+ "9",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=True,
+ msc3375_redaction_rules=True,
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
@@ -227,6 +255,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=True,
msc2716_historical=True,
msc2716_redactions=False,
@@ -242,6 +271,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
msc2403_knocking=True,
msc2716_historical=True,
msc2716_redactions=True,
@@ -261,6 +291,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V7,
RoomVersions.MSC2716,
RoomVersions.V8,
+ RoomVersions.V9,
)
}
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 50a02f51..39e28aff 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -37,6 +37,7 @@ from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
+from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.logging.context import PreserveLoggingContext
@@ -370,6 +371,7 @@ async def start(hs: "HomeServer"):
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
+ load_legacy_presence_router(hs)
# If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 845e6a82..9b71dd75 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
ProfileRestServlet,
)
from synapse.rest.client.push_rule import PushRuleRestServlet
-from synapse.rest.client.register import RegisterRestServlet
+from synapse.rest.client.register import (
+ RegisterRestServlet,
+ RegistrationTokenValidityRestServlet,
+)
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.client.voip import VoipRestServlet
@@ -115,6 +118,7 @@ from synapse.storage.databases.main.monthly_active_users import (
from synapse.storage.databases.main.presence import PresenceStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@@ -250,6 +254,7 @@ class GenericWorkerSlavedStore(
SearchStore,
TransactionWorkerStore,
LockStore,
+ SessionStore,
BaseSlavedStore,
):
pass
@@ -279,6 +284,7 @@ class GenericWorkerServer(HomeServer):
resource = JsonResource(self, canonical_json=False)
RegisterRestServlet(self).register(resource)
+ RegistrationTokenValidityRestServlet(self).register(resource)
login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index 52e63ab1..ffaffc49 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -11,8 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from synapse.config._base import Config, ConfigError
+logger = logging.getLogger(__name__)
+
+LEGACY_TEMPLATE_DIR_WARNING = """
+This server's configuration file is using the deprecated 'template_dir' setting in the
+'account_validity' section. Support for this setting has been deprecated and will be
+removed in a future version of Synapse. Server admins should instead use the new
+'custom_templates_directory' setting documented here:
+https://matrix-org.github.io/synapse/latest/templates.html
+---------------------------------------------------------------------------------------"""
+
class AccountValidityConfig(Config):
section = "account_validity"
@@ -69,6 +81,8 @@ class AccountValidityConfig(Config):
# Load account validity templates.
account_validity_template_dir = account_validity_config.get("template_dir")
+ if account_validity_template_dir is not None:
+ logger.warning(LEGACY_TEMPLATE_DIR_WARNING)
account_renewed_template_filename = account_validity_config.get(
"account_renewed_html_path", "account_renewed.html"
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 44774191..936abe61 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -16,6 +16,7 @@
# This file can't be called email.py because if it is, we cannot:
import email.utils
+import logging
import os
from enum import Enum
from typing import Optional
@@ -24,6 +25,8 @@ import attr
from ._base import Config, ConfigError
+logger = logging.getLogger(__name__)
+
MISSING_PASSWORD_RESET_CONFIG_ERROR = """\
Password reset emails are enabled on this homeserver due to a partial
'email' block. However, the following required keys are missing:
@@ -44,6 +47,14 @@ DEFAULT_SUBJECTS = {
"email_validation": "[%(server_name)s] Validate your email",
}
+LEGACY_TEMPLATE_DIR_WARNING = """
+This server's configuration file is using the deprecated 'template_dir' setting in the
+'email' section. Support for this setting has been deprecated and will be removed in a
+future version of Synapse. Server admins should instead use the new
+'custom_templates_directory' setting documented here:
+https://matrix-org.github.io/synapse/latest/templates.html
+---------------------------------------------------------------------------------------"""
+
@attr.s(slots=True, frozen=True)
class EmailSubjectConfig:
@@ -105,6 +116,9 @@ class EmailConfig(Config):
# A user-configurable template directory
template_dir = email_config.get("template_dir")
+ if template_dir is not None:
+ logger.warning(LEGACY_TEMPLATE_DIR_WARNING)
+
if isinstance(template_dir, str):
# We need an absolute path, because we change directory after starting (and
# we don't yet know what auxiliary templates like mail.css we will need).
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 907df959..95deda11 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -39,5 +39,8 @@ class ExperimentalConfig(Config):
# MSC3244 (room version capabilities)
self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
+ # MSC3283 (set displayname, avatar_url and change 3pid capabilities)
+ self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
+
# MSC3266 (room summary api)
self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 7a8d5851..f856327b 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -79,6 +79,11 @@ class RatelimitConfig(Config):
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_registration_token_validity = RateLimitConfig(
+ config.get("rc_registration_token_validity", {}),
+ defaults={"per_second": 0.1, "burst_count": 5},
+ )
+
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
@@ -143,6 +148,8 @@ class RatelimitConfig(Config):
# is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
+ # - one for checking the validity of registration tokens that ratelimits
+ # requests based on the client's IP address.
# - one for login that ratelimits login requests based on the client's IP
# address.
# - one for login that ratelimits login requests based on the account the
@@ -171,6 +178,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_registration_token_validity:
+ # per_second: 0.1
+ # burst_count: 5
+ #
#rc_login:
# address:
# per_second: 0.17
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 0ad919b1..7cffdacf 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -33,6 +33,9 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
+ self.registration_requires_token = config.get(
+ "registration_requires_token", False
+ )
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -140,6 +143,9 @@ class RegistrationConfig(Config):
"mechanism by removing the `access_token_lifetime` option."
)
+ # The fallback template used for authenticating using a registration token
+ self.registration_token_template = self.read_template("registration_token.html")
+
# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")
@@ -199,6 +205,15 @@ class RegistrationConfig(Config):
#
#enable_3pid_lookup: true
+ # Require users to submit a token during registration.
+ # Tokens can be managed using the admin API:
+ # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
+ # Note that `enable_registration` must be set to `true`.
+ # Disabling this option will not delete any tokens previously generated.
+ # Defaults to false. Uncomment the following to require tokens:
+ #
+ #registration_requires_token: true
+
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 84947959..d2c900f5 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -248,6 +248,7 @@ class ServerConfig(Config):
self.use_presence = config.get("use_presence", True)
# Custom presence router module
+ # This is the legacy way of configuring it (the config should now be put in the modules section)
self.presence_router_module_class = None
self.presence_router_config = None
presence_router_config = presence_config.get("presence_router")
@@ -870,20 +871,6 @@ class ServerConfig(Config):
#
#enabled: false
- # Presence routers are third-party modules that can specify additional logic
- # to where presence updates from users are routed.
- #
- presence_router:
- # The custom module's class. Uncomment to use a custom presence router module.
- #
- #module: "my_custom_router.PresenceRouter"
-
- # Configuration options of the custom module. Refer to your module's
- # documentation for available options.
- #
- #config:
- # example_option: 'something'
-
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index fe1177ab..524a7ff3 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -11,12 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from typing import Any, Dict, Optional
import attr
from ._base import Config
+logger = logging.getLogger(__name__)
+
+LEGACY_TEMPLATE_DIR_WARNING = """
+This server's configuration file is using the deprecated 'template_dir' setting in the
+'sso' section. Support for this setting has been deprecated and will be removed in a
+future version of Synapse. Server admins should instead use the new
+'custom_templates_directory' setting documented here:
+https://matrix-org.github.io/synapse/latest/templates.html
+---------------------------------------------------------------------------------------"""
+
@attr.s(frozen=True)
class SsoAttributeRequirement:
@@ -43,6 +54,8 @@ class SSOConfig(Config):
# The sso-specific template_dir
self.sso_template_dir = sso_config.get("template_dir")
+ if self.sso_template_dir is not None:
+ logger.warning(LEGACY_TEMPLATE_DIR_WARNING)
# Read templates from disk
custom_template_directories = (
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index 6c37c8a7..eb4556cd 100644
--- a/synapse/events/presence_router.py
+++ b/synapse/events/presence_router.py
@@ -11,45 +11,115 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
+import logging
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Union,
+)
from synapse.api.presence import UserPresenceState
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
from synapse.server import HomeServer
+GET_USERS_FOR_STATES_CALLBACK = Callable[
+ [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
+]
+GET_INTERESTED_USERS_CALLBACK = Callable[
+ [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
+]
+
+logger = logging.getLogger(__name__)
+
+
+def load_legacy_presence_router(hs: "HomeServer"):
+ """Wrapper that loads a presence router module configured using the old
+ configuration, and registers the hooks they implement.
+ """
+
+ if hs.config.presence_router_module_class is None:
+ return
+
+ module = hs.config.presence_router_module_class
+ config = hs.config.presence_router_config
+ api = hs.get_module_api()
+
+ presence_router = module(config=config, module_api=api)
+
+ # The known hooks. If a module implements a method which name appears in this set,
+ # we'll want to register it.
+ presence_router_methods = {
+ "get_users_for_states",
+ "get_interested_users",
+ }
+
+ # All methods that the module provides should be async, but this wasn't enforced
+ # in the old module system, so we wrap them if needed
+ def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+ # f might be None if the callback isn't implemented by the module. In this
+ # case we don't want to register a callback at all so we return None.
+ if f is None:
+ return None
+
+ def run(*args, **kwargs):
+ # mypy doesn't do well across function boundaries so we need to tell it
+ # f is definitely not None.
+ assert f is not None
+
+ return maybe_awaitable(f(*args, **kwargs))
+
+ return run
+
+ # Register the hooks through the module API.
+ hooks = {
+ hook: async_wrapper(getattr(presence_router, hook, None))
+ for hook in presence_router_methods
+ }
+
+ api.register_presence_router_callbacks(**hooks)
+
class PresenceRouter:
"""
A module that the homeserver will call upon to help route user presence updates to
- additional destinations. If a custom presence router is configured, calls will be
- passed to that instead.
+ additional destinations.
"""
ALL_USERS = "ALL"
def __init__(self, hs: "HomeServer"):
- self.custom_presence_router = None
+ # Initially there are no callbacks
+ self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
+ self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
- # Check whether a custom presence router module has been configured
- if hs.config.presence_router_module_class:
- # Initialise the module
- self.custom_presence_router = hs.config.presence_router_module_class(
- config=hs.config.presence_router_config, module_api=hs.get_module_api()
+ def register_presence_router_callbacks(
+ self,
+ get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
+ get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
+ ):
+ # PresenceRouter modules are required to implement both of these methods
+ # or neither of them as they are assumed to act in a complementary manner
+ paired_methods = [get_users_for_states, get_interested_users]
+ if paired_methods.count(None) == 1:
+ raise RuntimeError(
+ "PresenceRouter modules must register neither or both of the paired callbacks: "
+ "[get_users_for_states, get_interested_users]"
)
- # Ensure the module has implemented the required methods
- required_methods = ["get_users_for_states", "get_interested_users"]
- for method_name in required_methods:
- if not hasattr(self.custom_presence_router, method_name):
- raise Exception(
- "PresenceRouter module '%s' must implement all required methods: %s"
- % (
- hs.config.presence_router_module_class.__name__,
- ", ".join(required_methods),
- )
- )
+ # Append the methods provided to the lists of callbacks
+ if get_users_for_states is not None:
+ self._get_users_for_states_callbacks.append(get_users_for_states)
+
+ if get_interested_users is not None:
+ self._get_interested_users_callbacks.append(get_interested_users)
async def get_users_for_states(
self,
@@ -66,14 +136,40 @@ class PresenceRouter:
A dictionary of user_id -> set of UserPresenceState, indicating which
presence updates each user should receive.
"""
- if self.custom_presence_router is not None:
- # Ask the custom module
- return await self.custom_presence_router.get_users_for_states(
- state_updates=state_updates
- )
- # Don't include any extra destinations for presence updates
- return {}
+ # Bail out early if we don't have any callbacks to run.
+ if len(self._get_users_for_states_callbacks) == 0:
+ # Don't include any extra destinations for presence updates
+ return {}
+
+ users_for_states = {}
+ # run all the callbacks for get_users_for_states and combine the results
+ for callback in self._get_users_for_states_callbacks:
+ try:
+ result = await callback(state_updates)
+ except Exception as e:
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
+ continue
+
+ if not isinstance(result, Dict):
+ logger.warning(
+ "Wrong type returned by module API callback %s: %s, expected Dict",
+ callback,
+ result,
+ )
+ continue
+
+ for key, new_entries in result.items():
+ if not isinstance(new_entries, Set):
+ logger.warning(
+ "Wrong type returned by module API callback %s: %s, expected Set",
+ callback,
+ new_entries,
+ )
+ break
+ users_for_states.setdefault(key, set()).update(new_entries)
+
+ return users_for_states
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
"""
@@ -92,12 +188,36 @@ class PresenceRouter:
A set of user IDs to return presence updates for, or ALL_USERS to return all
known updates.
"""
- if self.custom_presence_router is not None:
- # Ask the custom module for interested users
- return await self.custom_presence_router.get_interested_users(
- user_id=user_id
- )
- # A custom presence router is not defined.
- # Don't report any additional interested users
- return set()
+ # Bail out early if we don't have any callbacks to run.
+ if len(self._get_interested_users_callbacks) == 0:
+ # Don't report any additional interested users
+ return set()
+
+ interested_users = set()
+ # run all the callbacks for get_interested_users and combine the results
+ for callback in self._get_interested_users_callbacks:
+ try:
+ result = await callback(user_id)
+ except Exception as e:
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
+ continue
+
+ # If one of the callbacks returns ALL_USERS then we can stop calling all
+ # of the other callbacks, since the set of interested_users is already as
+ # large as it can possibly be
+ if result == PresenceRouter.ALL_USERS:
+ return PresenceRouter.ALL_USERS
+
+ if not isinstance(result, Set):
+ logger.warning(
+ "Wrong type returned by module API callback %s: %s, expected set",
+ callback,
+ result,
+ )
+ continue
+
+ # Add the new interested users to the set
+ interested_users.update(result)
+
+ return interested_users
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index b6da2f60..fb22337e 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -32,6 +32,9 @@ from . import EventBase
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
+CANONICALJSON_MAX_INT = (2 ** 53) - 1
+CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
+
def prune_event(event: EventBase) -> EventBase:
"""Returns a pruned version of the given event, which removes all keys we
@@ -101,6 +104,8 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
if event_type == EventTypes.Member:
add_fields("membership")
+ if room_version.msc3375_redaction_rules:
+ add_fields("join_authorised_via_users_server")
elif event_type == EventTypes.Create:
# MSC2176 rules state that create events cannot be redacted.
if room_version.msc2176_redaction_rules:
@@ -505,7 +510,7 @@ def validate_canonicaljson(value: Any):
* NaN, Infinity, -Infinity
"""
if isinstance(value, int):
- if value <= -(2 ** 53) or 2 ** 53 <= value:
+ if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value:
raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON)
elif isinstance(value, float):
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index fa6987d7..33954b4f 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -11,16 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import collections.abc
from typing import Union
+import jsonschema
+
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
-from synapse.events.utils import validate_canonicaljson
+from synapse.events.utils import (
+ CANONICALJSON_MAX_INT,
+ CANONICALJSON_MIN_INT,
+ validate_canonicaljson,
+)
from synapse.federation.federation_server import server_matches_acl_event
from synapse.types import EventID, RoomID, UserID
@@ -87,6 +93,29 @@ class EventValidator:
400, "Can't create an ACL event that denies the local server"
)
+ if event.type == EventTypes.PowerLevels:
+ try:
+ jsonschema.validate(
+ instance=event.content,
+ schema=POWER_LEVELS_SCHEMA,
+ cls=plValidator,
+ )
+ except jsonschema.ValidationError as e:
+ if e.path:
+ # example: "users_default": '0' is not of type 'integer'
+ message = '"' + e.path[-1] + '": ' + e.message # noqa: B306
+ # jsonschema.ValidationError.message is a valid attribute
+ else:
+ # example: '0' is not of type 'integer'
+ message = e.message # noqa: B306
+ # jsonschema.ValidationError.message is a valid attribute
+
+ raise SynapseError(
+ code=400,
+ msg=message,
+ errcode=Codes.BAD_JSON,
+ )
+
def _validate_retention(self, event: EventBase):
"""Checks that an event that defines the retention policy for a room respects the
format enforced by the spec.
@@ -185,3 +214,47 @@ class EventValidator:
def _ensure_state_event(self, event):
if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,))
+
+
+POWER_LEVELS_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "ban": {"$ref": "#/definitions/int"},
+ "events": {"$ref": "#/definitions/objectOfInts"},
+ "events_default": {"$ref": "#/definitions/int"},
+ "invite": {"$ref": "#/definitions/int"},
+ "kick": {"$ref": "#/definitions/int"},
+ "notifications": {"$ref": "#/definitions/objectOfInts"},
+ "redact": {"$ref": "#/definitions/int"},
+ "state_default": {"$ref": "#/definitions/int"},
+ "users": {"$ref": "#/definitions/objectOfInts"},
+ "users_default": {"$ref": "#/definitions/int"},
+ },
+ "definitions": {
+ "int": {
+ "type": "integer",
+ "minimum": CANONICALJSON_MIN_INT,
+ "maximum": CANONICALJSON_MAX_INT,
+ },
+ "objectOfInts": {
+ "type": "object",
+ "additionalProperties": {"$ref": "#/definitions/int"},
+ },
+ },
+}
+
+
+def _create_power_level_validator():
+ validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
+
+ # by default jsonschema does not consider a frozendict to be an object so
+ # we need to use a custom type checker
+ # https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types
+ type_checker = validator.TYPE_CHECKER.redefine(
+ "object", lambda checker, thing: isinstance(thing, collections.abc.Mapping)
+ )
+
+ return jsonschema.validators.extend(validator, type_checker=type_checker)
+
+
+plValidator = _create_power_level_validator()
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 29979414..1416abd0 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
+ RequestSendFailed,
SynapseError,
UnsupportedRoomVersionError,
)
@@ -110,6 +111,23 @@ class FederationClient(FederationBase):
reset_expiry_on_get=False,
)
+ # A cache for fetching the room hierarchy over federation.
+ #
+ # Some stale data over federation is OK, but must be refreshed
+ # periodically since the local server is in the room.
+ #
+ # It is a map of (room ID, suggested-only) -> the response of
+ # get_room_hierarchy.
+ self._get_room_hierarchy_cache: ExpiringCache[
+ Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]
+ ] = ExpiringCache(
+ cache_name="get_room_hierarchy_cache",
+ clock=self._clock,
+ max_len=1000,
+ expiry_ms=5 * 60 * 1000,
+ reset_expiry_on_get=False,
+ )
+
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
@@ -558,7 +576,11 @@ class FederationClient(FederationBase):
try:
return await callback(destination)
- except InvalidResponseError as e:
+ except (
+ RequestSendFailed,
+ InvalidResponseError,
+ NotRetryingDestination,
+ ) as e:
logger.warning("Failed to %s via %s: %s", description, destination, e)
except UnsupportedRoomVersionError:
raise
@@ -1319,6 +1341,10 @@ class FederationClient(FederationBase):
remote servers
"""
+ cached_result = self._get_room_hierarchy_cache.get((room_id, suggested_only))
+ if cached_result:
+ return cached_result
+
async def send_request(
destination: str,
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
@@ -1365,58 +1391,63 @@ class FederationClient(FederationBase):
return room, children, inaccessible_children
try:
- return await self._try_destination_list(
+ result = await self._try_destination_list(
"fetch room hierarchy",
destinations,
send_request,
failover_on_unknown_endpoint=True,
)
except SynapseError as e:
+ # If an unexpected error occurred, re-raise it.
+ if e.code != 502:
+ raise
+
# Fallback to the old federation API and translate the results if
# no servers implement the new API.
#
# The algorithm below is a bit inefficient as it only attempts to
- # get information for the requested room, but the legacy API may
+ # parse information for the requested room, but the legacy API may
# return additional layers.
- if e.code == 502:
- legacy_result = await self.get_space_summary(
- destinations,
- room_id,
- suggested_only,
- max_rooms_per_space=None,
- exclude_rooms=[],
- )
+ legacy_result = await self.get_space_summary(
+ destinations,
+ room_id,
+ suggested_only,
+ max_rooms_per_space=None,
+ exclude_rooms=[],
+ )
- # Find the requested room in the response (and remove it).
- for _i, room in enumerate(legacy_result.rooms):
- if room.get("room_id") == room_id:
- break
- else:
- # The requested room was not returned, nothing we can do.
- raise
- requested_room = legacy_result.rooms.pop(_i)
-
- # Find any children events of the requested room.
- children_events = []
- children_room_ids = set()
- for event in legacy_result.events:
- if event.room_id == room_id:
- children_events.append(event.data)
- children_room_ids.add(event.state_key)
- # And add them under the requested room.
- requested_room["children_state"] = children_events
-
- # Find the children rooms.
- children = []
- for room in legacy_result.rooms:
- if room.get("room_id") in children_room_ids:
- children.append(room)
-
- # It isn't clear from the response whether some of the rooms are
- # not accessible.
- return requested_room, children, ()
-
- raise
+ # Find the requested room in the response (and remove it).
+ for _i, room in enumerate(legacy_result.rooms):
+ if room.get("room_id") == room_id:
+ break
+ else:
+ # The requested room was not returned, nothing we can do.
+ raise
+ requested_room = legacy_result.rooms.pop(_i)
+
+ # Find any children events of the requested room.
+ children_events = []
+ children_room_ids = set()
+ for event in legacy_result.events:
+ if event.room_id == room_id:
+ children_events.append(event.data)
+ children_room_ids.add(event.state_key)
+ # And add them under the requested room.
+ requested_room["children_state"] = children_events
+
+ # Find the children rooms.
+ children = []
+ for room in legacy_result.rooms:
+ if room.get("room_id") in children_room_ids:
+ children.append(room)
+
+ # It isn't clear from the response whether some of the rooms are
+ # not accessible.
+ result = (requested_room, children, ())
+
+ # Cache the result to avoid fetching data over federation every time.
+ self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
+ return result
@attr.s(frozen=True, slots=True, auto_attribs=True)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index afd8f858..214ee948 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -110,6 +110,7 @@ class FederationServer(FederationBase):
super().__init__(hs)
self.handler = hs.get_federation_handler()
+ self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
@@ -787,7 +788,9 @@ class FederationServer(FederationBase):
event = await self._check_sigs_and_hash(room_version, event)
- return await self.handler.on_send_membership_event(origin, event)
+ return await self._federation_event_handler.on_send_membership_event(
+ origin, event
+ )
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
@@ -1005,9 +1008,7 @@ class FederationServer(FederationBase):
async with lock:
logger.info("handling received PDU: %s", event)
try:
- await self.handler.on_receive_pdu(
- origin, event, sent_to_us_directly=True
- )
+ await self._federation_event_handler.on_receive_pdu(origin, event)
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 161b3c93..34725324 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -627,23 +627,28 @@ class AuthHandler(BaseHandler):
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
- ) -> bool:
+ ) -> None:
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
+
+ Raises:
+ LoginError if the stagetype is unknown or the session is missing.
+ LoginError is raised by check_auth if authentication fails.
"""
if stagetype not in self.checkers:
- raise LoginError(400, "", Codes.MISSING_PARAM)
+ raise LoginError(
+ 400, f"Unknown UIA stage type: {stagetype}", Codes.INVALID_PARAM
+ )
if "session" not in authdict:
- raise LoginError(400, "", Codes.MISSING_PARAM)
+ raise LoginError(400, "Missing session ID", Codes.MISSING_PARAM)
+ # If authentication fails a LoginError is raised. Otherwise, store
+ # the successful result.
result = await self.checkers[stagetype].check_auth(authdict, clientip)
- if result:
- await self.store.mark_ui_auth_stage_complete(
- authdict["session"], stagetype, result
- )
- return True
- return False
+ await self.store.mark_ui_auth_stage_complete(
+ authdict["session"], stagetype, result
+ )
def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
"""
@@ -1459,6 +1464,10 @@ class AuthHandler(BaseHandler):
)
await self.store.user_delete_threepid(user_id, medium, address)
+ if medium == "email":
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id="m.email", pushkey=address, user_id=user_id
+ )
return result
async def hash(self, password: str) -> str:
@@ -1727,7 +1736,6 @@ class AuthHandler(BaseHandler):
@attr.s(slots=True)
class MacaroonGenerator:
-
hs = attr.ib()
def generate_guest_access_token(self, user_id: str) -> str:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0e13bda..daf1d3bf 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,23 +17,9 @@
import itertools
import logging
-from collections.abc import Container
from http import HTTPStatus
-from typing import (
- TYPE_CHECKING,
- Collection,
- Dict,
- Iterable,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Union,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
-import attr
-from prometheus_client import Counter
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -41,19 +27,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import (
- EventContentFields,
- EventTypes,
- Membership,
- RejectedReason,
- RoomEncryptionAlgorithms,
-)
+from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.api.errors import (
AuthError,
CodeMessageException,
Codes,
FederationDeniedError,
- FederationError,
HttpResponseException,
NotFoundError,
RequestSendFailed,
@@ -61,10 +40,10 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
-from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
+from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
@@ -74,28 +53,14 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.logging.utils import log_function
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
- ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
-from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import (
- JsonDict,
- MutableStateMap,
- PersistedEventPosition,
- RoomStreamToken,
- StateMap,
- UserID,
- get_domain_from_id,
-)
-from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.iterutils import batch_iter
+from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
-from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
if TYPE_CHECKING:
@@ -103,50 +68,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-soft_failed_event_counter = Counter(
- "synapse_federation_soft_failed_events_total",
- "Events received over federation that we marked as soft_failed",
-)
-
-
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _NewEventInfo:
- """Holds information about a received event, ready for passing to _auth_and_persist_events
-
- Attributes:
- event: the received event
-
- state: the state at that event, according to /state_ids from a remote
- homeserver. Only populated for backfilled events which are going to be a
- new backwards extremity.
-
- claimed_auth_event_map: a map of (type, state_key) => event for the event's
- claimed auth_events.
-
- This can include events which have not yet been persisted, in the case that
- we are backfilling a batch of events.
-
- Note: May be incomplete: if we were unable to find all of the claimed auth
- events. Also, treat the contents with caution: the events might also have
- been rejected, might not yet have been authorized themselves, or they might
- be in the wrong room.
-
- """
-
- event: EventBase
- state: Optional[Sequence[EventBase]]
- claimed_auth_event_map: StateMap[EventBase]
-
class FederationHandler(BaseHandler):
- """Handles events that originated from federation.
- Responsible for:
- a) handling received Pdus before handing them on as Events to the rest
- of the homeserver (including auth and state conflict resolutions)
- b) converting events that were produced by local clients that may need
- to be sent to remote homeservers.
- c) doing the necessary dances to invite remote users and join remote
- rooms.
+ """Handles general incoming federation requests
+
+ Incoming events are *not* handled here, for which see FederationEventHandler.
"""
def __init__(self, hs: "HomeServer"):
@@ -159,979 +85,35 @@ class FederationHandler(BaseHandler):
self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
- self._state_resolution_handler = hs.get_state_resolution_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
- self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
- self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_proxied_blacklisted_http_client()
- self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
+ self._federation_event_handler = hs.get_federation_event_handler()
- self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
)
if hs.config.worker_app:
- self._user_device_resync = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
self._maybe_store_room_on_outlier_membership = (
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
)
else:
- self._device_list_updater = hs.get_device_handler().device_list_updater
self._maybe_store_room_on_outlier_membership = (
self.store.maybe_store_room_on_outlier_membership
)
- # When joining a room we need to queue any events for that room up.
- # For each room, a list of (pdu, origin) tuples.
- self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
- self._room_pdu_linearizer = Linearizer("fed_room_pdu")
-
self._room_backfill = Linearizer("room_backfill")
self.third_party_event_rules = hs.get_third_party_event_rules()
- self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
-
- async def on_receive_pdu(
- self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
- ) -> None:
- """Process a PDU received via a federation /send/ transaction, or
- via backfill of missing prev_events
-
- Args:
- origin: server which initiated the /send/ transaction. Will
- be used to fetch missing events or state.
- pdu: received PDU
- sent_to_us_directly: True if this event was pushed to us; False if
- we pulled it as the result of a missing prev_event.
- """
-
- room_id = pdu.room_id
- event_id = pdu.event_id
-
- # We reprocess pdus when we have seen them only as outliers
- existing = await self.store.get_event(
- event_id, allow_none=True, allow_rejected=True
- )
-
- # FIXME: Currently we fetch an event again when we already have it
- # if it has been marked as an outlier.
- if existing:
- if not existing.internal_metadata.is_outlier():
- logger.info(
- "Ignoring received event %s which we have already seen", event_id
- )
- return
- if pdu.internal_metadata.is_outlier():
- logger.info(
- "Ignoring received outlier %s which we already have as an outlier",
- event_id,
- )
- return
- logger.info("De-outliering event %s", event_id)
-
- # do some initial sanity-checking of the event. In particular, make
- # sure it doesn't have hundreds of prev_events or auth_events, which
- # could cause a huge state resolution or cascade of event fetches.
- try:
- self._sanity_check_event(pdu)
- except SynapseError as err:
- logger.warning("Received event failed sanity checks")
- raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
-
- # If we are currently in the process of joining this room, then we
- # queue up events for later processing.
- if room_id in self.room_queues:
- logger.info(
- "Queuing PDU from %s for now: join in progress",
- origin,
- )
- self.room_queues[room_id].append((pdu, origin))
- return
-
- # If we're not in the room just ditch the event entirely. This is
- # probably an old server that has come back and thinks we're still in
- # the room (or we've been rejoined to the room by a state reset).
- #
- # Note that if we were never in the room then we would have already
- # dropped the event, since we wouldn't know the room version.
- is_in_room = await self._event_auth_handler.check_host_in_room(
- room_id, self.server_name
- )
- if not is_in_room:
- logger.info(
- "Ignoring PDU from %s as we're not in the room",
- origin,
- )
- return None
-
- state = None
-
- # Check that the event passes auth based on the state at the event. This is
- # done for events that are to be added to the timeline (non-outliers).
- #
- # Get missing pdus if necessary:
- # - Fetching any missing prev events to fill in gaps in the graph
- # - Fetching state if we have a hole in the graph
- if not pdu.internal_metadata.is_outlier():
- # We only backfill backwards to the min depth.
- min_depth = await self.get_min_depth_for_context(pdu.room_id)
-
- logger.debug("min_depth: %d", min_depth)
-
- prevs = set(pdu.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
-
- if min_depth is not None and pdu.depth < min_depth:
- # This is so that we don't notify the user about this
- # message, to work around the fact that some events will
- # reference really really old events we really don't want to
- # send to the clients.
- pdu.internal_metadata.outlier = True
- elif min_depth is not None and pdu.depth > min_depth:
- missing_prevs = prevs - seen
- if sent_to_us_directly and missing_prevs:
- # If we're missing stuff, ensure we only fetch stuff one
- # at a time.
- logger.info(
- "Acquiring room lock to fetch %d missing prev_events: %s",
- len(missing_prevs),
- shortstr(missing_prevs),
- )
- with (await self._room_pdu_linearizer.queue(pdu.room_id)):
- logger.info(
- "Acquired room lock to fetch %d missing prev_events",
- len(missing_prevs),
- )
-
- try:
- await self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
- except Exception as e:
- raise Exception(
- "Error fetching missing prev_events for %s: %s"
- % (event_id, e)
- ) from e
-
- # Update the set of things we've seen after trying to
- # fetch the missing stuff
- seen = await self.store.have_events_in_timeline(prevs)
-
- if not prevs - seen:
- logger.info(
- "Found all missing prev_events",
- )
-
- missing_prevs = prevs - seen
- if missing_prevs:
- # We've still not been able to get all of the prev_events for this event.
- #
- # In this case, we need to fall back to asking another server in the
- # federation for the state at this event. That's ok provided we then
- # resolve the state against other bits of the DAG before using it (which
- # will ensure that you can't just take over a room by sending an event,
- # withholding its prev_events, and declaring yourself to be an admin in
- # the subsequent state request).
- #
- # Now, if we're pulling this event as a missing prev_event, then clearly
- # this event is not going to become the only forward-extremity and we are
- # guaranteed to resolve its state against our existing forward
- # extremities, so that should be fine.
- #
- # On the other hand, if this event was pushed to us, it is possible for
- # it to become the only forward-extremity in the room, and we would then
- # trust its state to be the state for the whole room. This is very bad.
- # Further, if the event was pushed to us, there is no excuse for us not to
- # have all the prev_events. We therefore reject any such events.
- #
- # XXX this really feels like it could/should be merged with the above,
- # but there is an interaction with min_depth that I'm not really
- # following.
-
- if sent_to_us_directly:
- logger.warning(
- "Rejecting: failed to fetch %d prev events: %s",
- len(missing_prevs),
- shortstr(missing_prevs),
- )
- raise FederationError(
- "ERROR",
- 403,
- (
- "Your server isn't divulging details about prev_events "
- "referenced in this event."
- ),
- affected=pdu.event_id,
- )
-
- logger.info(
- "Event %s is missing prev_events %s: calculating state for a "
- "backwards extremity",
- event_id,
- shortstr(missing_prevs),
- )
-
- # Calculate the state after each of the previous events, and
- # resolve them to find the correct state at the current event.
- event_map = {event_id: pdu}
- try:
- # Get the state of the events we know about
- ours = await self.state_store.get_state_groups_ids(room_id, seen)
-
- # state_maps is a list of mappings from (type, state_key) to event_id
- state_maps: List[StateMap[str]] = list(ours.values())
-
- # we don't need this any more, let's delete it.
- del ours
-
- # Ask the remote server for the states we don't
- # know about
- for p in missing_prevs:
- logger.info("Requesting state after missing prev_event %s", p)
-
- with nested_logging_context(p):
- # note that if any of the missing prevs share missing state or
- # auth events, the requests to fetch those events are deduped
- # by the get_pdu_cache in federation_client.
- remote_state = (
- await self._get_state_after_missing_prev_event(
- origin, room_id, p
- )
- )
-
- remote_state_map = {
- (x.type, x.state_key): x.event_id for x in remote_state
- }
- state_maps.append(remote_state_map)
-
- for x in remote_state:
- event_map[x.event_id] = x
-
- room_version = await self.store.get_room_version_id(room_id)
- state_map = (
- await self._state_resolution_handler.resolve_events_with_store(
- room_id,
- room_version,
- state_maps,
- event_map,
- state_res_store=StateResolutionStore(self.store),
- )
- )
-
- # We need to give _process_received_pdu the actual state events
- # rather than event ids, so generate that now.
-
- # First though we need to fetch all the events that are in
- # state_map, so we can build up the state below.
- evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
- )
- event_map.update(evs)
-
- state = [event_map[e] for e in state_map.values()]
- except Exception:
- logger.warning(
- "Error attempting to resolve state at missing " "prev_events",
- exc_info=True,
- )
- raise FederationError(
- "ERROR",
- 403,
- "We can't get valid state history.",
- affected=event_id,
- )
-
- # A second round of checks for all events. Check that the event passes auth
- # based on `auth_events`, this allows us to assert that the event would
- # have been allowed at some point. If an event passes this check its OK
- # for it to be used as part of a returned `/state` request, as either
- # a) we received the event as part of the original join and so trust it, or
- # b) we'll do a state resolution with existing state before it becomes
- # part of the "current state", which adds more protection.
- await self._process_received_pdu(origin, pdu, state=state)
-
- async def _get_missing_events_for_pdu(
- self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
- ) -> None:
- """
- Args:
- origin: Origin of the pdu. Will be called to get the missing events
- pdu: received pdu
- prevs: List of event ids which we are missing
- min_depth: Minimum depth of events to return.
- """
-
- room_id = pdu.room_id
- event_id = pdu.event_id
-
- seen = await self.store.have_events_in_timeline(prevs)
-
- if not prevs - seen:
- return
-
- latest_list = await self.store.get_latest_event_ids_in_room(room_id)
-
- # We add the prev events that we have seen to the latest
- # list to ensure the remote server doesn't give them to us
- latest = set(latest_list)
- latest |= seen
-
- logger.info(
- "Requesting missing events between %s and %s",
- shortstr(latest),
- event_id,
- )
-
- # XXX: we set timeout to 10s to help workaround
- # https://github.com/matrix-org/synapse/issues/1733.
- # The reason is to avoid holding the linearizer lock
- # whilst processing inbound /send transactions, causing
- # FDs to stack up and block other inbound transactions
- # which empirically can currently take up to 30 minutes.
- #
- # N.B. this explicitly disables retry attempts.
- #
- # N.B. this also increases our chances of falling back to
- # fetching fresh state for the room if the missing event
- # can't be found, which slightly reduces our security.
- # it may also increase our DAG extremity count for the room,
- # causing additional state resolution? See #1760.
- # However, fetching state doesn't hold the linearizer lock
- # apparently.
- #
- # see https://github.com/matrix-org/synapse/pull/1744
- #
- # ----
- #
- # Update richvdh 2018/09/18: There are a number of problems with timing this
- # request out aggressively on the client side:
- #
- # - it plays badly with the server-side rate-limiter, which starts tarpitting you
- # if you send too many requests at once, so you end up with the server carefully
- # working through the backlog of your requests, which you have already timed
- # out.
- #
- # - for this request in particular, we now (as of
- # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the
- # server can't produce a plausible-looking set of prev_events - so we becone
- # much more likely to reject the event.
- #
- # - contrary to what it says above, we do *not* fall back to fetching fresh state
- # for the room if get_missing_events times out. Rather, we give up processing
- # the PDU whose prevs we are missing, which then makes it much more likely that
- # we'll end up back here for the *next* PDU in the list, which exacerbates the
- # problem.
- #
- # - the aggressive 10s timeout was introduced to deal with incoming federation
- # requests taking 8 hours to process. It's not entirely clear why that was going
- # on; certainly there were other issues causing traffic storms which are now
- # resolved, and I think in any case we may be more sensible about our locking
- # now. We're *certainly* more sensible about our logging.
- #
- # All that said: Let's try increasing the timeout to 60s and see what happens.
-
- try:
- missing_events = await self.federation_client.get_missing_events(
- origin,
- room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- timeout=60000,
- )
- except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
- # We failed to get the missing events, but since we need to handle
- # the case of `get_missing_events` not returning the necessary
- # events anyway, it is safe to simply log the error and continue.
- logger.warning("Failed to get prev_events: %s", e)
- return
-
- logger.info("Got %d prev_events", len(missing_events))
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- missing_events.sort(key=lambda x: x.depth)
-
- for ev in missing_events:
- logger.info("Handling received prev_event %s", ev)
- with nested_logging_context(ev.event_id):
- try:
- await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
- except FederationError as e:
- if e.code == 403:
- logger.warning(
- "Received prev_event %s failed history check.",
- ev.event_id,
- )
- else:
- raise
-
- async def _get_state_for_room(
- self,
- destination: str,
- room_id: str,
- event_id: str,
- ) -> List[EventBase]:
- """Requests all of the room state at a given event from a remote
- homeserver.
-
- Will also fetch any missing events reported in the `auth_chain_ids`
- section of `/state_ids`.
-
- Args:
- destination: The remote homeserver to query for the state.
- room_id: The id of the room we're interested in.
- event_id: The id of the event we want the state at.
-
- Returns:
- A list of events in the state, not including the event itself.
- """
- (
- state_event_ids,
- auth_event_ids,
- ) = await self.federation_client.get_room_state_ids(
- destination, room_id, event_id=event_id
- )
-
- # Fetch the state events from the DB, and check we have the auth events.
- event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
- auth_events_in_store = await self.store.have_seen_events(
- room_id, auth_event_ids
- )
-
- # Check for missing events. We handle state and auth event seperately,
- # as we want to pull the state from the DB, but we don't for the auth
- # events. (Note: we likely won't use the majority of the auth chain, and
- # it can be *huge* for large rooms, so it's worth ensuring that we don't
- # unnecessarily pull it from the DB).
- missing_state_events = set(state_event_ids) - set(event_map)
- missing_auth_events = set(auth_event_ids) - set(auth_events_in_store)
- if missing_state_events or missing_auth_events:
- await self._get_events_and_persist(
- destination=destination,
- room_id=room_id,
- events=missing_state_events | missing_auth_events,
- )
-
- if missing_state_events:
- new_events = await self.store.get_events(
- missing_state_events, allow_rejected=True
- )
- event_map.update(new_events)
-
- missing_state_events.difference_update(new_events)
-
- if missing_state_events:
- logger.warning(
- "Failed to fetch missing state events for %s %s",
- event_id,
- missing_state_events,
- )
-
- if missing_auth_events:
- auth_events_in_store = await self.store.have_seen_events(
- room_id, missing_auth_events
- )
- missing_auth_events.difference_update(auth_events_in_store)
-
- if missing_auth_events:
- logger.warning(
- "Failed to fetch missing auth events for %s %s",
- event_id,
- missing_auth_events,
- )
-
- remote_state = list(event_map.values())
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
-
- bad_events = [
- (event.event_id, event.room_id)
- for event in remote_state
- if event.room_id != room_id
- ]
-
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned auth/state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
-
- if bad_events:
- remote_state = [e for e in remote_state if e.room_id == room_id]
-
- return remote_state
-
- async def _get_state_after_missing_prev_event(
- self,
- destination: str,
- room_id: str,
- event_id: str,
- ) -> List[EventBase]:
- """Requests all of the room state at a given event from a remote homeserver.
-
- Args:
- destination: The remote homeserver to query for the state.
- room_id: The id of the room we're interested in.
- event_id: The id of the event we want the state at.
-
- Returns:
- A list of events in the state, including the event itself
- """
- # TODO: This function is basically the same as _get_state_for_room. Can
- # we make backfill() use it, rather than having two code paths? I think the
- # only difference is that backfill() persists the prev events separately.
-
- (
- state_event_ids,
- auth_event_ids,
- ) = await self.federation_client.get_room_state_ids(
- destination, room_id, event_id=event_id
- )
-
- logger.debug(
- "state_ids returned %i state events, %i auth events",
- len(state_event_ids),
- len(auth_event_ids),
- )
-
- # start by just trying to fetch the events from the store
- desired_events = set(state_event_ids)
- desired_events.add(event_id)
- logger.debug("Fetching %i events from cache/store", len(desired_events))
- fetched_events = await self.store.get_events(
- desired_events, allow_rejected=True
- )
-
- missing_desired_events = desired_events - fetched_events.keys()
- logger.debug(
- "We are missing %i events (got %i)",
- len(missing_desired_events),
- len(fetched_events),
- )
-
- # We probably won't need most of the auth events, so let's just check which
- # we have for now, rather than thrashing the event cache with them all
- # unnecessarily.
-
- # TODO: we probably won't actually need all of the auth events, since we
- # already have a bunch of the state events. It would be nice if the
- # federation api gave us a way of finding out which we actually need.
-
- missing_auth_events = set(auth_event_ids) - fetched_events.keys()
- missing_auth_events.difference_update(
- await self.store.have_seen_events(room_id, missing_auth_events)
- )
- logger.debug("We are also missing %i auth events", len(missing_auth_events))
-
- missing_events = missing_desired_events | missing_auth_events
- logger.debug("Fetching %i events from remote", len(missing_events))
- await self._get_events_and_persist(
- destination=destination, room_id=room_id, events=missing_events
- )
-
- # we need to make sure we re-load from the database to get the rejected
- # state correct.
- fetched_events.update(
- await self.store.get_events(missing_desired_events, allow_rejected=True)
- )
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
-
- bad_events = [
- (event_id, event.room_id)
- for event_id, event in fetched_events.items()
- if event.room_id != room_id
- ]
-
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
-
- del fetched_events[bad_event_id]
-
- # if we couldn't get the prev event in question, that's a problem.
- remote_event = fetched_events.get(event_id)
- if not remote_event:
- raise Exception("Unable to get missing prev_event %s" % (event_id,))
-
- # missing state at that event is a warning, not a blocker
- # XXX: this doesn't sound right? it means that we'll end up with incomplete
- # state.
- failed_to_fetch = desired_events - fetched_events.keys()
- if failed_to_fetch:
- logger.warning(
- "Failed to fetch missing state events for %s %s",
- event_id,
- failed_to_fetch,
- )
-
- remote_state = [
- fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
- ]
-
- if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
-
- return remote_state
-
- async def _process_received_pdu(
- self,
- origin: str,
- event: EventBase,
- state: Optional[Iterable[EventBase]],
- ) -> None:
- """Called when we have a new pdu. We need to do auth checks and put it
- through the StateHandler.
-
- Args:
- origin: server sending the event
-
- event: event to be persisted
-
- state: Normally None, but if we are handling a gap in the graph
- (ie, we are missing one or more prev_events), the resolved state at the
- event
- """
- logger.debug("Processing event: %s", event)
-
- try:
- context = await self.state_handler.compute_event_context(
- event, old_state=state
- )
- await self._auth_and_persist_event(origin, event, context, state=state)
- except AuthError as e:
- raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
-
- # For encrypted messages we check that we know about the sending device,
- # if we don't then we mark the device cache for that user as stale.
- if event.type == EventTypes.Encrypted:
- device_id = event.content.get("device_id")
- sender_key = event.content.get("sender_key")
-
- cached_devices = await self.store.get_cached_devices_for_user(event.sender)
-
- resync = False # Whether we should resync device lists.
-
- device = None
- if device_id is not None:
- device = cached_devices.get(device_id)
- if device is None:
- logger.info(
- "Received event from remote device not in our cache: %s %s",
- event.sender,
- device_id,
- )
- resync = True
-
- # We also check if the `sender_key` matches what we expect.
- if sender_key is not None:
- # Figure out what sender key we're expecting. If we know the
- # device and recognize the algorithm then we can work out the
- # exact key to expect. Otherwise check it matches any key we
- # have for that device.
-
- current_keys: Container[str] = []
-
- if device:
- keys = device.get("keys", {}).get("keys", {})
-
- if (
- event.content.get("algorithm")
- == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2
- ):
- # For this algorithm we expect a curve25519 key.
- key_name = "curve25519:%s" % (device_id,)
- current_keys = [keys.get(key_name)]
- else:
- # We don't know understand the algorithm, so we just
- # check it matches a key for the device.
- current_keys = keys.values()
- elif device_id:
- # We don't have any keys for the device ID.
- pass
- else:
- # The event didn't include a device ID, so we just look for
- # keys across all devices.
- current_keys = [
- key
- for device in cached_devices.values()
- for key in device.get("keys", {}).get("keys", {}).values()
- ]
-
- # We now check that the sender key matches (one of) the expected
- # keys.
- if sender_key not in current_keys:
- logger.info(
- "Received event from remote device with unexpected sender key: %s %s: %s",
- event.sender,
- device_id or "<no device_id>",
- sender_key,
- )
- resync = True
-
- if resync:
- run_as_background_process(
- "resync_device_due_to_pdu", self._resync_device, event.sender
- )
-
- await self._handle_marker_event(origin, event)
-
- async def _handle_marker_event(self, origin: str, marker_event: EventBase):
- """Handles backfilling the insertion event when we receive a marker
- event that points to one.
-
- Args:
- origin: Origin of the event. Will be called to get the insertion event
- marker_event: The event to process
- """
-
- if marker_event.type != EventTypes.MSC2716_MARKER:
- # Not a marker event
- return
-
- if marker_event.rejected_reason is not None:
- # Rejected event
- return
-
- # Skip processing a marker event if the room version doesn't
- # support it.
- room_version = await self.store.get_room_version(marker_event.room_id)
- if not room_version.msc2716_historical:
- return
-
- logger.debug("_handle_marker_event: received %s", marker_event)
-
- insertion_event_id = marker_event.content.get(
- EventContentFields.MSC2716_MARKER_INSERTION
- )
-
- if insertion_event_id is None:
- # Nothing to retrieve then (invalid marker)
- return
-
- logger.debug(
- "_handle_marker_event: backfilling insertion event %s", insertion_event_id
- )
-
- await self._get_events_and_persist(
- origin,
- marker_event.room_id,
- [insertion_event_id],
- )
-
- insertion_event = await self.store.get_event(
- insertion_event_id, allow_none=True
- )
- if insertion_event is None:
- logger.warning(
- "_handle_marker_event: server %s didn't return insertion event %s for marker %s",
- origin,
- insertion_event_id,
- marker_event.event_id,
- )
- return
-
- logger.debug(
- "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s",
- insertion_event,
- marker_event,
- )
-
- await self.store.insert_insertion_extremity(
- insertion_event_id, marker_event.room_id
- )
-
- logger.debug(
- "_handle_marker_event: insertion extremity added for %s from marker event %s",
- insertion_event,
- marker_event,
- )
-
- async def _resync_device(self, sender: str) -> None:
- """We have detected that the device list for the given user may be out
- of sync, so we try and resync them.
- """
-
- try:
- await self.store.mark_remote_user_device_cache_as_stale(sender)
-
- # Immediately attempt a resync in the background
- if self.config.worker_app:
- await self._user_device_resync(user_id=sender)
- else:
- await self._device_list_updater.user_device_resync(sender)
- except Exception:
- logger.exception("Failed to resync device for %s", sender)
-
- @log_function
- async def backfill(
- self, dest: str, room_id: str, limit: int, extremities: List[str]
- ) -> List[EventBase]:
- """Trigger a backfill request to `dest` for the given `room_id`
-
- This will attempt to get more events from the remote. If the other side
- has no new events to offer, this will return an empty list.
-
- As the events are received, we check their signatures, and also do some
- sanity-checking on them. If any of the backfilled events are invalid,
- this method throws a SynapseError.
-
- TODO: make this more useful to distinguish failures of the remote
- server from invalid events (there is probably no point in trying to
- re-fetch invalid events from every other HS in the room.)
- """
- if dest == self.server_name:
- raise SynapseError(400, "Can't backfill from self.")
-
- events = await self.federation_client.backfill(
- dest, room_id, limit=limit, extremities=extremities
- )
-
- if not events:
- return []
-
- # ideally we'd sanity check the events here for excess prev_events etc,
- # but it's hard to reject events at this point without completely
- # breaking backfill in the same way that it is currently broken by
- # events whose signature we cannot verify (#3121).
- #
- # So for now we accept the events anyway. #3124 tracks this.
- #
- # for ev in events:
- # self._sanity_check_event(ev)
-
- # Don't bother processing events we already have.
- seen_events = await self.store.have_events_in_timeline(
- {e.event_id for e in events}
- )
-
- events = [e for e in events if e.event_id not in seen_events]
-
- if not events:
- return []
-
- event_map = {e.event_id: e for e in events}
-
- event_ids = {e.event_id for e in events}
-
- # build a list of events whose prev_events weren't in the batch.
- # (XXX: this will include events whose prev_events we already have; that doesn't
- # sound right?)
- edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
-
- logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
-
- # For each edge get the current state.
-
- state_events = {}
- events_to_state = {}
- for e_id in edges:
- state = await self._get_state_for_room(
- destination=dest,
- room_id=room_id,
- event_id=e_id,
- )
- state_events.update({s.event_id: s for s in state})
- events_to_state[e_id] = state
-
- required_auth = {
- a_id
- for event in events + list(state_events.values())
- for a_id in event.auth_event_ids()
- }
- auth_events = await self.store.get_events(required_auth, allow_rejected=True)
- auth_events.update(
- {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
- )
-
- ev_infos = []
-
- # Step 1: persist the events in the chunk we fetched state for (i.e.
- # the backwards extremities), with custom auth events and state
- for e_id in events_to_state:
- # For paranoia we ensure that these events are marked as
- # non-outliers
- ev = event_map[e_id]
- assert not ev.internal_metadata.is_outlier()
-
- ev_infos.append(
- _NewEventInfo(
- event=ev,
- state=events_to_state[e_id],
- claimed_auth_event_map={
- (
- auth_events[a_id].type,
- auth_events[a_id].state_key,
- ): auth_events[a_id]
- for a_id in ev.auth_event_ids()
- if a_id in auth_events
- },
- )
- )
-
- if ev_infos:
- await self._auth_and_persist_events(
- dest, room_id, ev_infos, backfilled=True
- )
-
- # Step 2: Persist the rest of the events in the chunk one by one
- events.sort(key=lambda e: e.depth)
-
- for event in events:
- if event in events_to_state:
- continue
-
- # For paranoia we ensure that these events are marked as
- # non-outliers
- assert not event.internal_metadata.is_outlier()
-
- context = await self.state_handler.compute_event_context(event)
-
- # We store these one at a time since each event depends on the
- # previous to work out the state.
- # TODO: We can probably do something more clever here.
- await self._auth_and_persist_event(dest, event, context, backfilled=True)
-
- return events
-
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
) -> bool:
@@ -1326,14 +308,14 @@ class FederationHandler(BaseHandler):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- await self.backfill(
+ await self._federation_event_handler.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
return True
- except SynapseError as e:
+ except (SynapseError, InvalidResponseError) as e:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
except HttpResponseException as e:
@@ -1417,115 +399,6 @@ class FederationHandler(BaseHandler):
return False
- async def _get_events_and_persist(
- self, destination: str, room_id: str, events: Iterable[str]
- ) -> None:
- """Fetch the given events from a server, and persist them as outliers.
-
- This function *does not* recursively get missing auth events of the
- newly fetched events. Callers must include in the `events` argument
- any missing events from the auth chain.
-
- Logs a warning if we can't find the given event.
- """
-
- room_version = await self.store.get_room_version(room_id)
-
- event_map: Dict[str, EventBase] = {}
-
- async def get_event(event_id: str):
- with nested_logging_context(event_id):
- try:
- event = await self.federation_client.get_pdu(
- [destination],
- event_id,
- room_version,
- outlier=True,
- )
- if event is None:
- logger.warning(
- "Server %s didn't return event %s",
- destination,
- event_id,
- )
- return
-
- event_map[event.event_id] = event
-
- except Exception as e:
- logger.warning(
- "Error fetching missing state/auth event %s: %s %s",
- event_id,
- type(e),
- e,
- )
-
- await concurrently_execute(get_event, events, 5)
-
- # Make a map of auth events for each event. We do this after fetching
- # all the events as some of the events' auth events will be in the list
- # of requested events.
-
- auth_events = [
- aid
- for event in event_map.values()
- for aid in event.auth_event_ids()
- if aid not in event_map
- ]
- persisted_events = await self.store.get_events(
- auth_events,
- allow_rejected=True,
- )
-
- event_infos = []
- for event in event_map.values():
- auth = {}
- for auth_event_id in event.auth_event_ids():
- ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
- if ae:
- auth[(ae.type, ae.state_key)] = ae
- else:
- logger.info("Missing auth event %s", auth_event_id)
-
- event_infos.append(_NewEventInfo(event, None, auth))
-
- if event_infos:
- await self._auth_and_persist_events(
- destination,
- room_id,
- event_infos,
- )
-
- def _sanity_check_event(self, ev: EventBase) -> None:
- """
- Do some early sanity checks of a received event
-
- In particular, checks it doesn't have an excessive number of
- prev_events or auth_events, which could cause a huge state resolution
- or cascade of event fetches.
-
- Args:
- ev: event to be checked
-
- Raises:
- SynapseError if the event does not pass muster
- """
- if len(ev.prev_event_ids()) > 20:
- logger.warning(
- "Rejecting event %s which has %i prev_events",
- ev.event_id,
- len(ev.prev_event_ids()),
- )
- raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events")
-
- if len(ev.auth_event_ids()) > 10:
- logger.warning(
- "Rejecting event %s which has %i auth_events",
- ev.event_id,
- len(ev.auth_event_ids()),
- )
- raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
-
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing.
@@ -1591,9 +464,9 @@ class FederationHandler(BaseHandler):
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
- assert room_id not in self.room_queues
+ assert room_id not in self._federation_event_handler.room_queues
- self.room_queues[room_id] = []
+ self._federation_event_handler.room_queues[room_id] = []
await self._clean_room_for_join(room_id)
@@ -1667,8 +540,8 @@ class FederationHandler(BaseHandler):
logger.debug("Finished joining %s to %s", joinee, room_id)
return event.event_id, max_stream_id
finally:
- room_queue = self.room_queues[room_id]
- del self.room_queues[room_id]
+ room_queue = self._federation_event_handler.room_queues[room_id]
+ del self._federation_event_handler.room_queues[room_id]
# we don't need to wait for the queued events to be processed -
# it's just a best-effort thing at this point. We do want to do
@@ -1744,7 +617,7 @@ class FederationHandler(BaseHandler):
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify(
+ stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
return event.event_id, stream_id
@@ -1764,7 +637,7 @@ class FederationHandler(BaseHandler):
p,
)
with nested_logging_context(p.event_id):
- await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self._federation_event_handler.on_receive_pdu(origin, p)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1857,7 +730,7 @@ class FederationHandler(BaseHandler):
raise
# Ensure the user can even join the room.
- await self._check_join_restrictions(context, event)
+ await self._federation_event_handler.check_join_restrictions(context, event)
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
@@ -1934,7 +807,9 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
- await self.persist_events_and_notify(event.room_id, [(event, context)])
+ await self._federation_event_handler.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
return event
@@ -1961,7 +836,7 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify(
+ stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -2104,116 +979,6 @@ class FederationHandler(BaseHandler):
return event
- @log_function
- async def on_send_membership_event(
- self, origin: str, event: EventBase
- ) -> Tuple[EventBase, EventContext]:
- """
- We have received a join/leave/knock event for a room via send_join/leave/knock.
-
- Verify that event and send it into the room on the remote homeserver's behalf.
-
- This is quite similar to on_receive_pdu, with the following principal
- differences:
- * only membership events are permitted (and only events with
- sender==state_key -- ie, no kicks or bans)
- * *We* send out the event on behalf of the remote server.
- * We enforce the membership restrictions of restricted rooms.
- * Rejected events result in an exception rather than being stored.
-
- There are also other differences, however it is not clear if these are by
- design or omission. In particular, we do not attempt to backfill any missing
- prev_events.
-
- Args:
- origin: The homeserver of the remote (joining/invited/knocking) user.
- event: The member event that has been signed by the remote homeserver.
-
- Returns:
- The event and context of the event after inserting it into the room graph.
-
- Raises:
- SynapseError if the event is not accepted into the room
- """
- logger.debug(
- "on_send_membership_event: Got event: %s, signatures: %s",
- event.event_id,
- event.signatures,
- )
-
- if get_domain_from_id(event.sender) != origin:
- logger.info(
- "Got send_membership request for user %r from different origin %s",
- event.sender,
- origin,
- )
- raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
- if event.sender != event.state_key:
- raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
-
- assert not event.internal_metadata.outlier
-
- # Send this event on behalf of the other server.
- #
- # The remote server isn't a full participant in the room at this point, so
- # may not have an up-to-date list of the other homeservers participating in
- # the room, so we send it on their behalf.
- event.internal_metadata.send_on_behalf_of = origin
-
- context = await self.state_handler.compute_event_context(event)
- context = await self._check_event_auth(origin, event, context)
- if context.rejected:
- raise SynapseError(
- 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
- )
-
- # for joins, we need to check the restrictions of restricted rooms
- if event.membership == Membership.JOIN:
- await self._check_join_restrictions(context, event)
-
- # for knock events, we run the third-party event rules. It's not entirely clear
- # why we don't do this for other sorts of membership events.
- if event.membership == Membership.KNOCK:
- event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
- event, context
- )
- if not event_allowed:
- logger.info("Sending of knock %s forbidden by third-party rules", event)
- raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN
- )
-
- # all looks good, we can persist the event.
- await self._run_push_actions_and_persist_event(event, context)
- return event, context
-
- async def _check_join_restrictions(
- self, context: EventContext, event: EventBase
- ) -> None:
- """Check that restrictions in restricted join rules are matched
-
- Called when we receive a join event via send_join.
-
- Raises an auth error if the restrictions are not matched.
- """
- prev_state_ids = await context.get_prev_state_ids()
-
- # Check if the user is already in the room or invited to the room.
- user_id = event.state_key
- prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
-
- # Check if the member should be allowed access via membership in a space.
- await self._event_auth_handler.check_restricted_join_rules(
- prev_state_ids,
- event.room_version,
- user_id,
- prev_member_event,
- )
-
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
@@ -2314,131 +1079,6 @@ class FederationHandler(BaseHandler):
else:
return None
- async def get_min_depth_for_context(self, context: str) -> int:
- return await self.store.get_min_depth(context)
-
- async def _auth_and_persist_event(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
- backfilled: bool = False,
- ) -> None:
- """
- Process an event by performing auth checks and then persisting to the database.
-
- Args:
- origin: The host the event originates from.
- event: The event itself.
- context:
- The event context.
-
- state:
- The state events used to check the event for soft-fail. If this is
- not provided the current state events will be used.
-
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
-
- Only populated where we may not already have persisted these events -
- for example, when populating outliers.
-
- backfilled: True if the event was backfilled.
- """
- context = await self._check_event_auth(
- origin,
- event,
- context,
- state=state,
- claimed_auth_event_map=claimed_auth_event_map,
- backfilled=backfilled,
- )
-
- await self._run_push_actions_and_persist_event(event, context, backfilled)
-
- async def _run_push_actions_and_persist_event(
- self, event: EventBase, context: EventContext, backfilled: bool = False
- ):
- """Run the push actions for a received event, and persist it.
-
- Args:
- event: The event itself.
- context: The event context.
- backfilled: True if the event was backfilled.
- """
- try:
- if (
- not event.internal_metadata.is_outlier()
- and not backfilled
- and not context.rejected
- ):
- await self.action_generator.handle_push_actions_for_event(
- event, context
- )
-
- await self.persist_events_and_notify(
- event.room_id, [(event, context)], backfilled=backfilled
- )
- except Exception:
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
- raise
-
- async def _auth_and_persist_events(
- self,
- origin: str,
- room_id: str,
- event_infos: Collection[_NewEventInfo],
- backfilled: bool = False,
- ) -> None:
- """Creates the appropriate contexts and persists events. The events
- should not depend on one another, e.g. this should be used to persist
- a bunch of outliers, but not a chunk of individual events that depend
- on each other for state calculations.
-
- Notifies about the events where appropriate.
- """
-
- if not event_infos:
- return
-
- async def prep(ev_info: _NewEventInfo):
- event = ev_info.event
- with nested_logging_context(suffix=event.event_id):
- res = await self.state_handler.compute_event_context(
- event, old_state=ev_info.state
- )
- res = await self._check_event_auth(
- origin,
- event,
- res,
- state=ev_info.state,
- claimed_auth_event_map=ev_info.claimed_auth_event_map,
- backfilled=backfilled,
- )
- return res
-
- contexts = await make_deferred_yieldable(
- defer.gatherResults(
- [run_in_background(prep, ev_info) for ev_info in event_infos],
- consumeErrors=True,
- )
- )
-
- await self.persist_events_and_notify(
- room_id,
- [
- (ev_info.event, context)
- for ev_info, context in zip(event_infos, contexts)
- ],
- backfilled=backfilled,
- )
-
async def _persist_auth_tree(
self,
origin: str,
@@ -2536,7 +1176,7 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
if auth_events or state:
- await self.persist_events_and_notify(
+ await self._federation_event_handler.persist_events_and_notify(
room_id,
[
(e, events_to_context[e.event_id])
@@ -2548,108 +1188,10 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
- return await self.persist_events_and_notify(
+ return await self._federation_event_handler.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
- async def _check_for_soft_fail(
- self,
- event: EventBase,
- state: Optional[Iterable[EventBase]],
- backfilled: bool,
- origin: str,
- ) -> None:
- """Checks if we should soft fail the event; if so, marks the event as
- such.
-
- Args:
- event
- state: The state at the event if we don't have all the event's prev events
- backfilled: Whether the event is from backfill
- origin: The host the event originates from.
- """
- # For new (non-backfilled and non-outlier) events we check if the event
- # passes auth based on the current state. If it doesn't then we
- # "soft-fail" the event.
- if backfilled or event.internal_metadata.is_outlier():
- return
-
- extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
- extrem_ids = set(extrem_ids_list)
- prev_event_ids = set(event.prev_event_ids())
-
- if extrem_ids == prev_event_ids:
- # If they're the same then the current state is the same as the
- # state at the event, so no point rechecking auth for soft fail.
- return
-
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- # Calculate the "current state".
- if state is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
- # however we still want to do checks as gaps are easy to
- # maliciously manufacture.
- #
- # So we use a "current state" that is actually a state
- # resolution across the current forward extremities and the
- # given state at the event. This should correctly handle cases
- # like bans, especially with state res v2.
-
- state_sets_d = await self.state_store.get_state_groups(
- event.room_id, extrem_ids
- )
- state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
- state_sets.append(state)
- current_states = await self.state_handler.resolve_events(
- room_version, state_sets, event
- )
- current_state_ids: StateMap[str] = {
- k: e.event_id for k, e in current_states.items()
- }
- else:
- current_state_ids = await self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
- )
-
- logger.debug(
- "Doing soft-fail check for %s: state %s",
- event.event_id,
- current_state_ids,
- )
-
- # Now check if event pass auth against said current state
- auth_types = auth_types_for_event(room_version_obj, event)
- current_state_ids_list = [
- e for k, e in current_state_ids.items() if k in auth_types
- ]
-
- auth_events_map = await self.store.get_events(current_state_ids_list)
- current_auth_events = {
- (e.type, e.state_key): e for e in auth_events_map.values()
- }
-
- try:
- event_auth.check(room_version_obj, event, auth_events=current_auth_events)
- except AuthError as e:
- logger.warning(
- "Soft-failing %r (from %s) because %s",
- event,
- e,
- origin,
- extra={
- "room_id": event.room_id,
- "mxid": event.sender,
- "hs": origin,
- },
- )
- soft_failed_event_counter.inc()
- event.internal_metadata.soft_failed = True
-
async def on_get_missing_events(
self,
origin: str,
@@ -2678,334 +1220,6 @@ class FederationHandler(BaseHandler):
return missing_events
- async def _check_event_auth(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
- backfilled: bool = False,
- ) -> EventContext:
- """
- Checks whether an event should be rejected (for failing auth checks).
-
- Args:
- origin: The host the event originates from.
- event: The event itself.
- context:
- The event context.
-
- state:
- The state events used to check the event for soft-fail. If this is
- not provided the current state events will be used.
-
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
-
- Only populated where we may not already have persisted these events -
- for example, when populating outliers, or the state for a backwards
- extremity.
-
- backfilled: True if the event was backfilled.
-
- Returns:
- The updated context object.
- """
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- if claimed_auth_event_map:
- # if we have a copy of the auth events from the event, use that as the
- # basis for auth.
- auth_events = claimed_auth_event_map
- else:
- # otherwise, we calculate what the auth events *should* be, and use that
- prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self._event_auth_handler.compute_auth_events(
- event, prev_state_ids, for_verification=True
- )
- auth_events_x = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
-
- try:
- (
- context,
- auth_events_for_auth,
- ) = await self._update_auth_events_and_context_for_auth(
- origin, event, context, auth_events
- )
- except Exception:
- # We don't really mind if the above fails, so lets not fail
- # processing if it does. However, it really shouldn't fail so
- # let's still log as an exception since we'll still want to fix
- # any bugs.
- logger.exception(
- "Failed to double check auth events for %s with remote. "
- "Ignoring failure and continuing processing of event.",
- event.event_id,
- )
- auth_events_for_auth = auth_events
-
- try:
- event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
- except AuthError as e:
- logger.warning("Failed auth resolution for %r because %s", event, e)
- context.rejected = RejectedReason.AUTH_ERROR
-
- if not context.rejected:
- await self._check_for_soft_fail(event, state, backfilled, origin=origin)
-
- if event.type == EventTypes.GuestAccess and not context.rejected:
- await self.maybe_kick_guest_users(event)
-
- # If we are going to send this event over federation we precaclculate
- # the joined hosts.
- if event.internal_metadata.get_send_on_behalf_of():
- await self.event_creation_handler.cache_joined_hosts_for_event(
- event, context
- )
-
- return context
-
- async def _update_auth_events_and_context_for_auth(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- input_auth_events: StateMap[EventBase],
- ) -> Tuple[EventContext, StateMap[EventBase]]:
- """Helper for _check_event_auth. See there for docs.
-
- Checks whether a given event has the expected auth events. If it
- doesn't then we talk to the remote server to compare state to see if
- we can come to a consensus (e.g. if one server missed some valid
- state).
-
- This attempts to resolve any potential divergence of state between
- servers, but is not essential and so failures should not block further
- processing of the event.
-
- Args:
- origin:
- event:
- context:
-
- input_auth_events:
- Map from (event_type, state_key) to event
-
- Normally, our calculated auth_events based on the state of the room
- at the event's position in the DAG, though occasionally (eg if the
- event is an outlier), may be the auth events claimed by the remote
- server.
-
- Returns:
- updated context, updated auth event map
- """
- # take a copy of input_auth_events before we modify it.
- auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
-
- event_auth_events = set(event.auth_event_ids())
-
- # missing_auth is the set of the event's auth_events which we don't yet have
- # in auth_events.
- missing_auth = event_auth_events.difference(
- e.event_id for e in auth_events.values()
- )
-
- # if we have missing events, we need to fetch those events from somewhere.
- #
- # we start by checking if they are in the store, and then try calling /event_auth/.
- if missing_auth:
- have_events = await self.store.have_seen_events(event.room_id, missing_auth)
- logger.debug("Events %s are in the store", have_events)
- missing_auth.difference_update(have_events)
-
- if missing_auth:
- # If we don't have all the auth events, we need to get them.
- logger.info("auth_events contains unknown events: %s", missing_auth)
- try:
- try:
- remote_auth_chain = await self.federation_client.get_event_auth(
- origin, event.room_id, event.event_id
- )
- except RequestSendFailed as e1:
- # The other side isn't around or doesn't implement the
- # endpoint, so lets just bail out.
- logger.info("Failed to get event auth from remote: %s", e1)
- return context, auth_events
-
- seen_remotes = await self.store.have_seen_events(
- event.room_id, [e.event_id for e in remote_auth_chain]
- )
-
- for e in remote_auth_chain:
- if e.event_id in seen_remotes:
- continue
-
- if e.event_id == event.event_id:
- continue
-
- try:
- auth_ids = e.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in remote_auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- e.internal_metadata.outlier = True
-
- logger.debug(
- "_check_event_auth %s missing_auth: %s",
- event.event_id,
- e.event_id,
- )
- missing_auth_event_context = (
- await self.state_handler.compute_event_context(e)
- )
- await self._auth_and_persist_event(
- origin,
- e,
- missing_auth_event_context,
- claimed_auth_event_map=auth,
- )
-
- if e.event_id in event_auth_events:
- auth_events[(e.type, e.state_key)] = e
- except AuthError:
- pass
-
- except Exception:
- logger.exception("Failed to get auth chain")
-
- if event.internal_metadata.is_outlier():
- # XXX: given that, for an outlier, we'll be working with the
- # event's *claimed* auth events rather than those we calculated:
- # (a) is there any point in this test, since different_auth below will
- # obviously be empty
- # (b) alternatively, why don't we do it earlier?
- logger.info("Skipping auth_event fetch for outlier")
- return context, auth_events
-
- different_auth = event_auth_events.difference(
- e.event_id for e in auth_events.values()
- )
-
- if not different_auth:
- return context, auth_events
-
- logger.info(
- "auth_events refers to events which are not in our calculated auth "
- "chain: %s",
- different_auth,
- )
-
- # XXX: currently this checks for redactions but I'm not convinced that is
- # necessary?
- different_events = await self.store.get_events_as_list(different_auth)
-
- for d in different_events:
- if d.room_id != event.room_id:
- logger.warning(
- "Event %s refers to auth_event %s which is in a different room",
- event.event_id,
- d.event_id,
- )
-
- # don't attempt to resolve the claimed auth events against our own
- # in this case: just use our own auth events.
- #
- # XXX: should we reject the event in this case? It feels like we should,
- # but then shouldn't we also do so if we've failed to fetch any of the
- # auth events?
- return context, auth_events
-
- # now we state-resolve between our own idea of the auth events, and the remote's
- # idea of them.
-
- local_state = auth_events.values()
- remote_auth_events = dict(auth_events)
- remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
- remote_state = remote_auth_events.values()
-
- room_version = await self.store.get_room_version_id(event.room_id)
- new_state = await self.state_handler.resolve_events(
- room_version, (local_state, remote_state), event
- )
-
- logger.info(
- "After state res: updating auth_events with new state %s",
- {
- (d.type, d.state_key): d.event_id
- for d in new_state.values()
- if auth_events.get((d.type, d.state_key)) != d
- },
- )
-
- auth_events.update(new_state)
-
- context = await self._update_context_for_auth_events(
- event, context, auth_events
- )
-
- return context, auth_events
-
- async def _update_context_for_auth_events(
- self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
- ) -> EventContext:
- """Update the state_ids in an event context after auth event resolution,
- storing the changes as a new state group.
-
- Args:
- event: The event we're handling the context for
-
- context: initial event context
-
- auth_events: Events to update in the event context.
-
- Returns:
- new event context
- """
- # exclude the state key of the new event from the current_state in the context.
- if event.is_state():
- event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
- else:
- event_key = None
- state_updates = {
- k: a.event_id for k, a in auth_events.items() if k != event_key
- }
-
- current_state_ids = await context.get_current_state_ids()
- current_state_ids = dict(current_state_ids) # type: ignore
-
- current_state_ids.update(state_updates)
-
- prev_state_ids = await context.get_prev_state_ids()
- prev_state_ids = dict(prev_state_ids)
-
- prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
- # create a new state group as a delta from the existing one.
- prev_group = context.state_group
- state_group = await self.state_store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=prev_group,
- delta_ids=state_updates,
- current_state_ids=current_state_ids,
- )
-
- return EventContext.with_state(
- state_group=state_group,
- state_group_before_event=context.state_group_before_event,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
- prev_group=prev_group,
- delta_ids=state_updates,
- )
-
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
@@ -3392,99 +1606,6 @@ class FederationHandler(BaseHandler):
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")
- async def persist_events_and_notify(
- self,
- room_id: str,
- event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
- backfilled: bool = False,
- ) -> int:
- """Persists events and tells the notifier/pushers about them, if
- necessary.
-
- Args:
- room_id: The room ID of events being persisted.
- event_and_contexts: Sequence of events with their associated
- context that should be persisted. All events must belong to
- the same room.
- backfilled: Whether these events are a result of
- backfilling or not
-
- Returns:
- The stream ID after which all events have been persisted.
- """
- if not event_and_contexts:
- return self.store.get_current_events_token()
-
- instance = self.config.worker.events_shard_config.get_instance(room_id)
- if instance != self._instance_name:
- # Limit the number of events sent over replication. We choose 200
- # here as that is what we default to in `max_request_body_size(..)`
- for batch in batch_iter(event_and_contexts, 200):
- result = await self._send_events(
- instance_name=instance,
- store=self.store,
- room_id=room_id,
- event_and_contexts=batch,
- backfilled=backfilled,
- )
- return result["max_stream_id"]
- else:
- assert self.storage.persistence
-
- # Note that this returns the events that were persisted, which may not be
- # the same as were passed in if some were deduplicated due to transaction IDs.
- events, max_stream_token = await self.storage.persistence.persist_events(
- event_and_contexts, backfilled=backfilled
- )
-
- if self._ephemeral_messages_enabled:
- for event in events:
- # If there's an expiry timestamp on the event, schedule its expiry.
- self._message_handler.maybe_schedule_expiry(event)
-
- if not backfilled: # Never notify for backfilled events
- for event in events:
- await self._notify_persisted_event(event, max_stream_token)
-
- return max_stream_token.stream
-
- async def _notify_persisted_event(
- self, event: EventBase, max_stream_token: RoomStreamToken
- ) -> None:
- """Checks to see if notifier/pushers should be notified about the
- event or not.
-
- Args:
- event:
- max_stream_id: The max_stream_id returned by persist_events
- """
-
- extra_users = []
- if event.type == EventTypes.Member:
- target_user_id = event.state_key
-
- # We notify for memberships if its an invite for one of our
- # users
- if event.internal_metadata.is_outlier():
- if event.membership != Membership.INVITE:
- if not self.is_mine_id(target_user_id):
- return
-
- target_user = UserID.from_string(target_user_id)
- extra_users.append(target_user)
- elif event.internal_metadata.is_outlier():
- return
-
- # the event has been persisted so it should have a stream ordering.
- assert event.internal_metadata.stream_ordering
-
- event_pos = PersistedEventPosition(
- self._instance_name, event.internal_metadata.stream_ordering
- )
- self.notifier.on_new_room_event(
- event, event_pos, max_stream_token, extra_users=extra_users
- )
-
async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
new file mode 100644
index 00000000..9f055f00
--- /dev/null
+++ b/synapse/handlers/federation_event.py
@@ -0,0 +1,1825 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from http import HTTPStatus
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Container,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
+
+import attr
+from prometheus_client import Counter
+
+from twisted.internet import defer
+
+from synapse import event_auth
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RejectedReason,
+ RoomEncryptionAlgorithms,
+)
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ FederationError,
+ HttpResponseException,
+ RequestSendFailed,
+ SynapseError,
+)
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.federation.federation_client import InvalidResponseError
+from synapse.handlers._base import BaseHandler
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
+from synapse.logging.utils import log_function
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
+from synapse.replication.http.federation import (
+ ReplicationFederationSendEventsRestServlet,
+)
+from synapse.state import StateResolutionStore
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ MutableStateMap,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
+from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util.iterutils import batch_iter
+from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import shortstr
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+soft_failed_event_counter = Counter(
+ "synapse_federation_soft_failed_events_total",
+ "Events received over federation that we marked as soft_failed",
+)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _NewEventInfo:
+ """Holds information about a received event, ready for passing to _auth_and_persist_events
+
+ Attributes:
+ event: the received event
+
+ claimed_auth_event_map: a map of (type, state_key) => event for the event's
+ claimed auth_events.
+
+ This can include events which have not yet been persisted, in the case that
+ we are backfilling a batch of events.
+
+ Note: May be incomplete: if we were unable to find all of the claimed auth
+ events. Also, treat the contents with caution: the events might also have
+ been rejected, might not yet have been authorized themselves, or they might
+ be in the wrong room.
+
+ """
+
+ event: EventBase
+ claimed_auth_event_map: StateMap[EventBase]
+
+
+class FederationEventHandler(BaseHandler):
+ """Handles events that originated from federation.
+
+ Responsible for handing incoming events and passing them on to the rest
+ of the homeserver (including auth and state conflict resolutions)
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+
+ self.state_handler = hs.get_state_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
+ self._message_handler = hs.get_message_handler()
+ self.action_generator = hs.get_action_generator()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
+ self.federation_client = hs.get_federation_client()
+ self.third_party_event_rules = hs.get_third_party_event_rules()
+
+ self.is_mine_id = hs.is_mine_id
+ self._instance_name = hs.get_instance_name()
+
+ self.config = hs.config
+ self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
+
+ self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
+ if hs.config.worker_app:
+ self._user_device_resync = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
+ else:
+ self._device_list_updater = hs.get_device_handler().device_list_updater
+
+ # When joining a room we need to queue any events for that room up.
+ # For each room, a list of (pdu, origin) tuples.
+ # TODO: replace this with something more elegant, probably based around the
+ # federation event staging area.
+ self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
+
+ self._room_pdu_linearizer = Linearizer("fed_room_pdu")
+
+ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
+ """Process a PDU received via a federation /send/ transaction
+
+ Args:
+ origin: server which initiated the /send/ transaction. Will
+ be used to fetch missing events or state.
+ pdu: received PDU
+ """
+
+ room_id = pdu.room_id
+ event_id = pdu.event_id
+
+ # We reprocess pdus when we have seen them only as outliers
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
+ )
+
+ # FIXME: Currently we fetch an event again when we already have it
+ # if it has been marked as an outlier.
+ if existing:
+ if not existing.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received event %s which we have already seen", event_id
+ )
+ return
+ if pdu.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received outlier %s which we already have as an outlier",
+ event_id,
+ )
+ return
+ logger.info("De-outliering event %s", event_id)
+
+ # do some initial sanity-checking of the event. In particular, make
+ # sure it doesn't have hundreds of prev_events or auth_events, which
+ # could cause a huge state resolution or cascade of event fetches.
+ try:
+ self._sanity_check_event(pdu)
+ except SynapseError as err:
+ logger.warning("Received event failed sanity checks")
+ raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
+
+ # If we are currently in the process of joining this room, then we
+ # queue up events for later processing.
+ if room_id in self.room_queues:
+ logger.info(
+ "Queuing PDU from %s for now: join in progress",
+ origin,
+ )
+ self.room_queues[room_id].append((pdu, origin))
+ return
+
+ # If we're not in the room just ditch the event entirely. This is
+ # probably an old server that has come back and thinks we're still in
+ # the room (or we've been rejoined to the room by a state reset).
+ #
+ # Note that if we were never in the room then we would have already
+ # dropped the event, since we wouldn't know the room version.
+ is_in_room = await self._event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
+ if not is_in_room:
+ logger.info(
+ "Ignoring PDU from %s as we're not in the room",
+ origin,
+ )
+ return None
+
+ # Check that the event passes auth based on the state at the event. This is
+ # done for events that are to be added to the timeline (non-outliers).
+ #
+ # Get missing pdus if necessary:
+ # - Fetching any missing prev events to fill in gaps in the graph
+ # - Fetching state if we have a hole in the graph
+ if not pdu.internal_metadata.is_outlier():
+ prevs = set(pdu.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if missing_prevs:
+ # We only backfill backwards to the min depth.
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
+ logger.debug("min_depth: %d", min_depth)
+
+ if min_depth is not None and pdu.depth > min_depth:
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ logger.info(
+ "Acquiring room lock to fetch %d missing prev_events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
+ logger.info(
+ "Acquired room lock to fetch %d missing prev_events",
+ len(missing_prevs),
+ )
+
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ ) from e
+
+ # Update the set of things we've seen after trying to
+ # fetch the missing stuff
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ logger.info("Found all missing prev_events")
+
+ if missing_prevs:
+ # since this event was pushed to us, it is possible for it to
+ # become the only forward-extremity in the room, and we would then
+ # trust its state to be the state for the whole room. This is very
+ # bad. Further, if the event was pushed to us, there is no excuse
+ # for us not to have all the prev_events. (XXX: apart from
+ # min_depth?)
+ #
+ # We therefore reject any such events.
+ logger.warning(
+ "Rejecting: failed to fetch %d prev events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
+ )
+
+ await self._process_received_pdu(origin, pdu, state=None)
+
+ @log_function
+ async def on_send_membership_event(
+ self, origin: str, event: EventBase
+ ) -> Tuple[EventBase, EventContext]:
+ """
+ We have received a join/leave/knock event for a room via send_join/leave/knock.
+
+ Verify that event and send it into the room on the remote homeserver's behalf.
+
+ This is quite similar to on_receive_pdu, with the following principal
+ differences:
+ * only membership events are permitted (and only events with
+ sender==state_key -- ie, no kicks or bans)
+ * *We* send out the event on behalf of the remote server.
+ * We enforce the membership restrictions of restricted rooms.
+ * Rejected events result in an exception rather than being stored.
+
+ There are also other differences, however it is not clear if these are by
+ design or omission. In particular, we do not attempt to backfill any missing
+ prev_events.
+
+ Args:
+ origin: The homeserver of the remote (joining/invited/knocking) user.
+ event: The member event that has been signed by the remote homeserver.
+
+ Returns:
+ The event and context of the event after inserting it into the room graph.
+
+ Raises:
+ SynapseError if the event is not accepted into the room
+ """
+ logger.debug(
+ "on_send_membership_event: Got event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got send_membership request for user %r from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ if event.sender != event.state_key:
+ raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
+
+ assert not event.internal_metadata.outlier
+
+ # Send this event on behalf of the other server.
+ #
+ # The remote server isn't a full participant in the room at this point, so
+ # may not have an up-to-date list of the other homeservers participating in
+ # the room, so we send it on their behalf.
+ event.internal_metadata.send_on_behalf_of = origin
+
+ context = await self.state_handler.compute_event_context(event)
+ context = await self._check_event_auth(origin, event, context)
+ if context.rejected:
+ raise SynapseError(
+ 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
+ )
+
+ # for joins, we need to check the restrictions of restricted rooms
+ if event.membership == Membership.JOIN:
+ await self.check_join_restrictions(context, event)
+
+ # for knock events, we run the third-party event rules. It's not entirely clear
+ # why we don't do this for other sorts of membership events.
+ if event.membership == Membership.KNOCK:
+ event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ # all looks good, we can persist the event.
+ await self._run_push_actions_and_persist_event(event, context)
+ return event, context
+
+ async def check_join_restrictions(
+ self, context: EventContext, event: EventBase
+ ) -> None:
+ """Check that restrictions in restricted join rules are matched
+
+ Called when we receive a join event via send_join.
+
+ Raises an auth error if the restrictions are not matched.
+ """
+ prev_state_ids = await context.get_prev_state_ids()
+
+ # Check if the user is already in the room or invited to the room.
+ user_id = event.state_key
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+
+ # Check if the member should be allowed access via membership in a space.
+ await self._event_auth_handler.check_restricted_join_rules(
+ prev_state_ids,
+ event.room_version,
+ user_id,
+ prev_member_event,
+ )
+
+ @log_function
+ async def backfill(
+ self, dest: str, room_id: str, limit: int, extremities: List[str]
+ ) -> None:
+ """Trigger a backfill request to `dest` for the given `room_id`
+
+ This will attempt to get more events from the remote. If the other side
+ has no new events to offer, this will return an empty list.
+
+ As the events are received, we check their signatures, and also do some
+ sanity-checking on them. If any of the backfilled events are invalid,
+ this method throws a SynapseError.
+
+ We might also raise an InvalidResponseError if the response from the remote
+ server is just bogus.
+
+ TODO: make this more useful to distinguish failures of the remote
+ server from invalid events (there is probably no point in trying to
+ re-fetch invalid events from every other HS in the room.)
+ """
+ if dest == self.server_name:
+ raise SynapseError(400, "Can't backfill from self.")
+
+ events = await self.federation_client.backfill(
+ dest, room_id, limit=limit, extremities=extremities
+ )
+
+ if not events:
+ return
+
+ # if there are any events in the wrong room, the remote server is buggy and
+ # should not be trusted.
+ for ev in events:
+ if ev.room_id != room_id:
+ raise InvalidResponseError(
+ f"Remote server {dest} returned event {ev.event_id} which is in "
+ f"room {ev.room_id}, when we were backfilling in {room_id}"
+ )
+
+ await self._process_pulled_events(dest, events, backfilled=True)
+
+ async def _get_missing_events_for_pdu(
+ self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+ ) -> None:
+ """
+ Args:
+ origin: Origin of the pdu. Will be called to get the missing events
+ pdu: received pdu
+ prevs: List of event ids which we are missing
+ min_depth: Minimum depth of events to return.
+ """
+
+ room_id = pdu.room_id
+ event_id = pdu.event_id
+
+ seen = await self.store.have_events_in_timeline(prevs)
+
+ if not prevs - seen:
+ return
+
+ latest_list = await self.store.get_latest_event_ids_in_room(room_id)
+
+ # We add the prev events that we have seen to the latest
+ # list to ensure the remote server doesn't give them to us
+ latest = set(latest_list)
+ latest |= seen
+
+ logger.info(
+ "Requesting missing events between %s and %s",
+ shortstr(latest),
+ event_id,
+ )
+
+ # XXX: we set timeout to 10s to help workaround
+ # https://github.com/matrix-org/synapse/issues/1733.
+ # The reason is to avoid holding the linearizer lock
+ # whilst processing inbound /send transactions, causing
+ # FDs to stack up and block other inbound transactions
+ # which empirically can currently take up to 30 minutes.
+ #
+ # N.B. this explicitly disables retry attempts.
+ #
+ # N.B. this also increases our chances of falling back to
+ # fetching fresh state for the room if the missing event
+ # can't be found, which slightly reduces our security.
+ # it may also increase our DAG extremity count for the room,
+ # causing additional state resolution? See #1760.
+ # However, fetching state doesn't hold the linearizer lock
+ # apparently.
+ #
+ # see https://github.com/matrix-org/synapse/pull/1744
+ #
+ # ----
+ #
+ # Update richvdh 2018/09/18: There are a number of problems with timing this
+ # request out aggressively on the client side:
+ #
+ # - it plays badly with the server-side rate-limiter, which starts tarpitting you
+ # if you send too many requests at once, so you end up with the server carefully
+ # working through the backlog of your requests, which you have already timed
+ # out.
+ #
+ # - for this request in particular, we now (as of
+ # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the
+ # server can't produce a plausible-looking set of prev_events - so we becone
+ # much more likely to reject the event.
+ #
+ # - contrary to what it says above, we do *not* fall back to fetching fresh state
+ # for the room if get_missing_events times out. Rather, we give up processing
+ # the PDU whose prevs we are missing, which then makes it much more likely that
+ # we'll end up back here for the *next* PDU in the list, which exacerbates the
+ # problem.
+ #
+ # - the aggressive 10s timeout was introduced to deal with incoming federation
+ # requests taking 8 hours to process. It's not entirely clear why that was going
+ # on; certainly there were other issues causing traffic storms which are now
+ # resolved, and I think in any case we may be more sensible about our locking
+ # now. We're *certainly* more sensible about our logging.
+ #
+ # All that said: Let's try increasing the timeout to 60s and see what happens.
+
+ try:
+ missing_events = await self.federation_client.get_missing_events(
+ origin,
+ room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ timeout=60000,
+ )
+ except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
+ # We failed to get the missing events, but since we need to handle
+ # the case of `get_missing_events` not returning the necessary
+ # events anyway, it is safe to simply log the error and continue.
+ logger.warning("Failed to get prev_events: %s", e)
+ return
+
+ logger.info("Got %d prev_events", len(missing_events))
+ await self._process_pulled_events(origin, missing_events, backfilled=False)
+
+ async def _process_pulled_events(
+ self, origin: str, events: Iterable[EventBase], backfilled: bool
+ ) -> None:
+ """Process a batch of events we have pulled from a remote server
+
+ Pulls in any events required to auth the events, persists the received events,
+ and notifies clients, if appropriate.
+
+ Assumes the events have already had their signatures and hashes checked.
+
+ Params:
+ origin: The server we received these events from
+ events: The received events.
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ sorted_events = sorted(events, key=lambda x: x.depth)
+
+ for ev in sorted_events:
+ with nested_logging_context(ev.event_id):
+ await self._process_pulled_event(origin, ev, backfilled=backfilled)
+
+ async def _process_pulled_event(
+ self, origin: str, event: EventBase, backfilled: bool
+ ) -> None:
+ """Process a single event that we have pulled from a remote server
+
+ Pulls in any events required to auth the event, persists the received event,
+ and notifies clients, if appropriate.
+
+ Assumes the event has already had its signatures and hashes checked.
+
+ This is somewhat equivalent to on_receive_pdu, but applies somewhat different
+ logic in the case that we are missing prev_events (in particular, it just
+ requests the state at that point, rather than triggering a get_missing_events) -
+ so is appropriate when we have pulled the event from a remote server, rather
+ than having it pushed to us.
+
+ Params:
+ origin: The server we received this event from
+ events: The received event
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+ logger.info("Processing pulled event %s", event)
+
+ # these should not be outliers.
+ assert not event.internal_metadata.is_outlier()
+
+ event_id = event.event_id
+
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
+ )
+ if existing:
+ if not existing.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received event %s which we have already seen",
+ event_id,
+ )
+ return
+ logger.info("De-outliering event %s", event_id)
+
+ try:
+ self._sanity_check_event(event)
+ except SynapseError as err:
+ logger.warning("Event %s failed sanity check: %s", event_id, err)
+ return
+
+ try:
+ state = await self._resolve_state_at_missing_prevs(origin, event)
+ await self._process_received_pdu(
+ origin, event, state=state, backfilled=backfilled
+ )
+ except FederationError as e:
+ if e.code == 403:
+ logger.warning("Pulled event %s failed history check.", event_id)
+ else:
+ raise
+
+ async def _resolve_state_at_missing_prevs(
+ self, dest: str, event: EventBase
+ ) -> Optional[Iterable[EventBase]]:
+ """Calculate the state at an event with missing prev_events.
+
+ This is used when we have pulled a batch of events from a remote server, and
+ still don't have all the prev_events.
+
+ If we already have all the prev_events for `event`, this method does nothing.
+
+ Otherwise, the missing prevs become new backwards extremities, and we fall back
+ to asking the remote server for the state after each missing `prev_event`,
+ and resolving across them.
+
+ That's ok provided we then resolve the state against other bits of the DAG
+ before using it - in other words, that the received event `event` is not going
+ to become the only forwards_extremity in the room (which will ensure that you
+ can't just take over a room by sending an event, withholding its prev_events,
+ and declaring yourself to be an admin in the subsequent state request).
+
+ In other words: we should only call this method if `event` has been *pulled*
+ as part of a batch of missing prev events, or similar.
+
+ Params:
+ dest: the remote server to ask for state at the missing prevs. Typically,
+ this will be the server we got `event` from.
+ event: an event to check for missing prevs.
+
+ Returns:
+ if we already had all the prev events, `None`. Otherwise, returns a list of
+ the events in the state at `event`.
+ """
+ room_id = event.room_id
+ event_id = event.event_id
+
+ prevs = set(event.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ return None
+
+ logger.info(
+ "Event %s is missing prev_events %s: calculating state for a "
+ "backwards extremity",
+ event_id,
+ shortstr(missing_prevs),
+ )
+ # Calculate the state after each of the previous events, and
+ # resolve them to find the correct state at the current event.
+ event_map = {event_id: event}
+ try:
+ # Get the state of the events we know about
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
+
+ # state_maps is a list of mappings from (type, state_key) to event_id
+ state_maps: List[StateMap[str]] = list(ours.values())
+
+ # we don't need this any more, let's delete it.
+ del ours
+
+ # Ask the remote server for the states we don't
+ # know about
+ for p in missing_prevs:
+ logger.info("Requesting state after missing prev_event %s", p)
+
+ with nested_logging_context(p):
+ # note that if any of the missing prevs share missing state or
+ # auth events, the requests to fetch those events are deduped
+ # by the get_pdu_cache in federation_client.
+ remote_state = await self._get_state_after_missing_prev_event(
+ dest, room_id, p
+ )
+
+ remote_state_map = {
+ (x.type, x.state_key): x.event_id for x in remote_state
+ }
+ state_maps.append(remote_state_map)
+
+ for x in remote_state:
+ event_map[x.event_id] = x
+
+ room_version = await self.store.get_room_version_id(room_id)
+ state_map = await self._state_resolution_handler.resolve_events_with_store(
+ room_id,
+ room_version,
+ state_maps,
+ event_map,
+ state_res_store=StateResolutionStore(self.store),
+ )
+
+ # We need to give _process_received_pdu the actual state events
+ # rather than event ids, so generate that now.
+
+ # First though we need to fetch all the events that are in
+ # state_map, so we can build up the state below.
+ evs = await self.store.get_events(
+ list(state_map.values()),
+ get_prev_content=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ event_map.update(evs)
+
+ state = [event_map[e] for e in state_map.values()]
+ except Exception:
+ logger.warning(
+ "Error attempting to resolve state at missing prev_events",
+ exc_info=True,
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ "We can't get valid state history.",
+ affected=event_id,
+ )
+ return state
+
+ async def _get_state_after_missing_prev_event(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ ) -> List[EventBase]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+
+ Returns:
+ A list of events in the state, including the event itself
+ """
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ logger.debug(
+ "state_ids returned %i state events, %i auth events",
+ len(state_event_ids),
+ len(auth_event_ids),
+ )
+
+ # start by just trying to fetch the events from the store
+ desired_events = set(state_event_ids)
+ desired_events.add(event_id)
+ logger.debug("Fetching %i events from cache/store", len(desired_events))
+ fetched_events = await self.store.get_events(
+ desired_events, allow_rejected=True
+ )
+
+ missing_desired_events = desired_events - fetched_events.keys()
+ logger.debug(
+ "We are missing %i events (got %i)",
+ len(missing_desired_events),
+ len(fetched_events),
+ )
+
+ # We probably won't need most of the auth events, so let's just check which
+ # we have for now, rather than thrashing the event cache with them all
+ # unnecessarily.
+
+ # TODO: we probably won't actually need all of the auth events, since we
+ # already have a bunch of the state events. It would be nice if the
+ # federation api gave us a way of finding out which we actually need.
+
+ missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events.difference_update(
+ await self.store.have_seen_events(room_id, missing_auth_events)
+ )
+ logger.debug("We are also missing %i auth events", len(missing_auth_events))
+
+ missing_events = missing_desired_events | missing_auth_events
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ await self.store.get_events(missing_desired_events, allow_rejected=True)
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ del fetched_events[bad_event_id]
+
+ # if we couldn't get the prev event in question, that's a problem.
+ remote_event = fetched_events.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+
+ # missing state at that event is a warning, not a blocker
+ # XXX: this doesn't sound right? it means that we'll end up with incomplete
+ # state.
+ failed_to_fetch = desired_events - fetched_events.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
+ ]
+
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
+
+ return remote_state
+
+ async def _process_received_pdu(
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ backfilled: bool = False,
+ ) -> None:
+ """Called when we have a new pdu. We need to do auth checks and put it
+ through the StateHandler.
+
+ Args:
+ origin: server sending the event
+
+ event: event to be persisted
+
+ state: Normally None, but if we are handling a gap in the graph
+ (ie, we are missing one or more prev_events), the resolved state at the
+ event
+
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+ logger.debug("Processing event: %s", event)
+
+ try:
+ context = await self.state_handler.compute_event_context(
+ event, old_state=state
+ )
+ await self._auth_and_persist_event(
+ origin, event, context, state=state, backfilled=backfilled
+ )
+ except AuthError as e:
+ raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
+
+ if backfilled:
+ return
+
+ # For encrypted messages we check that we know about the sending device,
+ # if we don't then we mark the device cache for that user as stale.
+ if event.type == EventTypes.Encrypted:
+ device_id = event.content.get("device_id")
+ sender_key = event.content.get("sender_key")
+
+ cached_devices = await self.store.get_cached_devices_for_user(event.sender)
+
+ resync = False # Whether we should resync device lists.
+
+ device = None
+ if device_id is not None:
+ device = cached_devices.get(device_id)
+ if device is None:
+ logger.info(
+ "Received event from remote device not in our cache: %s %s",
+ event.sender,
+ device_id,
+ )
+ resync = True
+
+ # We also check if the `sender_key` matches what we expect.
+ if sender_key is not None:
+ # Figure out what sender key we're expecting. If we know the
+ # device and recognize the algorithm then we can work out the
+ # exact key to expect. Otherwise check it matches any key we
+ # have for that device.
+
+ current_keys: Container[str] = []
+
+ if device:
+ keys = device.get("keys", {}).get("keys", {})
+
+ if (
+ event.content.get("algorithm")
+ == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2
+ ):
+ # For this algorithm we expect a curve25519 key.
+ key_name = "curve25519:%s" % (device_id,)
+ current_keys = [keys.get(key_name)]
+ else:
+ # We don't know understand the algorithm, so we just
+ # check it matches a key for the device.
+ current_keys = keys.values()
+ elif device_id:
+ # We don't have any keys for the device ID.
+ pass
+ else:
+ # The event didn't include a device ID, so we just look for
+ # keys across all devices.
+ current_keys = [
+ key
+ for device in cached_devices.values()
+ for key in device.get("keys", {}).get("keys", {}).values()
+ ]
+
+ # We now check that the sender key matches (one of) the expected
+ # keys.
+ if sender_key not in current_keys:
+ logger.info(
+ "Received event from remote device with unexpected sender key: %s %s: %s",
+ event.sender,
+ device_id or "<no device_id>",
+ sender_key,
+ )
+ resync = True
+
+ if resync:
+ run_as_background_process(
+ "resync_device_due_to_pdu",
+ self._resync_device,
+ event.sender,
+ )
+
+ await self._handle_marker_event(origin, event)
+
+ async def _resync_device(self, sender: str) -> None:
+ """We have detected that the device list for the given user may be out
+ of sync, so we try and resync them.
+ """
+
+ try:
+ await self.store.mark_remote_user_device_cache_as_stale(sender)
+
+ # Immediately attempt a resync in the background
+ if self.config.worker_app:
+ await self._user_device_resync(user_id=sender)
+ else:
+ await self._device_list_updater.user_device_resync(sender)
+ except Exception:
+ logger.exception("Failed to resync device for %s", sender)
+
+ async def _handle_marker_event(self, origin: str, marker_event: EventBase):
+ """Handles backfilling the insertion event when we receive a marker
+ event that points to one.
+
+ Args:
+ origin: Origin of the event. Will be called to get the insertion event
+ marker_event: The event to process
+ """
+
+ if marker_event.type != EventTypes.MSC2716_MARKER:
+ # Not a marker event
+ return
+
+ if marker_event.rejected_reason is not None:
+ # Rejected event
+ return
+
+ # Skip processing a marker event if the room version doesn't
+ # support it.
+ room_version = await self.store.get_room_version(marker_event.room_id)
+ if not room_version.msc2716_historical:
+ return
+
+ logger.debug("_handle_marker_event: received %s", marker_event)
+
+ insertion_event_id = marker_event.content.get(
+ EventContentFields.MSC2716_MARKER_INSERTION
+ )
+
+ if insertion_event_id is None:
+ # Nothing to retrieve then (invalid marker)
+ return
+
+ logger.debug(
+ "_handle_marker_event: backfilling insertion event %s", insertion_event_id
+ )
+
+ await self._get_events_and_persist(
+ origin,
+ marker_event.room_id,
+ [insertion_event_id],
+ )
+
+ insertion_event = await self.store.get_event(
+ insertion_event_id, allow_none=True
+ )
+ if insertion_event is None:
+ logger.warning(
+ "_handle_marker_event: server %s didn't return insertion event %s for marker %s",
+ origin,
+ insertion_event_id,
+ marker_event.event_id,
+ )
+ return
+
+ logger.debug(
+ "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s",
+ insertion_event,
+ marker_event,
+ )
+
+ await self.store.insert_insertion_extremity(
+ insertion_event_id, marker_event.room_id
+ )
+
+ logger.debug(
+ "_handle_marker_event: insertion extremity added for %s from marker event %s",
+ insertion_event,
+ marker_event,
+ )
+
+ async def _get_events_and_persist(
+ self, destination: str, room_id: str, events: Iterable[str]
+ ) -> None:
+ """Fetch the given events from a server, and persist them as outliers.
+
+ This function *does not* recursively get missing auth events of the
+ newly fetched events. Callers must include in the `events` argument
+ any missing events from the auth chain.
+
+ Logs a warning if we can't find the given event.
+ """
+
+ room_version = await self.store.get_room_version(room_id)
+
+ event_map: Dict[str, EventBase] = {}
+
+ async def get_event(event_id: str):
+ with nested_logging_context(event_id):
+ try:
+ event = await self.federation_client.get_pdu(
+ [destination],
+ event_id,
+ room_version,
+ outlier=True,
+ )
+ if event is None:
+ logger.warning(
+ "Server %s didn't return event %s",
+ destination,
+ event_id,
+ )
+ return
+
+ event_map[event.event_id] = event
+
+ except Exception as e:
+ logger.warning(
+ "Error fetching missing state/auth event %s: %s %s",
+ event_id,
+ type(e),
+ e,
+ )
+
+ await concurrently_execute(get_event, events, 5)
+
+ # Make a map of auth events for each event. We do this after fetching
+ # all the events as some of the events' auth events will be in the list
+ # of requested events.
+
+ auth_events = [
+ aid
+ for event in event_map.values()
+ for aid in event.auth_event_ids()
+ if aid not in event_map
+ ]
+ persisted_events = await self.store.get_events(
+ auth_events,
+ allow_rejected=True,
+ )
+
+ event_infos = []
+ for event in event_map.values():
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+ else:
+ logger.info("Missing auth event %s", auth_event_id)
+
+ event_infos.append(_NewEventInfo(event, auth))
+
+ if event_infos:
+ await self._auth_and_persist_events(
+ destination,
+ room_id,
+ event_infos,
+ )
+
+ async def _auth_and_persist_events(
+ self,
+ origin: str,
+ room_id: str,
+ event_infos: Collection[_NewEventInfo],
+ ) -> None:
+ """Creates the appropriate contexts and persists events. The events
+ should not depend on one another, e.g. this should be used to persist
+ a bunch of outliers, but not a chunk of individual events that depend
+ on each other for state calculations.
+
+ Notifies about the events where appropriate.
+ """
+
+ if not event_infos:
+ return
+
+ async def prep(ev_info: _NewEventInfo):
+ event = ev_info.event
+ with nested_logging_context(suffix=event.event_id):
+ res = await self.state_handler.compute_event_context(event)
+ res = await self._check_event_auth(
+ origin,
+ event,
+ res,
+ claimed_auth_event_map=ev_info.claimed_auth_event_map,
+ )
+ return res
+
+ contexts = await make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(prep, ev_info) for ev_info in event_infos],
+ consumeErrors=True,
+ )
+ )
+
+ await self.persist_events_and_notify(
+ room_id,
+ [
+ (ev_info.event, context)
+ for ev_info, context in zip(event_infos, contexts)
+ ],
+ )
+
+ async def _auth_and_persist_event(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ state: Optional[Iterable[EventBase]] = None,
+ claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> None:
+ """
+ Process an event by performing auth checks and then persisting to the database.
+
+ Args:
+ origin: The host the event originates from.
+ event: The event itself.
+ context:
+ The event context.
+
+ state:
+ The state events used to check the event for soft-fail. If this is
+ not provided the current state events will be used.
+
+ claimed_auth_event_map:
+ A map of (type, state_key) => event for the event's claimed auth_events.
+ Possibly incomplete, and possibly including events that are not yet
+ persisted, or authed, or in the right room.
+
+ Only populated where we may not already have persisted these events -
+ for example, when populating outliers.
+
+ backfilled: True if the event was backfilled.
+ """
+ context = await self._check_event_auth(
+ origin,
+ event,
+ context,
+ state=state,
+ claimed_auth_event_map=claimed_auth_event_map,
+ backfilled=backfilled,
+ )
+
+ await self._run_push_actions_and_persist_event(event, context, backfilled)
+
+ async def _check_event_auth(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ state: Optional[Iterable[EventBase]] = None,
+ claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> EventContext:
+ """
+ Checks whether an event should be rejected (for failing auth checks).
+
+ Args:
+ origin: The host the event originates from.
+ event: The event itself.
+ context:
+ The event context.
+
+ state:
+ The state events used to check the event for soft-fail. If this is
+ not provided the current state events will be used.
+
+ claimed_auth_event_map:
+ A map of (type, state_key) => event for the event's claimed auth_events.
+ Possibly incomplete, and possibly including events that are not yet
+ persisted, or authed, or in the right room.
+
+ Only populated where we may not already have persisted these events -
+ for example, when populating outliers, or the state for a backwards
+ extremity.
+
+ backfilled: True if the event was backfilled.
+
+ Returns:
+ The updated context object.
+ """
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ if claimed_auth_event_map:
+ # if we have a copy of the auth events from the event, use that as the
+ # basis for auth.
+ auth_events = claimed_auth_event_map
+ else:
+ # otherwise, we calculate what the auth events *should* be, and use that
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
+ event, prev_state_ids, for_verification=True
+ )
+ auth_events_x = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
+
+ try:
+ (
+ context,
+ auth_events_for_auth,
+ ) = await self._update_auth_events_and_context_for_auth(
+ origin, event, context, auth_events
+ )
+ except Exception:
+ # We don't really mind if the above fails, so lets not fail
+ # processing if it does. However, it really shouldn't fail so
+ # let's still log as an exception since we'll still want to fix
+ # any bugs.
+ logger.exception(
+ "Failed to double check auth events for %s with remote. "
+ "Ignoring failure and continuing processing of event.",
+ event.event_id,
+ )
+ auth_events_for_auth = auth_events
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
+ except AuthError as e:
+ logger.warning("Failed auth resolution for %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
+
+ if not context.rejected:
+ await self._check_for_soft_fail(event, state, backfilled, origin=origin)
+
+ if event.type == EventTypes.GuestAccess and not context.rejected:
+ await self.maybe_kick_guest_users(event)
+
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(
+ event, context
+ )
+
+ return context
+
+ async def _check_for_soft_fail(
+ self,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ backfilled: bool,
+ origin: str,
+ ) -> None:
+ """Checks if we should soft fail the event; if so, marks the event as
+ such.
+
+ Args:
+ event
+ state: The state at the event if we don't have all the event's prev events
+ backfilled: Whether the event is from backfill
+ origin: The host the event originates from.
+ """
+ # For new (non-backfilled and non-outlier) events we check if the event
+ # passes auth based on the current state. If it doesn't then we
+ # "soft-fail" the event.
+ if backfilled or event.internal_metadata.is_outlier():
+ return
+
+ extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids_list)
+ prev_event_ids = set(event.prev_event_ids())
+
+ if extrem_ids == prev_event_ids:
+ # If they're the same then the current state is the same as the
+ # state at the event, so no point rechecking auth for soft fail.
+ return
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ # Calculate the "current state".
+ if state is not None:
+ # If we're explicitly given the state then we won't have all the
+ # prev events, and so we have a gap in the graph. In this case
+ # we want to be a little careful as we might have been down for
+ # a while and have an incorrect view of the current state,
+ # however we still want to do checks as gaps are easy to
+ # maliciously manufacture.
+ #
+ # So we use a "current state" that is actually a state
+ # resolution across the current forward extremities and the
+ # given state at the event. This should correctly handle cases
+ # like bans, especially with state res v2.
+
+ state_sets_d = await self.state_store.get_state_groups(
+ event.room_id, extrem_ids
+ )
+ state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
+ state_sets.append(state)
+ current_states = await self.state_handler.resolve_events(
+ room_version, state_sets, event
+ )
+ current_state_ids: StateMap[str] = {
+ k: e.event_id for k, e in current_states.items()
+ }
+ else:
+ current_state_ids = await self.state_handler.get_current_state_ids(
+ event.room_id, latest_event_ids=extrem_ids
+ )
+
+ logger.debug(
+ "Doing soft-fail check for %s: state %s",
+ event.event_id,
+ current_state_ids,
+ )
+
+ # Now check if event pass auth against said current state
+ auth_types = auth_types_for_event(room_version_obj, event)
+ current_state_ids_list = [
+ e for k, e in current_state_ids.items() if k in auth_types
+ ]
+
+ auth_events_map = await self.store.get_events(current_state_ids_list)
+ current_auth_events = {
+ (e.type, e.state_key): e for e in auth_events_map.values()
+ }
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ except AuthError as e:
+ logger.warning(
+ "Soft-failing %r (from %s) because %s",
+ event,
+ e,
+ origin,
+ extra={
+ "room_id": event.room_id,
+ "mxid": event.sender,
+ "hs": origin,
+ },
+ )
+ soft_failed_event_counter.inc()
+ event.internal_metadata.soft_failed = True
+
+ async def _update_auth_events_and_context_for_auth(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ input_auth_events: StateMap[EventBase],
+ ) -> Tuple[EventContext, StateMap[EventBase]]:
+ """Helper for _check_event_auth. See there for docs.
+
+ Checks whether a given event has the expected auth events. If it
+ doesn't then we talk to the remote server to compare state to see if
+ we can come to a consensus (e.g. if one server missed some valid
+ state).
+
+ This attempts to resolve any potential divergence of state between
+ servers, but is not essential and so failures should not block further
+ processing of the event.
+
+ Args:
+ origin:
+ event:
+ context:
+
+ input_auth_events:
+ Map from (event_type, state_key) to event
+
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
+
+ Returns:
+ updated context, updated auth event map
+ """
+ # take a copy of input_auth_events before we modify it.
+ auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
+
+ event_auth_events = set(event.auth_event_ids())
+
+ # missing_auth is the set of the event's auth_events which we don't yet have
+ # in auth_events.
+ missing_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ # if we have missing events, we need to fetch those events from somewhere.
+ #
+ # we start by checking if they are in the store, and then try calling /event_auth/.
+ if missing_auth:
+ have_events = await self.store.have_seen_events(event.room_id, missing_auth)
+ logger.debug("Events %s are in the store", have_events)
+ missing_auth.difference_update(have_events)
+
+ if missing_auth:
+ # If we don't have all the auth events, we need to get them.
+ logger.info("auth_events contains unknown events: %s", missing_auth)
+ try:
+ try:
+ remote_auth_chain = await self.federation_client.get_event_auth(
+ origin, event.room_id, event.event_id
+ )
+ except RequestSendFailed as e1:
+ # The other side isn't around or doesn't implement the
+ # endpoint, so lets just bail out.
+ logger.info("Failed to get event auth from remote: %s", e1)
+ return context, auth_events
+
+ seen_remotes = await self.store.have_seen_events(
+ event.room_id, [e.event_id for e in remote_auth_chain]
+ )
+
+ for e in remote_auth_chain:
+ if e.event_id in seen_remotes:
+ continue
+
+ if e.event_id == event.event_id:
+ continue
+
+ try:
+ auth_ids = e.auth_event_ids()
+ auth = {
+ (e.type, e.state_key): e
+ for e in remote_auth_chain
+ if e.event_id in auth_ids or e.type == EventTypes.Create
+ }
+ e.internal_metadata.outlier = True
+
+ logger.debug(
+ "_check_event_auth %s missing_auth: %s",
+ event.event_id,
+ e.event_id,
+ )
+ missing_auth_event_context = (
+ await self.state_handler.compute_event_context(e)
+ )
+ await self._auth_and_persist_event(
+ origin,
+ e,
+ missing_auth_event_context,
+ claimed_auth_event_map=auth,
+ )
+
+ if e.event_id in event_auth_events:
+ auth_events[(e.type, e.state_key)] = e
+ except AuthError:
+ pass
+
+ except Exception:
+ logger.exception("Failed to get auth chain")
+
+ if event.internal_metadata.is_outlier():
+ # XXX: given that, for an outlier, we'll be working with the
+ # event's *claimed* auth events rather than those we calculated:
+ # (a) is there any point in this test, since different_auth below will
+ # obviously be empty
+ # (b) alternatively, why don't we do it earlier?
+ logger.info("Skipping auth_event fetch for outlier")
+ return context, auth_events
+
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ if not different_auth:
+ return context, auth_events
+
+ logger.info(
+ "auth_events refers to events which are not in our calculated auth "
+ "chain: %s",
+ different_auth,
+ )
+
+ # XXX: currently this checks for redactions but I'm not convinced that is
+ # necessary?
+ different_events = await self.store.get_events_as_list(different_auth)
+
+ for d in different_events:
+ if d.room_id != event.room_id:
+ logger.warning(
+ "Event %s refers to auth_event %s which is in a different room",
+ event.event_id,
+ d.event_id,
+ )
+
+ # don't attempt to resolve the claimed auth events against our own
+ # in this case: just use our own auth events.
+ #
+ # XXX: should we reject the event in this case? It feels like we should,
+ # but then shouldn't we also do so if we've failed to fetch any of the
+ # auth events?
+ return context, auth_events
+
+ # now we state-resolve between our own idea of the auth events, and the remote's
+ # idea of them.
+
+ local_state = auth_events.values()
+ remote_auth_events = dict(auth_events)
+ remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
+ remote_state = remote_auth_events.values()
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ new_state = await self.state_handler.resolve_events(
+ room_version, (local_state, remote_state), event
+ )
+
+ logger.info(
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id
+ for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
+ )
+
+ auth_events.update(new_state)
+
+ context = await self._update_context_for_auth_events(
+ event, context, auth_events
+ )
+
+ return context, auth_events
+
+ async def _update_context_for_auth_events(
+ self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
+ ) -> EventContext:
+ """Update the state_ids in an event context after auth event resolution,
+ storing the changes as a new state group.
+
+ Args:
+ event: The event we're handling the context for
+
+ context: initial event context
+
+ auth_events: Events to update in the event context.
+
+ Returns:
+ new event context
+ """
+ # exclude the state key of the new event from the current_state in the context.
+ if event.is_state():
+ event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
+ else:
+ event_key = None
+ state_updates = {
+ k: a.event_id for k, a in auth_events.items() if k != event_key
+ }
+
+ current_state_ids = await context.get_current_state_ids()
+ current_state_ids = dict(current_state_ids) # type: ignore
+
+ current_state_ids.update(state_updates)
+
+ prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = dict(prev_state_ids)
+
+ prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
+
+ # create a new state group as a delta from the existing one.
+ prev_group = context.state_group
+ state_group = await self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ current_state_ids=current_state_ids,
+ )
+
+ return EventContext.with_state(
+ state_group=state_group,
+ state_group_before_event=context.state_group_before_event,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ )
+
+ async def _run_push_actions_and_persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ):
+ """Run the push actions for a received event, and persist it.
+
+ Args:
+ event: The event itself.
+ context: The event context.
+ backfilled: True if the event was backfilled.
+ """
+ try:
+ if (
+ not event.internal_metadata.is_outlier()
+ and not backfilled
+ and not context.rejected
+ and (await self.store.get_min_depth(event.room_id)) <= event.depth
+ ):
+ await self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ await self.persist_events_and_notify(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
+ except Exception:
+ run_in_background(
+ self.store.remove_push_actions_from_staging, event.event_id
+ )
+ raise
+
+ async def persist_events_and_notify(
+ self,
+ room_id: str,
+ event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
+ backfilled: bool = False,
+ ) -> int:
+ """Persists events and tells the notifier/pushers about them, if
+ necessary.
+
+ Args:
+ room_id: The room ID of events being persisted.
+ event_and_contexts: Sequence of events with their associated
+ context that should be persisted. All events must belong to
+ the same room.
+ backfilled: Whether these events are a result of
+ backfilling or not
+
+ Returns:
+ The stream ID after which all events have been persisted.
+ """
+ if not event_and_contexts:
+ return self.store.get_current_events_token()
+
+ instance = self.config.worker.events_shard_config.get_instance(room_id)
+ if instance != self._instance_name:
+ # Limit the number of events sent over replication. We choose 200
+ # here as that is what we default to in `max_request_body_size(..)`
+ for batch in batch_iter(event_and_contexts, 200):
+ result = await self._send_events(
+ instance_name=instance,
+ store=self.store,
+ room_id=room_id,
+ event_and_contexts=batch,
+ backfilled=backfilled,
+ )
+ return result["max_stream_id"]
+ else:
+ assert self.storage.persistence
+
+ # Note that this returns the events that were persisted, which may not be
+ # the same as were passed in if some were deduplicated due to transaction IDs.
+ events, max_stream_token = await self.storage.persistence.persist_events(
+ event_and_contexts, backfilled=backfilled
+ )
+
+ if self._ephemeral_messages_enabled:
+ for event in events:
+ # If there's an expiry timestamp on the event, schedule its expiry.
+ self._message_handler.maybe_schedule_expiry(event)
+
+ if not backfilled: # Never notify for backfilled events
+ for event in events:
+ await self._notify_persisted_event(event, max_stream_token)
+
+ return max_stream_token.stream
+
+ async def _notify_persisted_event(
+ self, event: EventBase, max_stream_token: RoomStreamToken
+ ) -> None:
+ """Checks to see if notifier/pushers should be notified about the
+ event or not.
+
+ Args:
+ event:
+ max_stream_token: The max_stream_id returned by persist_events
+ """
+
+ extra_users = []
+ if event.type == EventTypes.Member:
+ target_user_id = event.state_key
+
+ # We notify for memberships if its an invite for one of our
+ # users
+ if event.internal_metadata.is_outlier():
+ if event.membership != Membership.INVITE:
+ if not self.is_mine_id(target_user_id):
+ return
+
+ target_user = UserID.from_string(target_user_id)
+ extra_users.append(target_user)
+ elif event.internal_metadata.is_outlier():
+ return
+
+ # the event has been persisted so it should have a stream ordering.
+ assert event.internal_metadata.stream_ordering
+
+ event_pos = PersistedEventPosition(
+ self._instance_name, event.internal_metadata.stream_ordering
+ )
+ self.notifier.on_new_room_event(
+ event, event_pos, max_stream_token, extra_users=extra_users
+ )
+
+ def _sanity_check_event(self, ev: EventBase) -> None:
+ """
+ Do some early sanity checks of a received event
+
+ In particular, checks it doesn't have an excessive number of
+ prev_events or auth_events, which could cause a huge state resolution
+ or cascade of event fetches.
+
+ Args:
+ ev: event to be checked
+
+ Raises:
+ SynapseError if the event does not pass muster
+ """
+ if len(ev.prev_event_ids()) > 20:
+ logger.warning(
+ "Rejecting event %s which has %i prev_events",
+ ev.event_id,
+ len(ev.prev_event_ids()),
+ )
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events")
+
+ if len(ev.auth_event_ids()) > 10:
+ logger.warning(
+ "Rejecting event %s which has %i auth_events",
+ ev.event_id,
+ len(ev.auth_event_ids()),
+ )
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
+
+ async def get_min_depth_for_context(self, context: str) -> int:
+ return await self.store.get_min_depth(context)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e1c544a3..4e8f7f1d 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler):
limit = 10
async def handle_room(event: RoomsForUser):
- d = {
+ d: JsonDict = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 7ca14e1d..4418d63d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -353,6 +353,11 @@ class BasePresenceHandler(abc.ABC):
# otherwise would not do).
await self.set_state(UserID.from_string(user_id), state, force_notify=True)
+ async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
+ raise NotImplementedError(
+ "Attempting to check presence on a non-presence worker."
+ )
+
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8cf61413..0ed59d75 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -56,6 +56,22 @@ login_counter = Counter(
)
+def init_counters_for_auth_provider(auth_provider_id: str) -> None:
+ """Ensure the prometheus counters for the given auth provider are initialised
+
+ This fixes a problem where the counters are not reported for a given auth provider
+ until the user first logs in/registers.
+ """
+ for is_guest in (True, False):
+ login_counter.labels(guest=is_guest, auth_provider=auth_provider_id)
+ for shadow_banned in (True, False):
+ registration_counter.labels(
+ guest=is_guest,
+ shadow_banned=shadow_banned,
+ auth_provider=auth_provider_id,
+ )
+
+
class LoginDict(TypedDict):
device_id: str
access_token: str
@@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
self.access_token_lifetime = hs.config.access_token_lifetime
+ init_counters_for_auth_provider("")
+
async def check_username(
self,
localpart: str,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ba131962..401b84aa 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -36,6 +36,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import (
JsonDict,
Requester,
@@ -79,7 +80,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.account_data_handler = hs.get_account_data_handler()
self.event_auth_handler = hs.get_event_auth_handler()
- self.member_linearizer = Linearizer(name="member")
+ self.member_linearizer: Linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
@@ -556,6 +557,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content.pop("displayname", None)
content.pop("avatar_url", None)
+ if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN:
+ raise SynapseError(
+ 400,
+ f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN:
+ raise SynapseError(
+ 400,
+ f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})",
+ errcode=Codes.BAD_JSON,
+ )
+
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index ac6cfc0d..906985c7 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,12 +28,11 @@ from synapse.api.constants import (
Membership,
RoomTypes,
)
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -76,6 +75,9 @@ class _PaginationSession:
class RoomSummaryHandler:
+ # A unique key used for pagination sessions for the room hierarchy endpoint.
+ _PAGINATION_SESSION_TYPE = "room_hierarchy_pagination"
+
# The time a pagination session remains valid for.
_PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
@@ -87,12 +89,6 @@ class RoomSummaryHandler:
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
- # A map of query information to the current pagination state.
- #
- # TODO Allow for multiple workers to share this data.
- # TODO Expire pagination tokens.
- self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
-
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
self._pagination_response_cache: ResponseCache[
@@ -102,21 +98,6 @@ class RoomSummaryHandler:
"get_room_hierarchy",
)
- def _expire_pagination_sessions(self):
- """Expire pagination session which are old."""
- expire_before = (
- self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS
- )
- to_expire = []
-
- for key, value in self._pagination_sessions.items():
- if value.creation_time_ms < expire_before:
- to_expire.append(key)
-
- for key in to_expire:
- logger.debug("Expiring pagination session id %s", key)
- del self._pagination_sessions[key]
-
async def get_space_summary(
self,
requester: str,
@@ -327,18 +308,29 @@ class RoomSummaryHandler:
# If this is continuing a previous session, pull the persisted data.
if from_token:
- self._expire_pagination_sessions()
+ try:
+ pagination_session = await self._store.get_session(
+ session_type=self._PAGINATION_SESSION_TYPE,
+ session_id=from_token,
+ )
+ except StoreError:
+ raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
- pagination_key = _PaginationKey(
- requested_room_id, suggested_only, max_depth, from_token
- )
- if pagination_key not in self._pagination_sessions:
+ # If the requester, room ID, suggested-only, or max depth were modified
+ # the session is invalid.
+ if (
+ requester != pagination_session["requester"]
+ or requested_room_id != pagination_session["room_id"]
+ or suggested_only != pagination_session["suggested_only"]
+ or max_depth != pagination_session["max_depth"]
+ ):
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
# Load the previous state.
- pagination_session = self._pagination_sessions[pagination_key]
- room_queue = pagination_session.room_queue
- processed_rooms = pagination_session.processed_rooms
+ room_queue = [
+ _RoomQueueEntry(*fields) for fields in pagination_session["room_queue"]
+ ]
+ processed_rooms = set(pagination_session["processed_rooms"])
else:
# The queue of rooms to process, the next room is last on the stack.
room_queue = [_RoomQueueEntry(requested_room_id, ())]
@@ -456,13 +448,21 @@ class RoomSummaryHandler:
# If there's additional data, generate a pagination token (and persist state).
if room_queue:
- next_batch = random_string(24)
- result["next_batch"] = next_batch
- pagination_key = _PaginationKey(
- requested_room_id, suggested_only, max_depth, next_batch
- )
- self._pagination_sessions[pagination_key] = _PaginationSession(
- self._clock.time_msec(), room_queue, processed_rooms
+ result["next_batch"] = await self._store.create_session(
+ session_type=self._PAGINATION_SESSION_TYPE,
+ value={
+ # Information which must be identical across pagination.
+ "requester": requester,
+ "room_id": requested_room_id,
+ "suggested_only": suggested_only,
+ "max_depth": max_depth,
+ # The stored state.
+ "room_queue": [
+ attr.astuple(room_entry) for room_entry in room_queue
+ ],
+ "processed_rooms": list(processed_rooms),
+ },
+ expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS,
)
return result
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 1b855a68..0e6ebb57 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
@@ -213,6 +214,7 @@ class SsoHandler:
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
+ init_counters_for_auth_provider(p_id)
def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
"""Get the configured identity providers"""
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 590642f5..86c3c7f0 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1,5 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018, 2019 New Vector Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,6 +30,8 @@ from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection
+from synapse.api.presence import UserPresenceState
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
@@ -86,20 +87,20 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
SyncRequestKey = Tuple[Any, ...]
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncConfig:
- user = attr.ib(type=UserID)
- filter_collection = attr.ib(type=FilterCollection)
- is_guest = attr.ib(type=bool)
- request_key = attr.ib(type=SyncRequestKey)
- device_id = attr.ib(type=Optional[str])
+ user: UserID
+ filter_collection: FilterCollection
+ is_guest: bool
+ request_key: SyncRequestKey
+ device_id: Optional[str]
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class TimelineBatch:
- prev_batch = attr.ib(type=StreamToken)
- events = attr.ib(type=List[EventBase])
- limited = attr.ib(type=bool)
+ prev_batch: StreamToken
+ events: List[EventBase]
+ limited: bool
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -113,16 +114,16 @@ class TimelineBatch:
# if there are updates for it, which we check after the instance has been created.
# This should not be a big deal because we update the notification counts afterwards as
# well anyway.
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class JoinedSyncResult:
- room_id = attr.ib(type=str)
- timeline = attr.ib(type=TimelineBatch)
- state = attr.ib(type=StateMap[EventBase])
- ephemeral = attr.ib(type=List[JsonDict])
- account_data = attr.ib(type=List[JsonDict])
- unread_notifications = attr.ib(type=JsonDict)
- summary = attr.ib(type=Optional[JsonDict])
- unread_count = attr.ib(type=int)
+ room_id: str
+ timeline: TimelineBatch
+ state: StateMap[EventBase]
+ ephemeral: List[JsonDict]
+ account_data: List[JsonDict]
+ unread_notifications: JsonDict
+ summary: Optional[JsonDict]
+ unread_count: int
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -138,12 +139,12 @@ class JoinedSyncResult:
)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class ArchivedSyncResult:
- room_id = attr.ib(type=str)
- timeline = attr.ib(type=TimelineBatch)
- state = attr.ib(type=StateMap[EventBase])
- account_data = attr.ib(type=List[JsonDict])
+ room_id: str
+ timeline: TimelineBatch
+ state: StateMap[EventBase]
+ account_data: List[JsonDict]
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -152,37 +153,37 @@ class ArchivedSyncResult:
return bool(self.timeline or self.state or self.account_data)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class InvitedSyncResult:
- room_id = attr.ib(type=str)
- invite = attr.ib(type=EventBase)
+ room_id: str
+ invite: EventBase
def __bool__(self) -> bool:
"""Invited rooms should always be reported to the client"""
return True
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class KnockedSyncResult:
- room_id = attr.ib(type=str)
- knock = attr.ib(type=EventBase)
+ room_id: str
+ knock: EventBase
def __bool__(self) -> bool:
"""Knocked rooms should always be reported to the client"""
return True
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class GroupsSyncResult:
- join = attr.ib(type=JsonDict)
- invite = attr.ib(type=JsonDict)
- leave = attr.ib(type=JsonDict)
+ join: JsonDict
+ invite: JsonDict
+ leave: JsonDict
def __bool__(self) -> bool:
return bool(self.join or self.invite or self.leave)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
@@ -190,27 +191,27 @@ class DeviceLists:
left: List of user_ids whose devices we no longer track
"""
- changed = attr.ib(type=Collection[str])
- left = attr.ib(type=Collection[str])
+ changed: Collection[str]
+ left: Collection[str]
def __bool__(self) -> bool:
return bool(self.changed or self.left)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
and left room IDs since last sync.
"""
- room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
- invited = attr.ib(type=List[InvitedSyncResult])
- knocked = attr.ib(type=List[KnockedSyncResult])
- newly_joined_rooms = attr.ib(type=List[str])
- newly_left_rooms = attr.ib(type=List[str])
+ room_entries: List["RoomSyncResultBuilder"]
+ invited: List[InvitedSyncResult]
+ knocked: List[KnockedSyncResult]
+ newly_joined_rooms: List[str]
+ newly_left_rooms: List[str]
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncResult:
"""
Attributes:
@@ -230,18 +231,18 @@ class SyncResult:
groups: Group updates, if any
"""
- next_batch = attr.ib(type=StreamToken)
- presence = attr.ib(type=List[JsonDict])
- account_data = attr.ib(type=List[JsonDict])
- joined = attr.ib(type=List[JoinedSyncResult])
- invited = attr.ib(type=List[InvitedSyncResult])
- knocked = attr.ib(type=List[KnockedSyncResult])
- archived = attr.ib(type=List[ArchivedSyncResult])
- to_device = attr.ib(type=List[JsonDict])
- device_lists = attr.ib(type=DeviceLists)
- device_one_time_keys_count = attr.ib(type=JsonDict)
- device_unused_fallback_key_types = attr.ib(type=List[str])
- groups = attr.ib(type=Optional[GroupsSyncResult])
+ next_batch: StreamToken
+ presence: List[UserPresenceState]
+ account_data: List[JsonDict]
+ joined: List[JoinedSyncResult]
+ invited: List[InvitedSyncResult]
+ knocked: List[KnockedSyncResult]
+ archived: List[ArchivedSyncResult]
+ to_device: List[JsonDict]
+ device_lists: DeviceLists
+ device_one_time_keys_count: JsonDict
+ device_unused_fallback_key_types: List[str]
+ groups: Optional[GroupsSyncResult]
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -701,7 +702,7 @@ class SyncHandler:
name_id = state_ids.get((EventTypes.Name, ""))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
- summary = {}
+ summary: JsonDict = {}
empty_ms = MemberSummary([], 0)
# TODO: only send these when they change.
@@ -1843,6 +1844,9 @@ class SyncHandler:
knocked = []
for event in room_list:
+ if event.room_version_id not in KNOWN_ROOM_VERSIONS:
+ continue
+
if event.membership == Membership.JOIN:
room_entries.append(
RoomSyncResultBuilder(
@@ -2076,21 +2080,23 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
- for room_id, event_pos in joined_rooms:
- if not event_pos.persisted_after(room_key):
- joined_room_ids.add(room_id)
+ for joined_room in joined_rooms:
+ if not joined_room.event_pos.persisted_after(room_key):
+ joined_room_ids.add(joined_room.room_id)
continue
- logger.info("User joined room after current token: %s", room_id)
+ logger.info("User joined room after current token: %s", joined_room.room_id)
extrems = (
await self.store.get_forward_extremities_for_room_at_stream_ordering(
- room_id, event_pos.stream
+ joined_room.room_id, joined_room.event_pos.stream
)
)
- users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
+ users_in_room = await self.state.get_current_users_in_room(
+ joined_room.room_id, extrems
+ )
if user_id in users_in_room:
- joined_room_ids.add(room_id)
+ joined_room_ids.add(joined_room.room_id)
return frozenset(joined_room_ids)
@@ -2160,7 +2166,7 @@ def _calculate_state(
return {event_id_to_key[e]: e for e in state_ids}
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class SyncResultBuilder:
"""Used to help build up a new SyncResult for a user
@@ -2172,33 +2178,33 @@ class SyncResultBuilder:
joined_room_ids: List of rooms the user is joined to
# The following mirror the fields in a sync response
- presence (list)
- account_data (list)
- joined (list[JoinedSyncResult])
- invited (list[InvitedSyncResult])
- knocked (list[KnockedSyncResult])
- archived (list[ArchivedSyncResult])
- groups (GroupsSyncResult|None)
- to_device (list)
+ presence
+ account_data
+ joined
+ invited
+ knocked
+ archived
+ groups
+ to_device
"""
- sync_config = attr.ib(type=SyncConfig)
- full_state = attr.ib(type=bool)
- since_token = attr.ib(type=Optional[StreamToken])
- now_token = attr.ib(type=StreamToken)
- joined_room_ids = attr.ib(type=FrozenSet[str])
+ sync_config: SyncConfig
+ full_state: bool
+ since_token: Optional[StreamToken]
+ now_token: StreamToken
+ joined_room_ids: FrozenSet[str]
- presence = attr.ib(type=List[JsonDict], default=attr.Factory(list))
- account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
- joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
- invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
- knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
- archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
- groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
- to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
+ presence: List[UserPresenceState] = attr.Factory(list)
+ account_data: List[JsonDict] = attr.Factory(list)
+ joined: List[JoinedSyncResult] = attr.Factory(list)
+ invited: List[InvitedSyncResult] = attr.Factory(list)
+ knocked: List[KnockedSyncResult] = attr.Factory(list)
+ archived: List[ArchivedSyncResult] = attr.Factory(list)
+ groups: Optional[GroupsSyncResult] = None
+ to_device: List[JsonDict] = attr.Factory(list)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class RoomSyncResultBuilder:
"""Stores information needed to create either a `JoinedSyncResult` or
`ArchivedSyncResult`.
@@ -2214,10 +2220,10 @@ class RoomSyncResultBuilder:
upto_token: Latest point to return events from.
"""
- room_id = attr.ib(type=str)
- rtype = attr.ib(type=str)
- events = attr.ib(type=Optional[List[EventBase]])
- newly_joined = attr.ib(type=bool)
- full_state = attr.ib(type=bool)
- since_token = attr.ib(type=Optional[StreamToken])
- upto_token = attr.ib(type=StreamToken)
+ room_id: str
+ rtype: str
+ events: Optional[List[EventBase]]
+ newly_joined: bool
+ full_state: bool
+ since_token: Optional[StreamToken]
+ upto_token: StreamToken
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 4c3b669f..13b0c61d 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for.
REQUEST_USER_ID = "request_user_id"
+
+ # used during registration to store the registration token used (if required) so that:
+ # - we can prevent a token being used twice by one session
+ # - we can 'use up' the token after registration has successfully completed
+ REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 5414ce77..d3828dec 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -49,7 +49,7 @@ class UserInteractiveAuthChecker:
clientip: The IP address of the client.
Raises:
- SynapseError if authentication failed
+ LoginError if authentication failed.
Returns:
The result of authentication (to pass back to the client?)
@@ -131,7 +131,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
)
if resp_body["success"]:
return True
- raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(
+ 401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED
+ )
class _BaseThreepidAuthChecker:
@@ -191,7 +193,9 @@ class _BaseThreepidAuthChecker:
raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
if not threepid:
- raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(
+ 401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED
+ )
if threepid["medium"] != medium:
raise LoginError(
@@ -237,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return await self._check_threepid("msisdn", authdict)
+class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.REGISTRATION_TOKEN
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self.hs = hs
+ self._enabled = bool(hs.config.registration_requires_token)
+ self.store = hs.get_datastore()
+
+ def is_enabled(self) -> bool:
+ return self._enabled
+
+ async def check_auth(self, authdict: dict, clientip: str) -> Any:
+ if "token" not in authdict:
+ raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
+ if not isinstance(authdict["token"], str):
+ raise LoginError(
+ 400, "Registration token must be a string", Codes.INVALID_PARAM
+ )
+ if "session" not in authdict:
+ raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
+
+ # Get these here to avoid cyclic dependencies
+ from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+
+ auth_handler = self.hs.get_auth_handler()
+
+ session = authdict["session"]
+ token = authdict["token"]
+
+ # If the LoginType.REGISTRATION_TOKEN stage has already been completed,
+ # return early to avoid incrementing `pending` again.
+ stored_token = await auth_handler.get_session_data(
+ session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
+ )
+ if stored_token:
+ if token != stored_token:
+ raise LoginError(
+ 400, "Registration token has changed", Codes.INVALID_PARAM
+ )
+ else:
+ return token
+
+ if await self.store.registration_token_is_valid(token):
+ # Increment pending counter, so that if token has limited uses it
+ # can't be used up by someone else in the meantime.
+ await self.store.set_registration_token_pending(token)
+ # Store the token in the UIA session, so that once registration
+ # is complete `completed` can be incremented.
+ await auth_handler.set_session_data(
+ session,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ token,
+ )
+ # The token will be stored as the result of the authentication stage
+ # in ui_auth_sessions_credentials. This allows the pending counter
+ # for tokens to be decremented when expired sessions are deleted.
+ return token
+ else:
+ raise LoginError(
+ 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
+ )
+
+
INTERACTIVE_AUTH_CHECKERS = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
EmailIdentityAuthChecker,
MsisdnAuthChecker,
+ RegistrationTokenAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 55ea97a0..9a2684ac 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
from synapse.http.server import DirectServeJsonResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources
@@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling.
"""
- def __init__(self, hs, handler):
+ def __init__(self, hs: "HomeServer", handler):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
@@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource):
``request.write()``, and call ``request.finish()``.
Args:
- hs (synapse.server.HomeServer): homeserver
+ hs: homeserver
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
super().__init__()
self._handler = handler
- def _async_render(self, request):
+ def _async_render(self, request: Request):
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b8ed4ec9..f68646fd 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
import logging
import random
import time
-from typing import List
+from typing import Callable, Dict, List
import attr
@@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
logger = logging.getLogger(__name__)
-SERVER_CACHE = {}
+SERVER_CACHE: Dict[bytes, List["Server"]] = {}
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
class Server:
"""
Our record of an individual server which can be tried to reach a destination.
Attributes:
- host (bytes): target hostname
- port (int):
- priority (int):
- weight (int):
- expires (int): when the cache should expire this record - in *seconds* since
+ host: target hostname
+ port:
+ priority:
+ weight:
+ expires: when the cache should expire this record - in *seconds* since
the epoch
"""
- host = attr.ib()
- port = attr.ib()
- priority = attr.ib(default=0)
- weight = attr.ib(default=0)
- expires = attr.ib(default=0)
+ host: bytes
+ port: int
+ priority: int = 0
+ weight: int = 0
+ expires: int = 0
-def _sort_server_list(server_list):
+def _sort_server_list(server_list: List[Server]) -> List[Server]:
"""Given a list of SRV records sort them into priority order and shuffle
each priority with the given weight.
"""
- priority_map = {}
+ priority_map: Dict[int, List[Server]] = {}
for server in server_list:
priority_map.setdefault(server.priority, []).append(server)
@@ -103,11 +103,16 @@ class SrvResolver:
Args:
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
- cache (dict): cache object
- get_time (callable): clock implementation. Should return seconds since the epoch
+ cache: cache object
+ get_time: clock implementation. Should return seconds since the epoch
"""
- def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+ def __init__(
+ self,
+ dns_client=client,
+ cache: Dict[bytes, List[Server]] = SERVER_CACHE,
+ get_time: Callable[[], float] = time.time,
+ ):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
@@ -116,7 +121,7 @@ class SrvResolver:
"""Look up a SRV record
Args:
- service_name (bytes): record to look up
+ service_name: record to look up
Returns:
a list of the SRV records, or an empty list if none found
@@ -158,7 +163,7 @@ class SrvResolver:
and answers[0].payload
and answers[0].payload.target == dns.Name(b".")
):
- raise ConnectError("Service %s unavailable" % service_name)
+ raise ConnectError(f"Service {service_name!r} unavailable")
servers = []
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a3f31452..6fd88bde 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase):
raise ValueError(f"Invalid URI {uri!r}")
parsed_uri = URI.fromBytes(uri)
- pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+ pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}"
request_path = parsed_uri.originForm
should_skip_proxy = False
@@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase):
)
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
- pool_key = ("http-proxy", self.http_proxy_endpoint)
+ pool_key = "http-proxy"
endpoint = self.http_proxy_endpoint
request_path = uri
elif (
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2d2ed229..b11fa639 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -32,6 +32,7 @@ from twisted.internet import defer
from twisted.web.resource import IResource
from synapse.events import EventBase
+from synapse.events.presence_router import PresenceRouter
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeHtmlResource,
@@ -57,6 +58,8 @@ This package defines the 'stable' API which can be used by extension modules whi
are loaded into Synapse.
"""
+PRESENCE_ALL_USERS = PresenceRouter.ALL_USERS
+
__all__ = [
"errors",
"make_deferred_yieldable",
@@ -70,6 +73,7 @@ __all__ = [
"DirectServeHtmlResource",
"DirectServeJsonResource",
"ModuleApi",
+ "PRESENCE_ALL_USERS",
]
logger = logging.getLogger(__name__)
@@ -112,6 +116,7 @@ class ModuleApi:
self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
+ self._presence_router = hs.get_presence_router()
#################################################################################
# The following methods should only be called during the module's initialisation.
@@ -131,6 +136,11 @@ class ModuleApi:
"""Registers callbacks for third party event rules capabilities."""
return self._third_party_event_rules.register_third_party_rules_callbacks
+ @property
+ def register_presence_router_callbacks(self):
+ """Registers callbacks for presence router capabilities."""
+ return self._presence_router.register_presence_router_callbacks
+
def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path.
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index cdcbdd77..154e5b70 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -48,7 +48,8 @@ logger = logging.getLogger(__name__)
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
REQUIREMENTS = [
- "jsonschema>=2.5.1",
+ # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0
+ "jsonschema>=3.0.0",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
"canonicaljson>=1.4.0",
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 79cadb7b..a0b3145f 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.clock = hs.get_clock()
- self.federation_handler = hs.get_federation_handler()
+ self.federation_event_handler = hs.get_federation_event_handler()
@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
@@ -127,7 +127,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
- max_stream_id = await self.federation_handler.persist_events_and_notify(
+ max_stream_id = await self.federation_event_handler.persist_events_and_notify(
room_id, event_and_contexts, backfilled
)
diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html
index 63944dc6..b3db06ef 100644
--- a/synapse/res/templates/recaptcha.html
+++ b/synapse/res/templates/recaptcha.html
@@ -16,6 +16,9 @@ function captchaDone() {
<body>
<form id="registrationForm" method="post" action="{{ myurl }}">
<div>
+ {% if error is defined %}
+ <p class="error"><strong>Error: {{ error }}</strong></p>
+ {% endif %}
<p>
Hello! We need to prevent computer programs and other automated
things from creating accounts on this server.
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
new file mode 100644
index 00000000..4577ce17
--- /dev/null
+++ b/synapse/res/templates/registration_token.html
@@ -0,0 +1,23 @@
+<html>
+<head>
+<title>Authentication</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+</head>
+<body>
+<form id="registrationForm" method="post" action="{{ myurl }}">
+ <div>
+ {% if error is defined %}
+ <p class="error"><strong>Error: {{ error }}</strong></p>
+ {% endif %}
+ <p>
+ Please enter a registration token.
+ </p>
+ <input type="hidden" name="session" value="{{ session }}" />
+ <input type="text" name="token" />
+ <input type="submit" value="Authenticate" />
+ </div>
+</form>
+</body>
+</html>
diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html
index dfef9897..369ff446 100644
--- a/synapse/res/templates/terms.html
+++ b/synapse/res/templates/terms.html
@@ -8,6 +8,9 @@
<body>
<form id="registrationForm" method="post" action="{{ myurl }}">
<div>
+ {% if error is defined %}
+ <p class="error"><strong>Error: {{ error }}</strong></p>
+ {% endif %}
<p>
Please click the button below if you agree to the
<a href="{{ terms_url }}">privacy policy of this homeserver.</a>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index d5862a4d..b2514d9d 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -36,7 +36,11 @@ from synapse.rest.admin.event_reports import (
)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
-from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
+from synapse.rest.admin.registration_tokens import (
+ ListRegistrationTokensRestServlet,
+ NewRegistrationTokenRestServlet,
+ RegistrationTokenRestServlet,
+)
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
@@ -47,7 +51,6 @@ from synapse.rest.admin.rooms import (
RoomMembersRestServlet,
RoomRestServlet,
RoomStateRestServlet,
- ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
@@ -220,8 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
- PurgeRoomServlet(hs).register(http_server)
- SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
@@ -241,6 +242,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomEventContextServlet(hs).register(http_server)
RateLimitRestServlet(hs).register(http_server)
UsernameAvailableRestServlet(hs).register(http_server)
+ ListRegistrationTokensRestServlet(hs).register(http_server)
+ NewRegistrationTokenRestServlet(hs).register(http_server)
+ RegistrationTokenRestServlet(hs).register(http_server)
+
+ # Some servlets only get registered for the main process.
+ if hs.config.worker_app is None:
+ SendServerNoticeServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
@@ -253,7 +261,6 @@ def register_servlets_for_client_rest_resource(
PurgeHistoryRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
- ShutdownRoomRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
DeleteGroupAdminRestServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
deleted file mode 100644
index 2365ff7a..00000000
--- a/synapse/rest/admin/purge_room_servlet.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING, Tuple
-
-from synapse.http.servlet import (
- RestServlet,
- assert_params_in_dict,
- parse_json_object_from_request,
-)
-from synapse.http.site import SynapseRequest
-from synapse.rest.admin import assert_requester_is_admin
-from synapse.rest.admin._base import admin_patterns
-from synapse.types import JsonDict
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class PurgeRoomServlet(RestServlet):
- """Servlet which will remove all trace of a room from the database
-
- POST /_synapse/admin/v1/purge_room
- {
- "room_id": "!room:id"
- }
-
- returns:
-
- {}
- """
-
- PATTERNS = admin_patterns("/purge_room$")
-
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.auth = hs.get_auth()
- self.pagination_handler = hs.get_pagination_handler()
-
- async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- await assert_requester_is_admin(self.auth, request)
-
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ("room_id",))
-
- await self.pagination_handler.purge_room(body["room_id"])
-
- return 200, {}
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
new file mode 100644
index 00000000..5a1c929d
--- /dev/null
+++ b/synapse/rest/admin/registration_tokens.py
@@ -0,0 +1,321 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import string
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_boolean,
+ parse_json_object_from_request,
+)
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ListRegistrationTokensRestServlet(RestServlet):
+ """List registration tokens.
+
+ To list all tokens:
+
+ GET /_synapse/admin/v1/registration_tokens
+
+ 200 OK
+
+ {
+ "registration_tokens": [
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "wxyz",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 9,
+ "expiry_time": 1625394937000
+ }
+ ]
+ }
+
+ The optional query parameter `valid` can be used to filter the response.
+ If it is `true`, only valid tokens are returned. If it is `false`, only
+ tokens that have expired or have had all uses exhausted are returned.
+ If it is omitted, all tokens are returned regardless of validity.
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+ valid = parse_boolean(request, "valid")
+ token_list = await self.store.get_registration_tokens(valid)
+ return 200, {"registration_tokens": token_list}
+
+
+class NewRegistrationTokenRestServlet(RestServlet):
+ """Create a new registration token.
+
+ For example, to create a token specifying some fields:
+
+ POST /_synapse/admin/v1/registration_tokens/new
+
+ {
+ "token": "defg",
+ "uses_allowed": 1
+ }
+
+ 200 OK
+
+ {
+ "token": "defg",
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": null
+ }
+
+ Defaults are used for any fields not specified.
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens/new$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ # A string of all the characters allowed to be in a registration_token
+ self.allowed_chars = string.ascii_letters + string.digits + "-_"
+ self.allowed_chars_set = set(self.allowed_chars)
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+
+ if "token" in body:
+ token = body["token"]
+ if not isinstance(token, str):
+ raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+ if not (0 < len(token) <= 64):
+ raise SynapseError(
+ 400,
+ "token must not be empty and must not be longer than 64 characters",
+ Codes.INVALID_PARAM,
+ )
+ if not set(token).issubset(self.allowed_chars_set):
+ raise SynapseError(
+ 400,
+ "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
+ Codes.INVALID_PARAM,
+ )
+
+ else:
+ # Get length of token to generate (default is 16)
+ length = body.get("length", 16)
+ if not isinstance(length, int):
+ raise SynapseError(
+ 400, "length must be an integer", Codes.INVALID_PARAM
+ )
+ if not (0 < length <= 64):
+ raise SynapseError(
+ 400,
+ "length must be greater than zero and not greater than 64",
+ Codes.INVALID_PARAM,
+ )
+
+ # Generate token
+ token = await self.store.generate_registration_token(
+ length, self.allowed_chars
+ )
+
+ uses_allowed = body.get("uses_allowed", None)
+ if not (
+ uses_allowed is None
+ or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ ):
+ raise SynapseError(
+ 400,
+ "uses_allowed must be a non-negative integer or null",
+ Codes.INVALID_PARAM,
+ )
+
+ expiry_time = body.get("expiry_time", None)
+ if not isinstance(expiry_time, (int, type(None))):
+ raise SynapseError(
+ 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ )
+ if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ raise SynapseError(
+ 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ )
+
+ created = await self.store.create_registration_token(
+ token, uses_allowed, expiry_time
+ )
+ if not created:
+ raise SynapseError(
+ 400, f"Token already exists: {token}", Codes.INVALID_PARAM
+ )
+
+ resp = {
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ }
+ return 200, resp
+
+
+class RegistrationTokenRestServlet(RestServlet):
+ """Retrieve, update, or delete the given token.
+
+ For example,
+
+ to retrieve a token:
+
+ GET /_synapse/admin/v1/registration_tokens/abcd
+
+ 200 OK
+
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ }
+
+
+ to update a token:
+
+ PUT /_synapse/admin/v1/registration_tokens/defg
+
+ {
+ "uses_allowed": 5,
+ "expiry_time": 4781243146000
+ }
+
+ 200 OK
+
+ {
+ "token": "defg",
+ "uses_allowed": 5,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": 4781243146000
+ }
+
+
+ to delete a token:
+
+ DELETE /_synapse/admin/v1/registration_tokens/wxyz
+
+ 200 OK
+
+ {}
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+ """Retrieve a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+ token_info = await self.store.get_one_registration_token(token)
+
+ # If no result return a 404
+ if token_info is None:
+ raise NotFoundError(f"No such registration token: {token}")
+
+ return 200, token_info
+
+ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+ """Update a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+ new_attributes = {}
+
+ # Only add uses_allowed to new_attributes if it is present and valid
+ if "uses_allowed" in body:
+ uses_allowed = body["uses_allowed"]
+ if not (
+ uses_allowed is None
+ or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ ):
+ raise SynapseError(
+ 400,
+ "uses_allowed must be a non-negative integer or null",
+ Codes.INVALID_PARAM,
+ )
+ new_attributes["uses_allowed"] = uses_allowed
+
+ if "expiry_time" in body:
+ expiry_time = body["expiry_time"]
+ if not isinstance(expiry_time, (int, type(None))):
+ raise SynapseError(
+ 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ )
+ if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ raise SynapseError(
+ 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ )
+ new_attributes["expiry_time"] = expiry_time
+
+ if len(new_attributes) == 0:
+ # Nothing to update, get token info to return
+ token_info = await self.store.get_one_registration_token(token)
+ else:
+ token_info = await self.store.update_registration_token(
+ token, new_attributes
+ )
+
+ # If no result return a 404
+ if token_info is None:
+ raise NotFoundError(f"No such registration token: {token}")
+
+ return 200, token_info
+
+ async def on_DELETE(
+ self, request: SynapseRequest, token: str
+ ) -> Tuple[int, JsonDict]:
+ """Delete a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+
+ if await self.store.delete_registration_token(token):
+ return 200, {}
+
+ raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 975c28b2..ad83d4b5 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -46,41 +46,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class ShutdownRoomRestServlet(RestServlet):
- """Shuts down a room by removing all local users from the room and blocking
- all future invites and joins to the room. Any local aliases will be repointed
- to a new room created by `new_room_user_id` and kicked users will be auto
- joined to the new room.
- """
-
- PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
-
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.auth = hs.get_auth()
- self.room_shutdown_handler = hs.get_room_shutdown_handler()
-
- async def on_POST(
- self, request: SynapseRequest, room_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- content = parse_json_object_from_request(request)
- assert_params_in_dict(content, ["new_room_user_id"])
-
- ret = await self.room_shutdown_handler.shutdown_room(
- room_id=room_id,
- new_room_user_id=content["new_room_user_id"],
- new_room_name=content.get("room_name"),
- message=content.get("message"),
- requester_user_id=requester.user.to_string(),
- block=True,
- )
-
- return (200, ret)
-
-
class DeleteRoomRestServlet(RestServlet):
"""Delete a room from server.
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index b5e4c474..42201afc 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
+ self.server_notices_manager = hs.get_server_notices_manager()
+ self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer):
@@ -79,19 +81,22 @@ class SendServerNoticeServlet(RestServlet):
# We grab the server notices manager here as its initialisation has a check for worker processes,
# but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
# admin api).
- if not self.hs.get_server_notices_manager().is_enabled():
+ if not self.server_notices_manager.is_enabled():
raise SynapseError(400, "Server notices are not enabled on this server")
- user_id = body["user_id"]
- UserID.from_string(user_id)
- if not self.hs.is_mine_id(user_id):
+ target_user = UserID.from_string(body["user_id"])
+ if not self.hs.is_mine(target_user):
raise SynapseError(400, "Server notices can only be sent to local users")
- event = await self.hs.get_server_notices_manager().send_notice(
- user_id=body["user_id"],
+ if not await self.admin_handler.get_user(target_user):
+ raise NotFoundError("User not found")
+
+ event = await self.server_notices_manager.send_notice(
+ user_id=target_user.to_string(),
type=event_type,
state_key=state_key,
event_content=body["content"],
+ txn_id=txn_id,
)
return 200, {"event_id": event.event_id}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 3c8a0c68..c1a1ba64 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
if not isinstance(deactivate, bool):
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
- # convert into List[Tuple[str, str]]
+ # convert List[Dict[str, str]] into Set[Tuple[str, str]]
if external_ids is not None:
- new_external_ids = []
- for external_id in external_ids:
- new_external_ids.append(
- (external_id["auth_provider"], external_id["external_id"])
- )
+ new_external_ids = {
+ (external_id["auth_provider"], external_id["external_id"])
+ for external_id in external_ids
+ }
+
+ # convert List[Dict[str, str]] into Set[Tuple[str, str]]
+ if threepids is not None:
+ new_threepids = {
+ (threepid["medium"], threepid["address"]) for threepid in threepids
+ }
if user: # modify user
if "displayname" in body:
@@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
)
if threepids is not None:
- # remove old threepids from user
- old_threepids = await self.store.user_get_threepids(user_id)
- for threepid in old_threepids:
+ # get changed threepids (added and removed)
+ # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
+ cur_threepids = {
+ (threepid["medium"], threepid["address"])
+ for threepid in await self.store.user_get_threepids(user_id)
+ }
+ add_threepids = new_threepids - cur_threepids
+ del_threepids = cur_threepids - new_threepids
+
+ # remove old threepids
+ for medium, address in del_threepids:
try:
await self.auth_handler.delete_threepid(
- user_id, threepid["medium"], threepid["address"], None
+ user_id, medium, address, None
)
except Exception:
logger.exception("Failed to remove threepids")
raise SynapseError(500, "Failed to remove threepids")
- # add new threepids to user
+ # add new threepids
current_time = self.hs.get_clock().time_msec()
- for threepid in threepids:
+ for medium, address in add_threepids:
await self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], current_time
+ user_id, medium, address, current_time
)
if external_ids is not None:
# get changed external_ids (added and removed)
- cur_external_ids = await self.store.get_external_ids_by_user(user_id)
- add_external_ids = set(new_external_ids) - set(cur_external_ids)
- del_external_ids = set(cur_external_ids) - set(new_external_ids)
+ cur_external_ids = set(
+ await self.store.get_external_ids_by_user(user_id)
+ )
+ add_external_ids = new_external_ids - cur_external_ids
+ del_external_ids = cur_external_ids - new_external_ids
# remove old external_ids
for auth_provider, external_id in del_external_ids:
@@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
if threepids is not None:
current_time = self.hs.get_clock().time_msec()
- for threepid in threepids:
+ for medium, address in new_threepids:
await self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], current_time
+ user_id, medium, address, current_time
)
if (
self.hs.config.email_enable_notifs
@@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
- device_display_name=threepid["address"],
- pushkey=threepid["address"],
+ device_display_name=address,
+ pushkey=address,
lang=None, # We don't know a user's language here
data={},
)
diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py
index 3ebe4018..6c24b96c 100644
--- a/synapse/rest/client/account_validity.py
+++ b/synapse/rest/client/account_validity.py
@@ -13,24 +13,27 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
-from synapse.api.errors import SynapseError
-from synapse.http.server import respond_with_html
-from synapse.http.servlet import RestServlet
+from twisted.web.server import Request
+
+from synapse.http.server import HttpServer, respond_with_html
+from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -46,18 +49,14 @@ class AccountValidityRenewServlet(RestServlet):
hs.config.account_validity.account_validity_invalid_token_template
)
- async def on_GET(self, request):
- if b"token" not in request.args:
- raise SynapseError(400, "Missing renewal token")
- renewal_token = request.args[b"token"][0]
+ async def on_GET(self, request: Request) -> None:
+ renewal_token = parse_string(request, "token", required=True)
(
token_valid,
token_stale,
expiration_ts,
- ) = await self.account_activity_handler.renew_account(
- renewal_token.decode("utf8")
- )
+ ) = await self.account_activity_handler.renew_account(renewal_token)
if token_valid:
status_code = 200
@@ -77,11 +76,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/send_mail$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -91,7 +86,7 @@ class AccountValiditySendMailServlet(RestServlet):
hs.config.account_validity.account_validity_renew_by_email_enabled
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id)
@@ -99,6 +94,6 @@ class AccountValiditySendMailServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountValidityRenewServlet(hs).register(http_server)
AccountValiditySendMailServlet(hs).register(http_server)
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 6ea1b50a..df8cc4ac 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -15,11 +15,14 @@
import logging
from typing import TYPE_CHECKING
+from twisted.web.server import Request
+
from synapse.api.constants import LoginType
-from synapse.api.errors import SynapseError
+from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.http.server import respond_with_html
+from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
from ._base import client_patterns
@@ -46,9 +49,10 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
+ self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
- async def on_GET(self, request, stagetype):
+ async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@@ -74,6 +78,12 @@ class AuthRestServlet(RestServlet):
# re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session)
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ )
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -81,7 +91,7 @@ class AuthRestServlet(RestServlet):
respond_with_html(request, 200, html)
return None
- async def on_POST(self, request, stagetype):
+ async def on_POST(self, request: Request, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
@@ -95,29 +105,32 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.RECAPTCHA, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.recaptcha_template.render(
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.recaptcha_public_key,
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.TERMS, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.TERMS, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.terms_template.render(
session=session,
terms_url="%s_matrix/consent?v=%s"
@@ -127,10 +140,33 @@ class AuthRestServlet(RestServlet):
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
+
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ token = parse_string(request, "token", required=True)
+ authdict = {"session": session, "token": token}
+
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ error=e.msg,
+ )
+ else:
+ html = self.success_template.render()
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -139,5 +175,5 @@ class AuthRestServlet(RestServlet):
return None
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AuthRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 88e3aac7..65b3b5ce 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
@@ -61,8 +62,19 @@ class CapabilitiesRestServlet(RestServlet):
"org.matrix.msc3244.room_capabilities"
] = MSC3244_CAPABILITIES
+ if self.config.experimental.msc3283_enabled:
+ response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
+ "enabled": self.config.enable_set_displayname
+ }
+ response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
+ "enabled": self.config.enable_set_avatar_url
+ }
+ response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
+ "enabled": self.config.enable_3pid_changes
+ }
+
return 200, response
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
CapabilitiesRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8b9674db..25bc3c8f 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -14,34 +14,36 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api import errors
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
@@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/delete_devices")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -65,7 +67,7 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
@@ -100,18 +102,16 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
- async def on_GET(self, request, device_id):
+ async def on_GET(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
@@ -119,7 +119,9 @@ class DeviceRestServlet(RestServlet):
return 200, device
@interactive_auth_handler
- async def on_DELETE(self, request, device_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
@@ -146,7 +148,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {}
- async def on_PUT(self, request, device_id):
+ async def on_PUT(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request)
@@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet):
PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_GET(self, request: SynapseRequest):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
dehydrated_device = await self.device_handler.get_dehydrated_device(
requester.user.to_string()
@@ -211,7 +215,7 @@ class DehydratedDeviceServlet(RestServlet):
else:
raise errors.NotFoundError("No dehydrated device available")
- async def on_PUT(self, request: SynapseRequest):
+ async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_json_object_from_request(request)
requester = await self.auth.get_user_by_req(request)
@@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet):
"/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_POST(self, request: SynapseRequest):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
submission = parse_json_object_from_request(request)
@@ -292,7 +296,7 @@ class ClaimDehydratedDeviceServlet(RestServlet):
return (200, result)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index ffa075c8..ee247e3d 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
from synapse.api.errors import (
AuthError,
@@ -22,14 +24,19 @@ from synapse.api.errors import (
NotFoundError,
SynapseError,
)
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server)
@@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_alias):
- room_alias = RoomAlias.from_string(room_alias)
+ async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
- res = await self.directory_handler.get_association(room_alias)
+ res = await self.directory_handler.get_association(room_alias_obj)
return 200, res
- async def on_PUT(self, request, room_alias):
- room_alias = RoomAlias.from_string(room_alias)
+ async def on_PUT(
+ self, request: SynapseRequest, room_alias: str
+ ) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
if "room_id" not in content:
@@ -61,7 +70,7 @@ class ClientDirectoryServer(RestServlet):
)
logger.debug("Got content: %s", content)
- logger.debug("Got room name: %s", room_alias.to_string())
+ logger.debug("Got room name: %s", room_alias_obj.to_string())
room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None
@@ -78,22 +87,25 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.create_association(
- requester, room_alias, room_id, servers
+ requester, room_alias_obj, room_id, servers
)
return 200, {}
- async def on_DELETE(self, request, room_alias):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_alias: str
+ ) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
+
try:
service = self.auth.get_appservice_by_req(request)
- room_alias = RoomAlias.from_string(room_alias)
await self.directory_handler.delete_appservice_association(
- service, room_alias
+ service, room_alias_obj
)
logger.info(
"Application service at %s deleted alias %s",
service.url,
- room_alias.to_string(),
+ room_alias_obj.to_string(),
)
return 200, {}
except InvalidClientCredentialsError:
@@ -103,12 +115,10 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
user = requester.user
- room_alias = RoomAlias.from_string(room_alias)
-
- await self.directory_handler.delete_association(requester, room_alias)
+ await self.directory_handler.delete_association(requester, room_alias_obj)
logger.info(
- "User %s deleted alias %s", user.to_string(), room_alias.to_string()
+ "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
)
return 200, {}
@@ -117,20 +127,22 @@ class ClientDirectoryServer(RestServlet):
class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
- async def on_PUT(self, request, room_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -142,7 +154,9 @@ class ClientDirectoryListServer(RestServlet):
return 200, {}
- async def on_DELETE(self, request, room_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.edit_published_room_list(
@@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- def on_PUT(self, request, network_id, room_id):
+ async def on_PUT(
+ self, request: SynapseRequest, network_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- return self._edit(request, network_id, room_id, visibility)
+ return await self._edit(request, network_id, room_id, visibility)
- def on_DELETE(self, request, network_id, room_id):
- return self._edit(request, network_id, room_id, "private")
+ async def on_DELETE(
+ self, request: SynapseRequest, network_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ return await self._edit(request, network_id, room_id, "private")
- async def _edit(self, request, network_id, room_id, visibility):
+ async def _edit(
+ self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 52bb579c..13b72a04 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -14,11 +14,18 @@
"""This module contains REST servlets to do with event streaming, /events."""
import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.http.server import HttpServer
+from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
- room_id = None
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
if is_guest:
- if b"room_id" not in request.args:
+ if b"room_id" not in args:
raise SynapseError(400, "Guest users must specify room_id param")
- if b"room_id" in request.args:
- room_id = request.args[b"room_id"][0].decode("ascii")
+ room_id = parse_string(request, "room_id")
pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if b"timeout" in request.args:
+ if b"timeout" in args:
try:
- timeout = int(request.args[b"timeout"][0])
+ timeout = int(args[b"timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
- as_client_event = b"raw" not in request.args
+ as_client_event = b"raw" not in args
chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
@@ -70,25 +76,27 @@ class EventStreamRestServlet(RestServlet):
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()
- async def on_GET(self, request, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, event_id: str
+ ) -> Tuple[int, Union[str, JsonDict]]:
requester = await self.auth.get_user_by_req(request)
event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec()
if event:
- event = await self._event_serializer.serialize_event(event, time_now)
- return 200, event
+ result = await self._event_serializer.serialize_event(event, time_now)
+ return 200, result
else:
return 404, "Event not found."
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EventStreamRestServlet(hs).register(http_server)
EventRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 411667a9..6ed60c74 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -13,26 +13,34 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID
from ._base import client_patterns, set_timeline_upper_limit
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- async def on_GET(self, request, user_id, filter_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, filter_id: str
+ ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@@ -43,13 +51,13 @@ class GetFilterRestServlet(RestServlet):
raise AuthError(403, "Can only get filters for local users")
try:
- filter_id = int(filter_id)
+ filter_id_int = int(filter_id)
except Exception:
raise SynapseError(400, "Invalid filter_id")
try:
filter_collection = await self.filtering.get_user_filter(
- user_localpart=target_user.localpart, filter_id=filter_id
+ user_localpart=target_user.localpart, filter_id=filter_id_int
)
except StoreError as e:
if e.code != 404:
@@ -62,13 +70,15 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@@ -89,6 +99,6 @@ class CreateFilterRestServlet(RestServlet):
return 200, {"filter_id": str(filter_id)}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
index 6285680c..c3667ff8 100644
--- a/synapse/rest/client/groups.py
+++ b/synapse/rest/client/groups.py
@@ -26,6 +26,7 @@ from synapse.api.constants import (
)
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -930,7 +931,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py
index 12ba0e91..49b1037b 100644
--- a/synapse/rest/client/initial_sync.py
+++ b/synapse/rest/client/initial_sync.py
@@ -12,25 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Dict, List, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
# TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- as_client_event = b"raw" not in request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
+ as_client_event = b"raw" not in args
pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
@@ -43,5 +51,5 @@ class InitialSyncRestServlet(RestServlet):
return 200, content
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
InitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index d0d9d30d..7281b2ee 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,19 +15,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Optional, Tuple
-from synapse.api.errors import SynapseError
+from synapse.api.errors import InvalidAPICallError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.types import StreamToken
+from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -59,18 +65,16 @@ class KeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys")
- async def on_POST(self, request, device_id):
+ async def on_POST(
+ self, request: SynapseRequest, device_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -148,21 +152,30 @@ class KeyQueryServlet(RestServlet):
PATTERNS = client_patterns("/keys/query$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
+
+ device_keys = body.get("device_keys")
+ if not isinstance(device_keys, dict):
+ raise InvalidAPICallError("'device_keys' must be a JSON object")
+
+ def is_list_of_strings(values: Any) -> bool:
+ return isinstance(values, list) and all(isinstance(v, str) for v in values)
+
+ if any(not is_list_of_strings(keys) for keys in device_keys.values()):
+ raise InvalidAPICallError(
+ "'device_keys' values must be a list of strings",
+ )
+
result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
@@ -181,17 +194,13 @@ class KeyChangesServlet(RestServlet):
PATTERNS = client_patterns("/keys/changes$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from", required=True)
@@ -231,12 +240,12 @@ class OneTimeKeyServlet(RestServlet):
PATTERNS = client_patterns("/keys/claim$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
@@ -255,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -267,7 +272,7 @@ class SigningKeyUploadServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -315,16 +320,12 @@ class SignaturesUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/signatures/upload$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -335,7 +336,7 @@ class SignaturesUploadServlet(RestServlet):
return 200, result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index 7d1bc406..68fb08d0 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -19,6 +19,7 @@ from twisted.web.server import Request
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -103,5 +104,5 @@ class KnockRoomAliasServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KnockRoomAliasServlet(hs).register(http_server)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 0c8d8967..4be502a7 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
@@ -104,7 +104,13 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_account.burst_count,
)
- def on_GET(self, request: SynapseRequest):
+ # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
+ # The reason for this is to ensure that the auth_provider_ids are registered
+ # with SsoHandler, which in turn ensures that the login/registration prometheus
+ # counters are initialised for the auth_provider_ids.
+ _load_sso_handlers(hs)
+
+ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -151,7 +157,7 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
- async def on_POST(self, request: SynapseRequest):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled:
@@ -211,7 +217,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
- ):
+ ) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -461,10 +467,7 @@ class RefreshTokenServlet(RestServlet):
self._clock = hs.get_clock()
self.access_token_lifetime = hs.config.access_token_lifetime
- async def on_POST(
- self,
- request: SynapseRequest,
- ):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
assert_params_in_dict(refresh_submission, ["refresh_token"])
@@ -499,12 +502,7 @@ class SsoRedirectServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
# register themselves with the main SSOHandler.
- if hs.config.cas_enabled:
- hs.get_cas_handler()
- if hs.config.saml2_enabled:
- hs.get_saml_handler()
- if hs.config.oidc_enabled:
- hs.get_oidc_handler()
+ _load_sso_handlers(hs)
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
@@ -569,7 +567,7 @@ class SsoRedirectServlet(RestServlet):
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._cas_handler = hs.get_cas_handler()
@@ -591,10 +589,26 @@ class CasTicketServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasTicketServlet(hs).register(http_server)
+
+
+def _load_sso_handlers(hs: "HomeServer") -> None:
+ """Ensure that the SSO handlers are loaded, if they are enabled by configuration.
+
+ This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
+ with the main SsoHandler.
+
+ It's safe to call this multiple times.
+ """
+ if hs.config.cas.cas_enabled:
+ hs.get_cas_handler()
+ if hs.config.saml2.saml2_enabled:
+ hs.get_saml_handler()
+ if hs.config.oidc.oidc_enabled:
+ hs.get_oidc_handler()
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 6055cac2..193a6951 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -13,9 +13,16 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -23,13 +30,13 @@ logger = logging.getLogger(__name__)
class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None:
@@ -48,13 +55,13 @@ class LogoutRestServlet(RestServlet):
class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
@@ -67,6 +74,6 @@ class LogoutAllRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LogoutRestServlet(hs).register(http_server)
LogoutAllRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 0ede643c..d1d8a984 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -13,26 +13,33 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.events.utils import format_event_for_client_v2_without_room_id
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_patterns("/notifications$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -87,5 +94,5 @@ class NotificationsServlet(RestServlet):
return 200, {"notifications": returned_push_actions, "next_token": next_token}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
NotificationsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py
index e8d26738..4dda6dce 100644
--- a/synapse/rest/client/openid.py
+++ b/synapse/rest/client/openid.py
@@ -12,15 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet):
EXPIRES_MS = 3600 * 1000
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
@@ -90,5 +98,5 @@ class IdTokenServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
IdTokenServlet(hs).register(http_server)
diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py
index a83927ae..6d64efb1 100644
--- a/synapse/rest/client/password_policy.py
+++ b/synapse/rest/client/password_policy.py
@@ -13,28 +13,32 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+from twisted.web.server import Request
+
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class PasswordPolicyServlet(RestServlet):
PATTERNS = client_patterns("/password_policy$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
- def on_GET(self, request):
+ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.enabled or not self.policy:
return (200, {})
@@ -53,5 +57,5 @@ class PasswordPolicyServlet(RestServlet):
return (200, policy)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py
index 6c27e5fa..94dd4fe2 100644
--- a/synapse/rest/client/presence.py
+++ b/synapse/rest/client/presence.py
@@ -15,12 +15,18 @@
""" This module contains REST servlets to do with presence: /presence/<paths>
"""
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -28,7 +34,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler()
@@ -37,7 +43,9 @@ class PresenceStatusRestServlet(RestServlet):
self._use_presence = hs.config.server.use_presence
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
@@ -53,13 +61,15 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(
+ result = format_user_presence_state(
state, self.clock.time_msec(), include_user_id=False
)
- return 200, state
+ return 200, result
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
@@ -91,5 +101,5 @@ class PresenceStatusRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PresenceStatusRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index 5463ed2c..d0f20de5 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -14,22 +14,31 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from typing import TYPE_CHECKING, Tuple
+
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -48,7 +57,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, ret
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
@@ -72,13 +83,15 @@ class ProfileDisplaynameRestServlet(RestServlet):
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -97,7 +110,9 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, ret
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
@@ -120,13 +135,15 @@ class ProfileAvatarURLRestServlet(RestServlet):
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -149,7 +166,7 @@ class ProfileRestServlet(RestServlet):
return 200, ret
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 84619c5e..98604a93 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -13,17 +13,23 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, StoreError, SynapseError
-from synapse.http.server import respond_with_html_bytes
+from synapse.http.server import HttpServer, respond_with_html_bytes
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,12 +37,12 @@ logger = logging.getLogger(__name__)
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@@ -50,14 +56,14 @@ class PushersRestServlet(RestServlet):
class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
@@ -165,7 +171,7 @@ class PushersRemoveRestServlet(RestServlet):
return None
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 027f8b81..43c04fac 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -13,27 +13,36 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.presence_handler.bump_presence_active_time(requester.user)
@@ -70,5 +79,5 @@ class ReadMarkerRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReadMarkerRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 58b8e8f2..7b5f49d6 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import hmac
import logging
import random
from typing import List, Union
@@ -28,6 +27,7 @@ from synapse.api.errors import (
ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
@@ -59,18 +59,6 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler
-# We ought to be using hmac.compare_digest() but on older pythons it doesn't
-# exist. It's a _really minor_ security flaw to use plain string comparison
-# because the timing attack is so obscured by all the other code here it's
-# unlikely to make much difference
-if hasattr(hmac, "compare_digest"):
- compare_digest = hmac.compare_digest
-else:
-
- def compare_digest(a, b):
- return a == b
-
-
logger = logging.getLogger(__name__)
@@ -379,6 +367,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
return 200, {"available": True}
+class RegistrationTokenValidityRestServlet(RestServlet):
+ """Check the validity of a registration token.
+
+ Example:
+
+ GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+ 200 OK
+
+ {
+ "valid": true
+ }
+ """
+
+ PATTERNS = client_patterns(
+ f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+ releases=(),
+ unstable=True,
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.ratelimiter = Ratelimiter(
+ store=self.store,
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+ burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+ )
+
+ async def on_GET(self, request):
+ await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+ if not self.hs.config.enable_registration:
+ raise SynapseError(
+ 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+ )
+
+ token = parse_string(request, "token", required=True)
+ valid = await self.store.registration_token_is_valid(token)
+
+ return 200, {"valid": valid}
+
+
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
@@ -686,6 +723,22 @@ class RegisterRestServlet(RestServlet):
)
if registered:
+ # Check if a token was used to authenticate registration
+ registration_token = await self.auth_handler.get_session_data(
+ session_id,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ )
+ if registration_token:
+ # Increment the `completed` counter for the token
+ await self.store.use_registration_token(registration_token)
+ # Indicate that the token has been successfully used so that
+ # pending is not decremented again when expiring old UIA sessions.
+ await self.store.mark_ui_auth_stage_complete(
+ session_id,
+ LoginType.REGISTRATION_TOKEN,
+ True,
+ )
+
await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
@@ -868,6 +921,11 @@ def _calculate_registration_flows(
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
+ # Prepend registration token to all flows if we're requiring a token
+ if config.registration_requires_token:
+ for flow in flows:
+ flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
return flows
@@ -876,4 +934,5 @@ def register_servlets(hs, http_server):
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server)
+ RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 6d1b083a..6a7792e1 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -13,18 +13,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util import stringutils
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet):
}
Creates a new room and shuts down the old one. Returns the ID of the new room.
-
- Args:
- hs (synapse.server.HomeServer):
"""
PATTERNS = client_patterns(
@@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/upgrade$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -84,5 +90,5 @@ class RoomUpgradeRestServlet(RestServlet):
return 200, ret
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomUpgradeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py
index d2e7f04b..1d90493e 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/shared_rooms.py
@@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
-from synapse.types import UserID
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet):
releases=(), # This is an unstable feature
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
if not self.user_directory_active:
raise SynapseError(
@@ -63,5 +71,5 @@ class UserSharedRoomsServlet(RestServlet):
return 200, {"joined": list(rooms)}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e18f4d01..65c37be3 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,17 +14,26 @@
import itertools
import logging
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
+from synapse.api.presence import UserPresenceState
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import KnockedSyncResult, SyncConfig
+from synapse.handlers.sync import (
+ ArchivedSyncResult,
+ InvitedSyncResult,
+ JoinedSyncResult,
+ KnockedSyncResult,
+ SyncConfig,
+ SyncResult,
+)
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken
@@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet):
return 200, {}
time_now = self.clock.time_msec()
+ # We know that the the requester has an access token since appservices
+ # cannot use sync.
response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
)
@@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
- async def encode_response(self, time_now, sync_result, access_token_id, filter):
+ async def encode_response(
+ self,
+ time_now: int,
+ sync_result: SyncResult,
+ access_token_id: Optional[int],
+ filter: FilterCollection,
+ ) -> JsonDict:
logger.debug("Formatting events in sync response")
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
@@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet):
logger.debug("building sync response dict")
- response: dict = defaultdict(dict)
+ response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.account_data:
@@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet):
if archived:
response["rooms"][Membership.LEAVE] = archived
+ # By the time we get here groups is no longer optional.
+ assert sync_result.groups is not None
if sync_result.groups.join:
response["groups"][Membership.JOIN] = sync_result.groups.join
if sync_result.groups.invite:
@@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet):
return response
@staticmethod
- def encode_presence(events, time_now):
+ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
return {
"events": [
{
@@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet):
}
async def encode_joined(
- self, rooms, time_now, token_id, event_fields, event_formatter
- ):
+ self,
+ rooms: List[JoinedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_fields: List[str],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the joined rooms in a sync result
Args:
- rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
- results for rooms this user is joined to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is joined to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_fields(list<str>): List of event fields to include. If empty,
+ event_fields: List of event fields to include. If empty,
all fields will be returned.
- event_formatter (func[dict]): function to convert from federation format
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, dict[str, object]]: the joined rooms list, in our
- response format
+ The joined rooms list, in our response format
"""
joined = {}
for room in rooms:
@@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet):
return joined
- async def encode_invited(self, rooms, time_now, token_id, event_formatter):
+ async def encode_invited(
+ self,
+ rooms: List[InvitedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the invited rooms in a sync result
Args:
- rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
- sync results for rooms this user is invited to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is invited to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_formatter (func[dict]): function to convert from federation format
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, dict[str, object]]: the invited rooms list, in our
- response format
+ The invited rooms list, in our response format
"""
invited = {}
for room in rooms:
@@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet):
self,
rooms: List[KnockedSyncResult],
time_now: int,
- token_id: int,
+ token_id: Optional[int],
event_formatter: Callable[[Dict], Dict],
) -> Dict[str, Dict[str, Any]]:
"""
@@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet):
return knocked
async def encode_archived(
- self, rooms, time_now, token_id, event_fields, event_formatter
- ):
+ self,
+ rooms: List[ArchivedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_fields: List[str],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the archived rooms in a sync result
Args:
- rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
- sync results for rooms this user is joined to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is joined to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_fields(list<str>): List of event fields to include. If empty,
+ event_fields: List of event fields to include. If empty,
all fields will be returned.
- event_formatter (func[dict]): function to convert from federation format
- to client format
+ event_formatter: function to convert from federation format to client format
Returns:
- dict[str, dict[str, object]]: The invited rooms list, in our
- response format
+ The archived rooms list, in our response format
"""
joined = {}
for room in rooms:
@@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet):
return joined
async def encode_room(
- self, room, time_now, token_id, joined, only_fields, event_formatter
- ):
+ self,
+ room: Union[JoinedSyncResult, ArchivedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ joined: bool,
+ only_fields: Optional[List[str]],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Args:
- room (JoinedSyncResult|ArchivedSyncResult): sync result for a
- single room
- time_now (int): current time - used as a baseline for age
- calculations
- token_id (int): ID of the user's auth token - used for namespacing
+ room: sync result for a single room
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- joined (bool): True if the user is joined to this room - will mean
+ joined: True if the user is joined to this room - will mean
we handle ephemeral events
- only_fields(list<str>): Optional. The list of event fields to include.
- event_formatter (func[dict]): function to convert from federation format
+ only_fields: Optional. The list of event fields to include.
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, object]: the room, encoded in our response format
+ The room, encoded in our response format
"""
def serialize(events):
@@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet):
account_data = room.account_data
- result = {
+ result: JsonDict = {
"timeline": {
"events": serialized_timeline,
"prev_batch": await room.timeline.prev_batch.to_string(self.store),
@@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet):
}
if joined:
+ assert isinstance(room, JoinedSyncResult)
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
@@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet):
return result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index c14f83be..c88cb936 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -13,12 +13,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -29,12 +36,14 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
@@ -54,12 +63,14 @@ class TagServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, room_id, tag):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@@ -70,7 +81,9 @@ class TagServlet(RestServlet):
return 200, {}
- async def on_DELETE(self, request, user_id, room_id, tag):
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@@ -80,6 +93,6 @@ class TagServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)
diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py
index b5c67c9b..b895c73a 100644
--- a/synapse/rest/client/thirdparty.py
+++ b/synapse/rest/client/thirdparty.py
@@ -12,27 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.constants import ThirdPartyEntityKind
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols()
@@ -42,13 +48,15 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols(
@@ -63,16 +71,18 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
- fields = request.args
+ fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@@ -85,16 +95,18 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
- fields = request.args
+ fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@@ -104,7 +116,7 @@ class ThirdPartyLocationServlet(RestServlet):
return 200, results
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py
index b2f85854..c8c3b25b 100644
--- a/synapse/rest/client/tokenrefresh.py
+++ b/synapse/rest/client/tokenrefresh.py
@@ -12,11 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class TokenRefreshRestServlet(RestServlet):
"""
@@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> None:
raise AuthError(403, "tokenrefresh is no longer supported.")
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TokenRefreshRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index 7e8912f0..88528111 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -13,29 +13,32 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Searches for users in directory
Returns:
@@ -75,5 +78,5 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserDirectorySearchRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index fa2e4e9c..a1a815cf 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -17,9 +17,17 @@
import logging
import re
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
from synapse.api.constants import RoomCreationPreset
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -27,7 +35,7 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
@@ -45,7 +53,7 @@ class VersionsRestServlet(RestServlet):
in self.config.encryption_enabled_by_default_for_room_presets
)
- def on_GET(self, request):
+ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
return (
200,
{
@@ -89,5 +97,5 @@ class VersionsRestServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VersionsRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index f5302052..9d46ed3a 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -15,20 +15,27 @@
import base64
import hashlib
import hmac
+from typing import TYPE_CHECKING, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests
)
@@ -69,5 +76,5 @@ class VoipRestServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VoipRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a029d426..12bd745c 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -15,7 +15,7 @@
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from twisted.web.server import Request
@@ -414,9 +414,9 @@ class ThumbnailResource(DirectServeJsonResource):
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
- crop_info_list = []
+ crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
# Other thumbnails.
- crop_info_list2 = []
+ crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "crop":
@@ -451,15 +451,19 @@ class ThumbnailResource(DirectServeJsonResource):
info,
)
)
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info dictionary and pick the thumbnail that appears earlier
+ # in the list of candidates.
if crop_info_list:
- thumbnail_info = min(crop_info_list)[-1]
+ thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
elif crop_info_list2:
- thumbnail_info = min(crop_info_list2)[-1]
+ thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
- info_list = []
+ info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = []
# Other thumbnails.
- info_list2 = []
+ info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
@@ -477,10 +481,14 @@ class ThumbnailResource(DirectServeJsonResource):
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info dictionary and pick the thumbnail that appears earlier
+ # in the list of candidates.
if info_list:
- thumbnail_info = min(info_list)[-1]
+ thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
elif info_list2:
- thumbnail_info = min(info_list2)[-1]
+ thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
if thumbnail_info:
return FileInfo(
diff --git a/synapse/server.py b/synapse/server.py
index de651766..5adeeff6 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -76,6 +76,7 @@ from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
from synapse.handlers.event_auth import EventAuthHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.federation import FederationHandler
+from synapse.handlers.federation_event import FederationEventHandler
from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
from synapse.handlers.identity import IdentityHandler
from synapse.handlers.initial_sync import InitialSyncHandler
@@ -547,6 +548,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return FederationHandler(self)
@cache_in_self
+ def get_federation_event_handler(self) -> FederationEventHandler:
+ return FederationEventHandler(self)
+
+ @cache_in_self
def get_identity_handler(self) -> IdentityHandler:
return IdentityHandler(self)
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index f19075b7..d87a5389 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -12,26 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.events import EventBase
from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
SERVER_NOTICE_ROOM_TAG = "m.server_notice"
class ServerNoticesManager:
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
-
+ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._config = hs.config
self._account_data_handler = hs.get_account_data_handler()
@@ -58,6 +55,7 @@ class ServerNoticesManager:
event_content: dict,
type: str = EventTypes.Message,
state_key: Optional[str] = None,
+ txn_id: Optional[str] = None,
) -> EventBase:
"""Send a notice to the given user
@@ -68,6 +66,7 @@ class ServerNoticesManager:
event_content: content of event to send
type: type of event
is_state_event: Is the event a state event
+ txn_id: The transaction ID.
"""
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
@@ -90,7 +89,7 @@ class ServerNoticesManager:
event_dict["state_key"] = state_key
event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, ratelimit=False
+ requester, event_dict, ratelimit=False, txn_id=txn_id
)
return event
diff --git a/synapse/static/client/register/style.css b/synapse/static/client/register/style.css
index 5a7b6eeb..8a39b5d0 100644
--- a/synapse/static/client/register/style.css
+++ b/synapse/static/client/register/style.css
@@ -57,4 +57,8 @@ textarea, input {
background-color: #f8f8f8;
border: 1px #ccc solid;
-} \ No newline at end of file
+}
+
+.error {
+ color: red;
+}
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 01b918e1..00a644e8 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -63,6 +63,7 @@ from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
+from .session import SessionStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
@@ -121,6 +122,7 @@ class DataStore(
ServerMetricsStore,
EventForwardExtremitiesStore,
LockStore,
+ SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 375463e4..9501f00f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -520,16 +520,26 @@ class EventsWorkerStore(SQLBaseStore):
# We now look up if we're already fetching some of the events in the DB,
# if so we wait for those lookups to finish instead of pulling the same
# events out of the DB multiple times.
- already_fetching: Dict[str, defer.Deferred] = {}
+ #
+ # Note: we might get the same `ObservableDeferred` back for multiple
+ # events we're already fetching, so we deduplicate the deferreds to
+ # avoid extraneous work (if we don't do this we can end up in a n^2 mode
+ # when we wait on the same Deferred N times, then try and merge the
+ # same dict into itself N times).
+ already_fetching_ids: Set[str] = set()
+ already_fetching_deferreds: Set[
+ ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ] = set()
for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
- already_fetching[event_id] = deferred.observe()
+ already_fetching_ids.add(event_id)
+ already_fetching_deferreds.add(deferred)
- missing_events_ids.difference_update(already_fetching)
+ missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids:
log_ctx = current_context()
@@ -569,18 +579,25 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
- if already_fetching:
+ if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results
# to ours.
results = await make_deferred_yieldable(
defer.gatherResults(
- already_fetching.values(),
+ (d.observe() for d in already_fetching_deferreds),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
for result in results:
- event_entry_map.update(result)
+ # We filter out events that we haven't asked for as we might get
+ # a *lot* of superfluous events back, and there is no point
+ # going through and inserting them all (which can take time).
+ event_entry_map.update(
+ (event_id, entry)
+ for event_id, entry in result.items()
+ if event_id in already_fetching_ids
+ )
if not allow_rejected:
event_entry_map = {
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 664c65da..bccff5e5 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
+ self._invalidate_get_event_cache(event_id)
logger.info("[purge] done")
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b48fe086..63ac09c6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_stale_pushers,
)
+ self.db_pool.updates.register_background_update_handler(
+ "remove_deleted_email_pushers",
+ self._remove_deleted_email_pushers,
+ )
+
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
@@ -388,6 +393,74 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted
+ async def _remove_deleted_email_pushers(
+ self, progress: dict, batch_size: int
+ ) -> int:
+ """A background update that deletes all pushers for deleted email addresses.
+
+ In previous versions of synapse, when users deleted their email address, it didn't
+ also delete all the pushers for that email address. This background update removes
+ those to prevent unwanted emails. This should only need to be run once (when users
+ upgrade to v1.42.0
+
+ Args:
+ progress: dict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ last_pusher = progress.get("last_pusher", 0)
+
+ def _delete_pushers(txn) -> int:
+
+ sql = """
+ SELECT p.id, p.user_name, p.app_id, p.pushkey
+ FROM pushers AS p
+ LEFT JOIN user_threepids AS t
+ ON t.user_id = p.user_name
+ AND t.medium = 'email'
+ AND t.address = p.pushkey
+ WHERE t.user_id is NULL
+ AND p.app_id = 'm.email'
+ AND p.id > ?
+ ORDER BY p.id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_pusher, batch_size))
+ rows = txn.fetchall()
+
+ last = None
+ num_deleted = 0
+ for row in rows:
+ last = row[0]
+ num_deleted += 1
+ self.db_pool.simple_delete_txn(
+ txn,
+ "pushers",
+ {"user_name": row[1], "app_id": row[2], "pushkey": row[3]},
+ )
+
+ if last is not None:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "remove_deleted_email_pushers", {"last_pusher": last}
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deleted_email_pushers", _delete_pushers
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "remove_deleted_email_pushers"
+ )
+
+ return number_deleted
+
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c67bea81..a6517962 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return user_id
- def get_user_id_by_threepid_txn(self, txn, medium, address):
+ def get_user_id_by_threepid_txn(
+ self, txn, medium: str, address: str
+ ) -> Optional[str]:
"""Returns user id from threepid
Args:
txn (cursor):
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
+ medium: threepid medium e.g. email
+ address: threepid address e.g. me@example.com
Returns:
- str|None: user id or None if no user id/threepid mapping exists
+ user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
txn,
@@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return ret["user_id"]
return None
- async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ async def user_add_threepid(
+ self,
+ user_id: str,
+ medium: str,
+ address: str,
+ validated_at: int,
+ added_at: int,
+ ) -> None:
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id):
+ async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"user_get_threepids",
)
- async def user_delete_threepid(self, user_id, medium, address) -> None:
+ async def user_delete_threepid(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
@@ -1157,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated",
)
+ async def registration_token_is_valid(self, token: str) -> bool:
+ """Checks if a token can be used to authenticate a registration.
+
+ Args:
+ token: The registration token to be checked
+ Returns:
+ True if the token is valid, False otherwise.
+ """
+ res = await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ )
+
+ # Check if the token exists
+ if res is None:
+ return False
+
+ # Check if the token has expired
+ now = self._clock.time_msec()
+ if res["expiry_time"] and res["expiry_time"] < now:
+ return False
+
+ # Check if the token has been used up
+ if (
+ res["uses_allowed"]
+ and res["pending"] + res["completed"] >= res["uses_allowed"]
+ ):
+ return False
+
+ # Otherwise, the token is valid
+ return True
+
+ async def set_registration_token_pending(self, token: str) -> None:
+ """Increment the pending registrations counter for a token.
+
+ Args:
+ token: The registration token pending use
+ """
+
+ def _set_registration_token_pending_txn(txn):
+ pending = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": pending + 1},
+ )
+
+ return await self.db_pool.runInteraction(
+ "set_registration_token_pending", _set_registration_token_pending_txn
+ )
+
+ async def use_registration_token(self, token: str) -> None:
+ """Complete a use of the given registration token.
+
+ The `pending` counter will be decremented, and the `completed`
+ counter will be incremented.
+
+ Args:
+ token: The registration token to be 'used'
+ """
+
+ def _use_registration_token_txn(txn):
+ # Normally, res is Optional[Dict[str, Any]].
+ # Override type because the return type is only optional if
+ # allow_none is True, and we don't want mypy throwing errors
+ # about None not being indexable.
+ res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ ) # type: ignore
+
+ # Decrement pending and increment completed
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={
+ "completed": res["completed"] + 1,
+ "pending": res["pending"] - 1,
+ },
+ )
+
+ return await self.db_pool.runInteraction(
+ "use_registration_token", _use_registration_token_txn
+ )
+
+ async def get_registration_tokens(
+ self, valid: Optional[bool] = None
+ ) -> List[Dict[str, Any]]:
+ """List all registration tokens. Used by the admin API.
+
+ Args:
+ valid: If True, only valid tokens are returned.
+ If False, only invalid tokens are returned.
+ Default is None: return all tokens regardless of validity.
+
+ Returns:
+ A list of dicts, each containing details of a token.
+ """
+
+ def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+ if valid is None:
+ # Return all tokens regardless of validity
+ txn.execute("SELECT * FROM registration_tokens")
+
+ elif valid:
+ # Select valid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
+ "AND (expiry_time > ? OR expiry_time IS NULL)"
+ )
+ txn.execute(sql, [now])
+
+ else:
+ # Select invalid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "uses_allowed <= pending + completed OR expiry_time <= ?"
+ )
+ txn.execute(sql, [now])
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "select_registration_tokens",
+ select_registration_tokens_txn,
+ self._clock.time_msec(),
+ valid,
+ )
+
+ async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
+ """Get info about the given registration token. Used by the admin API.
+
+ Args:
+ token: The token to retrieve information about.
+
+ Returns:
+ A dict, or None if token doesn't exist.
+ """
+ return await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ desc="get_one_registration_token",
+ )
+
+ async def generate_registration_token(
+ self, length: int, chars: str
+ ) -> Optional[str]:
+ """Generate a random registration token. Used by the admin API.
+
+ Args:
+ length: The length of the token to generate.
+ chars: A string of the characters allowed in the generated token.
+
+ Returns:
+ The generated token.
+
+ Raises:
+ SynapseError if a unique registration token could still not be
+ generated after a few tries.
+ """
+ # Make a few attempts at generating a unique token of the required
+ # length before failing.
+ for _i in range(3):
+ # Generate token
+ token = "".join(random.choices(chars, k=length))
+
+ # Check if the token already exists
+ existing_token = await self.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="token",
+ allow_none=True,
+ desc="check_if_registration_token_exists",
+ )
+
+ if existing_token is None:
+ # The generated token doesn't exist yet, return it
+ return token
+
+ raise SynapseError(
+ 500,
+ "Unable to generate a unique registration token. Try again with a greater length",
+ Codes.UNKNOWN,
+ )
+
+ async def create_registration_token(
+ self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
+ ) -> bool:
+ """Create a new registration token. Used by the admin API.
+
+ Args:
+ token: The token to create.
+ uses_allowed: The number of times the token can be used to complete
+ a registration before it becomes invalid. A value of None indicates
+ unlimited uses.
+ expiry_time: The latest time the token is valid. Given as the
+ number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
+ None indicates that the token does not expire.
+
+ Returns:
+ Whether the row was inserted or not.
+ """
+
+ def _create_registration_token_txn(txn):
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token"],
+ allow_none=True,
+ )
+
+ if row is not None:
+ # Token already exists
+ return False
+
+ self.db_pool.simple_insert_txn(
+ txn,
+ "registration_tokens",
+ values={
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ },
+ )
+
+ return True
+
+ return await self.db_pool.runInteraction(
+ "create_registration_token", _create_registration_token_txn
+ )
+
+ async def update_registration_token(
+ self, token: str, updatevalues: Dict[str, Optional[int]]
+ ) -> Optional[Dict[str, Any]]:
+ """Update a registration token. Used by the admin API.
+
+ Args:
+ token: The token to update.
+ updatevalues: A dict with the fields to update. E.g.:
+ `{"uses_allowed": 3}` to update just uses_allowed, or
+ `{"uses_allowed": 3, "expiry_time": None}` to update both.
+ This is passed straight to simple_update_one.
+
+ Returns:
+ A dict with all info about the token, or None if token doesn't exist.
+ """
+
+ def _update_registration_token_txn(txn):
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues=updatevalues,
+ )
+ except StoreError:
+ # Update failed because token does not exist
+ return None
+
+ # Get all info about the token so it can be sent in the response
+ return self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=[
+ "token",
+ "uses_allowed",
+ "pending",
+ "completed",
+ "expiry_time",
+ ],
+ allow_none=True,
+ )
+
+ return await self.db_pool.runInteraction(
+ "update_registration_token", _update_registration_token_txn
+ )
+
+ async def delete_registration_token(self, token: str) -> bool:
+ """Delete a registration token. Used by the admin API.
+
+ Args:
+ token: The token to delete.
+
+ Returns:
+ Whether the token was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ desc="delete_registration_token",
+ )
+ except StoreError:
+ # Deletion failed because token does not exist
+ return False
+
+ return True
+
@cached()
async def mark_access_token_as_used(self, token_id: int) -> None:
"""
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e8157ba3..c58a4b86 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
- async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
+ async def get_invited_rooms_for_local_user(
+ self, user_id: str
+ ) -> List[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@@ -384,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
sql = """
- SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
FROM local_current_membership AS c
INNER JOIN events AS e USING (room_id, event_id)
+ INNER JOIN rooms AS r USING (room_id)
WHERE
user_id = ?
AND %s
@@ -395,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
+ results = [RoomsForUser(*r) for r in txn]
return results
@@ -445,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
Returns the rooms the user is in currently, along with the stream
- ordering of the most recent join for that user and room.
+ ordering of the most recent join for that user and room, along with
+ the room version of the room.
"""
return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
@@ -522,7 +526,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn,
)
- async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
+ async def get_rooms_for_user(
+ self, user_id: str, on_invalidate=None
+ ) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
new file mode 100644
index 00000000..172f27d1
--- /dev/null
+++ b/synapse/storage/databases/main/session.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class SessionStore(SQLBaseStore):
+ """
+ A store for generic session data.
+
+ Each type of session should provide a unique type (to separate sessions).
+
+ Sessions are automatically removed when they expire.
+ """
+
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ # Create a background job for culling expired sessions.
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
+
+ async def create_session(
+ self, session_type: str, value: JsonDict, expiry_ms: int
+ ) -> str:
+ """
+ Creates a new pagination session for the room hierarchy endpoint.
+
+ Args:
+ session_type: The type for this session.
+ value: The value to store.
+ expiry_ms: How long before an item is evicted from the cache
+ in milliseconds. Default is 0, indicating items never get
+ evicted based on time.
+
+ Returns:
+ The newly created session ID.
+
+ Raises:
+ StoreError if a unique session ID cannot be generated.
+ """
+ # autogen a session ID and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ while attempts < 5:
+ session_id = stringutils.random_string(24)
+
+ try:
+ await self.db_pool.simple_insert(
+ table="sessions",
+ values={
+ "session_id": session_id,
+ "session_type": session_type,
+ "value": json_encoder.encode(value),
+ "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
+ },
+ desc="create_session",
+ )
+
+ return session_id
+ except self.db_pool.engine.module.IntegrityError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a session ID.")
+
+ async def get_session(self, session_type: str, session_id: str) -> JsonDict:
+ """
+ Retrieve data stored with create_session
+
+ Args:
+ session_type: The type for this session.
+ session_id: The session ID returned from create_session.
+
+ Raises:
+ StoreError if the session cannot be found.
+ """
+
+ def _get_session(
+ txn: LoggingTransaction, session_type: str, session_id: str, ts: int
+ ) -> JsonDict:
+ # This includes the expiry time since items are only periodically
+ # deleted, not upon expiry.
+ select_sql = """
+ SELECT value FROM sessions WHERE
+ session_type = ? AND session_id = ? AND expiry_time_ms > ?
+ """
+ txn.execute(select_sql, [session_type, session_id, ts])
+ row = txn.fetchone()
+
+ if not row:
+ raise StoreError(404, "No session")
+
+ return db_to_json(row[0])
+
+ return await self.db_pool.runInteraction(
+ "get_session",
+ _get_session,
+ session_type,
+ session_id,
+ self._clock.time_msec(),
+ )
+
+ @wrap_as_background_process("delete_expired_sessions")
+ async def _delete_expired_sessions(self) -> None:
+ """Remove sessions with expiry dates that have passed."""
+
+ def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
+ sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "delete_expired_sessions",
+ _delete_expired_sessions_txn,
+ self._clock.time_msec(),
+ )
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5d..4d6bbc94 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import attr
+from synapse.api.constants import LoginType
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
keyvalues={},
)
+ # If a registration token was used, decrement the pending counter
+ # before deleting the session.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ )
+
+ # Get the tokens used and how much pending needs to be decremented by.
+ token_counts: Dict[str, int] = {}
+ for r in rows:
+ # If registration was successfully completed, the result of the
+ # registration token stage for that session will be True.
+ # If a token was used to authenticate, but registration was
+ # never completed, the result will be the token used.
+ token = db_to_json(r["result"])
+ if isinstance(token, str):
+ token_counts[token] = token_counts.get(token, 0) + 1
+
+ # Update the `pending` counters.
+ if len(token_counts) > 0:
+ token_rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ )
+ for token_row in token_rows:
+ token = token_row["token"]
+ new_pending = token_row["pending"] - token_counts[token]
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": new_pending},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 9d28d69a..65dde67a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
async def update_profile_in_user_dir(
- self, user_id: str, display_name: str, avatar_url: str
+ self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
) -> None:
"""
Update or add a user's profile in the user directory.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c34fbf21..2500381b 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -14,25 +14,41 @@
# limitations under the License.
import logging
-from collections import namedtuple
+from typing import List, Optional, Tuple
+
+import attr
+
+from synapse.types import PersistedEventPosition
logger = logging.getLogger(__name__)
-RoomsForUser = namedtuple(
- "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class RoomsForUser:
+ room_id: str
+ sender: str
+ membership: str
+ event_id: str
+ stream_ordering: int
+ room_version_id: str
+
+
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class GetRoomsForUserWithStreamOrdering:
+ room_id: str
+ event_pos: PersistedEventPosition
-GetRoomsForUserWithStreamOrdering = namedtuple(
- "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
-)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class ProfileInfo:
+ avatar_url: Optional[str]
+ display_name: Optional[str]
-# We store this using a namedtuple so that we save about 3x space over using a
-# dict.
-ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
-# "members" points to a truncated list of (user_id, event_id) tuples for users of
-# a given membership type, suitable for use in calculating heroes for a room.
-# "count" points to the total numberr of users of a given membership type.
-MemberSummary = namedtuple("MemberSummary", ("members", "count"))
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class MemberSummary:
+ # A truncated list of (user_id, event_id) tuples for users of a given
+ # membership type, suitable for use in calculating heroes for a room.
+ members: List[Tuple[str, str]]
+ # The total number of users of a given membership type.
+ count: int
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index a5bc0ee8..af9cc699 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# When updating these values, please leave a short summary of the changes below.
+
SCHEMA_VERSION = 63
"""Represents the expectations made by the codebase about the database schema
diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
new file mode 100644
index 00000000..ee6cf958
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
@@ -0,0 +1,23 @@
+/* Copyright 2021 Callum Brown
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS registration_tokens(
+ token TEXT NOT NULL, -- The token that can be used for authentication.
+ uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
+ pending INT NOT NULL, -- The number of in progress registrations using this token.
+ completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
+ expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
+ UNIQUE (token)
+);
diff --git a/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
new file mode 100644
index 00000000..611c4b95
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We may not have deleted all pushers for emails that are no longer linked
+-- to an account, so we set up a background job to delete them.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6302, 'remove_deleted_email_pushers', '{}');
diff --git a/synapse/storage/schema/main/delta/63/03session_store.sql b/synapse/storage/schema/main/delta/63/03session_store.sql
new file mode 100644
index 00000000..535fb34c
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/03session_store.sql
@@ -0,0 +1,23 @@
+/*
+ * Copyright 2021 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS sessions(
+ session_type TEXT NOT NULL, -- The unique key for this type of session.
+ session_id TEXT NOT NULL, -- The session ID passed to the client.
+ value TEXT NOT NULL, -- A JSON dictionary to persist.
+ expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds).
+ UNIQUE (session_type, session_id)
+);