summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-06-29 12:59:58 +0200
committerAndrej Shadura <andrewsh@debian.org>2021-06-29 12:59:58 +0200
commit364c37238258580e132178cc7b35acabce3ff326 (patch)
treedadc23431f7f55fd8bcd8780c64b519dae7d5a76 /synapse
parent219af4a8aef838c5e3689a2aa71cf72f2fd75aa2 (diff)
New upstream version 1.37.0
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py11
-rw-r--r--synapse/api/constants.py24
-rw-r--r--synapse/api/errors.py2
-rw-r--r--synapse/api/room_versions.py28
-rw-r--r--synapse/app/_base.py144
-rw-r--r--synapse/app/admin_cmd.py2
-rw-r--r--synapse/app/generic_worker.py26
-rw-r--r--synapse/app/homeserver.py68
-rw-r--r--synapse/appservice/api.py11
-rw-r--r--synapse/config/_base.py5
-rw-r--r--synapse/config/_base.pyi5
-rw-r--r--synapse/config/account_validity.py1
-rw-r--r--synapse/config/auth.py4
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/config/homeserver.py5
-rw-r--r--synapse/config/logger.py4
-rw-r--r--synapse/config/modules.py49
-rw-r--r--synapse/config/repository.py4
-rw-r--r--synapse/config/server.py27
-rw-r--r--synapse/config/spam_checker.py28
-rw-r--r--synapse/config/sso.py15
-rw-r--r--synapse/config/tls.py151
-rw-r--r--synapse/event_auth.py33
-rw-r--r--synapse/events/__init__.py9
-rw-r--r--synapse/events/builder.py17
-rw-r--r--synapse/events/spamcheck.py307
-rw-r--r--synapse/events/utils.py19
-rw-r--r--synapse/federation/federation_client.py69
-rw-r--r--synapse/federation/federation_server.py105
-rw-r--r--synapse/federation/transport/client.py41
-rw-r--r--synapse/federation/transport/server.py269
-rw-r--r--synapse/handlers/acme.py117
-rw-r--r--synapse/handlers/acme_issuing_service.py127
-rw-r--r--synapse/handlers/auth.py7
-rw-r--r--synapse/handlers/e2e_keys.py350
-rw-r--r--synapse/handlers/event_auth.py45
-rw-r--r--synapse/handlers/federation.py214
-rw-r--r--synapse/handlers/message.py134
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/room_list.py7
-rw-r--r--synapse/handlers/room_member.py284
-rw-r--r--synapse/handlers/room_member_worker.py55
-rw-r--r--synapse/handlers/space_summary.py45
-rw-r--r--synapse/handlers/sso.py25
-rw-r--r--synapse/handlers/stats.py7
-rw-r--r--synapse/handlers/sync.py123
-rw-r--r--synapse/http/matrixfederationclient.py14
-rw-r--r--synapse/http/servlet.py27
-rw-r--r--synapse/logging/_terse_json.py9
-rw-r--r--synapse/logging/opentracing.py156
-rw-r--r--synapse/module_api/__init__.py30
-rw-r--r--synapse/module_api/errors.py1
-rw-r--r--synapse/python_dependencies.py11
-rw-r--r--synapse/replication/http/_base.py11
-rw-r--r--synapse/replication/http/membership.py141
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/rest/__init__.py2
-rw-r--r--synapse/rest/client/v1/room.py316
-rw-r--r--synapse/rest/client/v2_alpha/devices.py6
-rw-r--r--synapse/rest/client/v2_alpha/keys.py8
-rw-r--r--synapse/rest/client/v2_alpha/knock.py107
-rw-r--r--synapse/rest/client/v2_alpha/openid.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py82
-rw-r--r--synapse/server.py44
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py9
-rw-r--r--synapse/storage/databases/main/event_federation.py50
-rw-r--r--synapse/storage/databases/main/room.py14
-rw-r--r--synapse/storage/databases/main/roommember.py2
-rw-r--r--synapse/storage/databases/main/stats.py1
-rw-r--r--synapse/storage/persist_events.py201
-rw-r--r--synapse/storage/prepare_database.py121
-rw-r--r--synapse/storage/schema/README.md37
-rw-r--r--synapse/storage/schema/__init__.py19
-rw-r--r--synapse/storage/schema/common/schema_version.sql7
-rw-r--r--synapse/storage/schema/main/delta/59/11add_knock_members_to_stats.sql20
-rw-r--r--synapse/types.py6
-rw-r--r--synapse/util/caches/response_cache.py99
-rw-r--r--synapse/util/metrics.py5
-rw-r--r--synapse/util/module_loader.py35
80 files changed, 3178 insertions, 1447 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index c3016fc6..c865d2e1 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.36.0"
+__version__ = "1.37.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 26a3b389..edf1b918 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -92,11 +92,8 @@ class Auth:
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
) -> None:
- prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self.compute_auth_events(
- event, prev_state_ids, for_verification=True
- )
- auth_events_by_id = await self.store.get_events(auth_events_ids)
+ auth_event_ids = event.auth_event_ids()
+ auth_events_by_id = await self.store.get_events(auth_event_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -207,7 +204,7 @@ class Auth:
request.requester = user_id
if user_id in self._force_tracing_for_users:
- opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+ opentracing.force_tracing()
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("user_id", user_id)
opentracing.set_tag("appservice_id", app_service.id)
@@ -260,7 +257,7 @@ class Auth:
request.requester = requester
if user_info.token_owner in self._force_tracing_for_users:
- opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+ opentracing.force_tracing()
opentracing.set_tag("authenticated_entity", user_info.token_owner)
opentracing.set_tag("user_id", user_info.user_id)
if device_id:
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 3940da5c..414e4c01 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -65,6 +65,12 @@ class JoinRules:
MSC3083_RESTRICTED = "restricted"
+class RestrictedJoinRuleTypes:
+ """Understood types for the allow rules in restricted join rules."""
+
+ ROOM_MEMBERSHIP = "m.room_membership"
+
+
class LoginType:
PASSWORD = "m.login.password"
EMAIL_IDENTITY = "m.login.email.identity"
@@ -112,8 +118,9 @@ class EventTypes:
SpaceChild = "m.space.child"
SpaceParent = "m.space.parent"
- MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child"
- MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent"
+
+ MSC2716_INSERTION = "org.matrix.msc2716.insertion"
+ MSC2716_MARKER = "org.matrix.msc2716.marker"
class ToDeviceEventTypes:
@@ -180,7 +187,18 @@ class EventContentFields:
# cf https://github.com/matrix-org/matrix-doc/pull/1772
ROOM_TYPE = "type"
- MSC1772_ROOM_TYPE = "org.matrix.msc1772.type"
+
+ # Used on normal messages to indicate they were historically imported after the fact
+ MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
+ # For "insertion" events
+ MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id"
+ # Used on normal message events to indicate where the chunk connects to
+ MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id"
+ # For "marker" events
+ MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
+ MSC2716_MARKER_INSERTION_PREV_EVENTS = (
+ "org.matrix.msc2716.marker.insertion_prev_events"
+ )
class RoomEncryptionAlgorithms:
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 0231c790..4cb8bbaf 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -449,7 +449,7 @@ class IncompatibleRoomVersionError(SynapseError):
super().__init__(
code=400,
msg="Your homeserver does not support the features required to "
- "join this room",
+ "interact with this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
)
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 373a4669..f6c1c97b 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -56,7 +56,7 @@ class RoomVersion:
state_res = attr.ib(type=int) # one of the StateResolutionVersions
enforce_key_validity = attr.ib(type=bool)
- # Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
+ # Before MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool)
# Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
@@ -70,6 +70,9 @@ class RoomVersion:
msc2176_redaction_rules = attr.ib(type=bool)
# MSC3083: Support the 'restricted' join_rule.
msc3083_join_rules = attr.ib(type=bool)
+ # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
+ # m.room.membership event with membership 'knock'.
+ msc2403_knocking = attr.ib(type=bool)
class RoomVersions:
@@ -84,6 +87,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V2 = RoomVersion(
"2",
@@ -96,6 +100,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V3 = RoomVersion(
"3",
@@ -108,6 +113,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V4 = RoomVersion(
"4",
@@ -120,6 +126,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V5 = RoomVersion(
"5",
@@ -132,6 +139,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
V6 = RoomVersion(
"6",
@@ -144,6 +152,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -156,6 +165,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=True,
msc3083_join_rules=False,
+ msc2403_knocking=False,
)
MSC3083 = RoomVersion(
"org.matrix.msc3083",
@@ -168,6 +178,20 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=True,
+ msc2403_knocking=False,
+ )
+ V7 = RoomVersion(
+ "7",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=False,
+ msc2403_knocking=True,
)
@@ -182,5 +206,7 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V6,
RoomVersions.MSC2176,
RoomVersions.MSC3083,
+ RoomVersions.V7,
)
+ # Note that we do not include MSC2043 here unless it is enabled in the config.
} # type: Dict[str, RoomVersion]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 1329af2e..88791368 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -26,7 +26,9 @@ from typing import Awaitable, Callable, Iterable
from cryptography.utils import CryptographyDeprecationWarning
from typing_extensions import NoReturn
+import twisted
from twisted.internet import defer, error, reactor
+from twisted.logger import LoggingFile, LogLevel
from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
@@ -35,10 +37,10 @@ from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
+from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
-from synapse.util.async_helpers import Linearizer
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -112,8 +114,6 @@ def start_reactor(
run_command (Callable[]): callable that actually runs the reactor
"""
- install_dns_limiter(reactor)
-
def run():
logger.info("Running")
setup_jemalloc_stats()
@@ -141,7 +141,7 @@ def start_reactor(
def quit_with_error(error_string: str) -> NoReturn:
message_lines = error_string.split("\n")
- line_length = max(len(line) for line in message_lines if len(line) < 80) + 2
+ line_length = min(max(len(line) for line in message_lines), 80) + 2
sys.stderr.write("*" * line_length + "\n")
for line in message_lines:
sys.stderr.write(" %s\n" % (line.rstrip(),))
@@ -149,6 +149,30 @@ def quit_with_error(error_string: str) -> NoReturn:
sys.exit(1)
+def handle_startup_exception(e: Exception) -> NoReturn:
+ # Exceptions that occur between setting up the logging and forking or starting
+ # the reactor are written to the logs, followed by a summary to stderr.
+ logger.exception("Exception during startup")
+ quit_with_error(
+ f"Error during initialisation:\n {e}\nThere may be more information in the logs."
+ )
+
+
+def redirect_stdio_to_logs() -> None:
+ streams = [("stdout", LogLevel.info), ("stderr", LogLevel.error)]
+
+ for (stream, level) in streams:
+ oldStream = getattr(sys, stream)
+ loggingFile = LoggingFile(
+ logger=twisted.logger.Logger(namespace=stream),
+ level=level,
+ encoding=getattr(oldStream, "encoding", None),
+ )
+ setattr(sys, stream, loggingFile)
+
+ print("Redirected stdout/stderr to logs")
+
+
def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
"""Register a callback with the reactor, to be called once it is running
@@ -292,8 +316,7 @@ async def start(hs: "synapse.server.HomeServer"):
"""
Start a Synapse server or worker.
- Should be called once the reactor is running and (if we're using ACME) the
- TLS certificates are in place.
+ Should be called once the reactor is running.
Will start the main HTTP listeners and do some other startup tasks, and then
notify systemd.
@@ -334,6 +357,14 @@ async def start(hs: "synapse.server.HomeServer"):
# Start the tracer
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
+ # Instantiate the modules so they can register their web resources to the module API
+ # before we start the listeners.
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
+ load_legacy_spam_checkers(hs)
+
# It is now safe to start your Synapse.
hs.start_listening()
hs.get_datastore().db_pool.start_profiling()
@@ -398,107 +429,6 @@ def setup_sdnotify(hs):
)
-def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
- """Replaces the resolver with one that limits the number of in flight DNS
- requests.
-
- This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
- can run out of file descriptors and infinite loop if we attempt to do too
- many DNS queries at once
-
- XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
- you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
- backed by the reactor's default threadpool (which is limited to 10 threads). So
- (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
- understand why we would run out of FDs if we did too many lookups at once.
- -- richvdh 2020/08/29
- """
- new_resolver = _LimitedHostnameResolver(
- reactor.nameResolver, max_dns_requests_in_flight
- )
-
- reactor.installNameResolver(new_resolver)
-
-
-class _LimitedHostnameResolver:
- """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups."""
-
- def __init__(self, resolver, max_dns_requests_in_flight):
- self._resolver = resolver
- self._limiter = Linearizer(
- name="dns_client_limiter", max_count=max_dns_requests_in_flight
- )
-
- def resolveHostName(
- self,
- resolutionReceiver,
- hostName,
- portNumber=0,
- addressTypes=None,
- transportSemantics="TCP",
- ):
- # We need this function to return `resolutionReceiver` so we do all the
- # actual logic involving deferreds in a separate function.
-
- # even though this is happening within the depths of twisted, we need to drop
- # our logcontext before starting _resolve, otherwise: (a) _resolve will drop
- # the logcontext if it returns an incomplete deferred; (b) _resolve will
- # call the resolutionReceiver *with* a logcontext, which it won't be expecting.
- with PreserveLoggingContext():
- self._resolve(
- resolutionReceiver,
- hostName,
- portNumber,
- addressTypes,
- transportSemantics,
- )
-
- return resolutionReceiver
-
- @defer.inlineCallbacks
- def _resolve(
- self,
- resolutionReceiver,
- hostName,
- portNumber=0,
- addressTypes=None,
- transportSemantics="TCP",
- ):
-
- with (yield self._limiter.queue(())):
- # resolveHostName doesn't return a Deferred, so we need to hook into
- # the receiver interface to get told when resolution has finished.
-
- deferred = defer.Deferred()
- receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred)
-
- self._resolver.resolveHostName(
- receiver, hostName, portNumber, addressTypes, transportSemantics
- )
-
- yield deferred
-
-
-class _DeferredResolutionReceiver:
- """Wraps a IResolutionReceiver and simply resolves the given deferred when
- resolution is complete
- """
-
- def __init__(self, receiver, deferred):
- self._receiver = receiver
- self._deferred = deferred
-
- def resolutionBegan(self, resolutionInProgress):
- self._receiver.resolutionBegan(resolutionInProgress)
-
- def addressResolved(self, address):
- self._receiver.addressResolved(address)
-
- def resolutionComplete(self):
- self._deferred.callback(())
- self._receiver.resolutionComplete()
-
-
sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET")
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 68ae19c9..2878d2c1 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -36,7 +36,6 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
@@ -54,7 +53,6 @@ class AdminCmdSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
- SlavedPresenceStore,
SlavedGroupServerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 57c2fc2e..af8a1833 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -32,7 +32,12 @@ from synapse.api.urls import (
SERVER_KEY_V2_PREFIX,
)
from synapse.app import _base
-from synapse.app._base import max_request_body_size, register_start
+from synapse.app._base import (
+ handle_startup_exception,
+ max_request_body_size,
+ redirect_stdio_to_logs,
+ register_start,
+)
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
@@ -354,6 +359,10 @@ class GenericWorkerServer(HomeServer):
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
+ # Attach additional resources registered by modules.
+ resources.update(self._module_web_resources)
+ self._module_web_resources_consumed = True
+
root_resource = create_resource_tree(resources, OptionsResource())
_base.listen_tcp(
@@ -465,14 +474,21 @@ def start(config_options):
setup_logging(hs, config, use_worker_options=True)
- hs.setup()
+ try:
+ hs.setup()
- # Ensure the replication streamer is always started in case we write to any
- # streams. Will no-op if no streams can be written to by this worker.
- hs.get_replication_streamer()
+ # Ensure the replication streamer is always started in case we write to any
+ # streams. Will no-op if no streams can be written to by this worker.
+ hs.get_replication_streamer()
+ except Exception as e:
+ handle_startup_exception(e)
register_start(_base.start, hs)
+ # redirect stdio to the logs, if configured.
+ if not hs.config.no_redirect_stdio:
+ redirect_stdio_to_logs()
+
_base.start_worker_reactor("synapse-generic-worker", config)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index b2501ee4..7af56ac1 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -37,10 +37,11 @@ from synapse.api.urls import (
)
from synapse.app import _base
from synapse.app._base import (
+ handle_startup_exception,
listen_ssl,
listen_tcp,
max_request_body_size,
- quit_with_error,
+ redirect_stdio_to_logs,
register_start,
)
from synapse.config._base import ConfigError
@@ -69,8 +70,6 @@ from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.engines import IncorrectDatabaseSetup
-from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.module_loader import load_module
from synapse.util.versionstring import get_version_string
@@ -124,6 +123,10 @@ class SynapseHomeServer(HomeServer):
)
resources[path] = resource
+ # Attach additional resources registered by modules.
+ resources.update(self._module_web_resources)
+ self._module_web_resources_consumed = True
+
# try to find something useful to redirect '/' to
if WEB_CLIENT_PREFIX in resources:
root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
@@ -358,60 +361,10 @@ def setup(config_options):
try:
hs.setup()
- except IncorrectDatabaseSetup as e:
- quit_with_error(str(e))
- except UpgradeDatabaseException as e:
- quit_with_error("Failed to upgrade database: %s" % (e,))
-
- async def do_acme() -> bool:
- """
- Reprovision an ACME certificate, if it's required.
-
- Returns:
- Whether the cert has been updated.
- """
- acme = hs.get_acme_handler()
-
- # Check how long the certificate is active for.
- cert_days_remaining = hs.config.is_disk_cert_valid(allow_self_signed=False)
-
- # We want to reprovision if cert_days_remaining is None (meaning no
- # certificate exists), or the days remaining number it returns
- # is less than our re-registration threshold.
- provision = False
-
- if (
- cert_days_remaining is None
- or cert_days_remaining < hs.config.acme_reprovision_threshold
- ):
- provision = True
-
- if provision:
- await acme.provision_certificate()
-
- return provision
-
- async def reprovision_acme():
- """
- Provision a certificate from ACME, if required, and reload the TLS
- certificate if it's renewed.
- """
- reprovisioned = await do_acme()
- if reprovisioned:
- _base.refresh_certificate(hs)
+ except Exception as e:
+ handle_startup_exception(e)
async def start():
- # Run the ACME provisioning code, if it's enabled.
- if hs.config.acme_enabled:
- acme = hs.get_acme_handler()
- # Start up the webservices which we will respond to ACME
- # challenges with, and then provision.
- await acme.start_listening()
- await do_acme()
-
- # Check if it needs to be reprovisioned every day.
- hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
-
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
@@ -500,6 +453,11 @@ def main():
# check base requirements
check_requirements()
hs = setup(sys.argv[1:])
+
+ # redirect stdio to the logs, if configured.
+ if not hs.config.no_redirect_stdio:
+ redirect_stdio_to_logs()
+
run(hs)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index fe04d7a6..61152b2c 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events import EventBase
from synapse.events.utils import serialize_event
@@ -247,9 +247,14 @@ class ApplicationServiceApi(SimpleHttpClient):
e,
time_now,
as_client_event=True,
- is_invite=(
+ # If this is an invite or a knock membership event, and we're interested
+ # in this user, then include any stripped state alongside the event.
+ include_stripped_room_state=(
e.type == EventTypes.Member
- and e.membership == "invite"
+ and (
+ e.membership == Membership.INVITE
+ or e.membership == Membership.KNOCK
+ )
and service.is_interested_in_user(e.state_key)
),
)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 08e2c2c5..d6ec618f 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -405,7 +405,6 @@ class RootConfig:
listeners=None,
tls_certificate_path=None,
tls_private_key_path=None,
- acme_domain=None,
):
"""
Build a default configuration file
@@ -457,9 +456,6 @@ class RootConfig:
tls_private_key_path (str|None): The path to the tls private key.
- acme_domain (str|None): The domain acme will try to validate. If
- specified acme will be enabled.
-
Returns:
str: the yaml config file
"""
@@ -477,7 +473,6 @@ class RootConfig:
listeners=listeners,
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
- acme_domain=acme_domain,
).values()
)
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index ff9abbc2..23ca0c83 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -11,11 +11,13 @@ from synapse.config import (
database,
emailconfig,
experimental,
+ federation,
groups,
jwt,
key,
logger,
metrics,
+ modules,
oidc,
password_auth_providers,
push,
@@ -85,6 +87,8 @@ class RootConfig:
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
redis: redis.RedisConfig
+ modules: modules.ModulesConfig
+ federation: federation.FederationConfig
config_classes: List = ...
def __init__(self) -> None: ...
@@ -111,7 +115,6 @@ class RootConfig:
database_conf: Optional[Any] = ...,
tls_certificate_path: Optional[str] = ...,
tls_private_key_path: Optional[str] = ...,
- acme_domain: Optional[str] = ...,
): ...
@classmethod
def load_or_generate_config(cls, description: Any, argv: Any): ...
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index c58a7d95..957de7f3 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/config/auth.py b/synapse/config/auth.py
index e10d641a..53809cee 100644
--- a/synapse/config/auth.py
+++ b/synapse/config/auth.py
@@ -103,6 +103,10 @@ class AuthConfig(Config):
# the user-interactive authentication process, by allowing for multiple
# (and potentially different) operations to use the same validation session.
#
+ # This is ignored for potentially "dangerous" operations (including
+ # deactivating an account, modifying an account password, and
+ # adding a 3PID).
+ #
# Uncomment below to allow for credential validation to last for 15
# seconds.
#
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 6ebce4b2..7fb1f702 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -29,3 +29,6 @@ class ExperimentalConfig(Config):
# MSC3026 (busy presence state)
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
+
+ # MSC2716 (backfill existing history)
+ self.msc2716_enabled = experimental.get("msc2716_enabled", False) # type: bool
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 5ae0f55b..1f42a518 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,6 +29,7 @@ from .jwt import JWTConfig
from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
+from .modules import ModulesConfig
from .oidc import OIDCConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
@@ -56,6 +56,7 @@ from .workers import WorkerConfig
class HomeServerConfig(RootConfig):
config_classes = [
+ ModulesConfig,
ServerConfig,
TlsConfig,
FederationConfig,
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 813076df..91d9bcf3 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -259,9 +259,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
finally:
threadlocal.active = False
- logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio)
- if not config.no_redirect_stdio:
- print("Redirected stdout/stderr to logs")
+ logBeginner.beginLoggingTo([_log], redirectStandardIO=False)
def _load_logging_config(log_config_path: str) -> None:
diff --git a/synapse/config/modules.py b/synapse/config/modules.py
new file mode 100644
index 00000000..3209e1c4
--- /dev/null
+++ b/synapse/config/modules.py
@@ -0,0 +1,49 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, List, Tuple
+
+from synapse.config._base import Config, ConfigError
+from synapse.util.module_loader import load_module
+
+
+class ModulesConfig(Config):
+ section = "modules"
+
+ def read_config(self, config: dict, **kwargs):
+ self.loaded_modules: List[Tuple[Any, Dict]] = []
+
+ configured_modules = config.get("modules") or []
+ for i, module in enumerate(configured_modules):
+ config_path = ("modules", "<item %i>" % i)
+ if not isinstance(module, dict):
+ raise ConfigError("expected a mapping", config_path)
+
+ self.loaded_modules.append(load_module(module, config_path))
+
+ def generate_config_section(self, **kwargs):
+ return """
+ ## Modules ##
+
+ # Server admins can expand Synapse's functionality with external modules.
+ #
+ # See https://matrix-org.github.io/synapse/develop/modules.html for more
+ # documentation on how to configure or create custom modules for Synapse.
+ #
+ modules:
+ # - module: my_super_module.MySuperClass
+ # config:
+ # do_thing: true
+ # - module: my_other_super_module.SomeClass
+ # config: {}
+ """
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index c78a83ab..2f77d670 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -248,6 +248,10 @@ class ContentRepositoryConfig(Config):
# The largest allowed upload size in bytes
#
+ # If you are using a reverse proxy you may also need to set this value in
+ # your reverse proxy's config. Notably Nginx has a small max body size by default.
+ # See https://matrix-org.github.io/synapse/develop/reverse_proxy.html.
+ #
#max_upload_size: 50M
# Maximum number of pixels that will be thumbnailed
diff --git a/synapse/config/server.py b/synapse/config/server.py
index c290a35a..0833a5f7 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -397,19 +397,22 @@ class ServerConfig(Config):
self.ip_range_whitelist = generate_ip_set(
config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",)
)
-
# The federation_ip_range_blacklist is used for backwards-compatibility
- # and only applies to federation and identity servers. If it is not given,
- # default to ip_range_blacklist.
- federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", ip_range_blacklist
- )
- # Always blacklist 0.0.0.0, ::
- self.federation_ip_range_blacklist = generate_ip_set(
- federation_ip_range_blacklist,
- ["0.0.0.0", "::"],
- config_path=("federation_ip_range_blacklist",),
- )
+ # and only applies to federation and identity servers.
+ if "federation_ip_range_blacklist" in config:
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist = generate_ip_set(
+ config["federation_ip_range_blacklist"],
+ ["0.0.0.0", "::"],
+ config_path=("federation_ip_range_blacklist",),
+ )
+ # 'federation_ip_range_whitelist' was never a supported configuration option.
+ self.federation_ip_range_whitelist = None
+ else:
+ # No backwards-compatiblity requrired, as federation_ip_range_blacklist
+ # is not given. Default to ip_range_blacklist and ip_range_whitelist.
+ self.federation_ip_range_blacklist = self.ip_range_blacklist
+ self.federation_ip_range_whitelist = self.ip_range_whitelist
# (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 447ba330..d0311d64 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from typing import Any, Dict, List, Tuple
from synapse.config import ConfigError
@@ -19,6 +20,15 @@ from synapse.util.module_loader import load_module
from ._base import Config
+logger = logging.getLogger(__name__)
+
+LEGACY_SPAM_CHECKER_WARNING = """
+This server is using a spam checker module that is implementing the deprecated spam
+checker interface. Please check with the module's maintainer to see if a new version
+supporting Synapse's generic modules system is available.
+For more information, please see https://matrix-org.github.io/synapse/develop/modules.html
+---------------------------------------------------------------------------------------"""
+
class SpamCheckerConfig(Config):
section = "spamchecker"
@@ -43,17 +53,7 @@ class SpamCheckerConfig(Config):
else:
raise ConfigError("spam_checker syntax is incorrect")
- def generate_config_section(self, **kwargs):
- return """\
- # Spam checkers are third-party modules that can block specific actions
- # of local users, such as creating rooms and registering undesirable
- # usernames, as well as remote users by redacting incoming events.
- #
- spam_checker:
- #- module: "my_custom_project.SuperSpamChecker"
- # config:
- # example_option: 'things'
- #- module: "some_other_project.BadEventStopper"
- # config:
- # example_stop_events_from: ['@bad:example.com']
- """
+ # If this configuration is being used in any way, warn the admin that it is going
+ # away soon.
+ if self.spam_checkers:
+ logger.warning(LEGACY_SPAM_CHECKER_WARNING)
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index af645c93..e4346e02 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -74,6 +74,10 @@ class SSOConfig(Config):
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
+ self.sso_update_profile_information = (
+ sso_config.get("update_profile_information") or False
+ )
+
# Attempt to also whitelist the server's login fallback, since that fallback sets
# the redirect URL to itself (so it can process the login token then return
# gracefully to the client). This would make it pointless to ask the user for
@@ -111,6 +115,17 @@ class SSOConfig(Config):
# - https://riot.im/develop
# - https://my.custom.client/
+ # Uncomment to keep a user's profile fields in sync with information from
+ # the identity provider. Currently only syncing the displayname is
+ # supported. Fields are checked on every SSO login, and are updated
+ # if necessary.
+ #
+ # Note that enabling this option will override user profile information,
+ # regardless of whether users have opted-out of syncing that
+ # information when first signing in. Defaults to false.
+ #
+ #update_profile_information: true
+
# Directory in which Synapse will try to find the template files below.
# If not set, or the files named below are not found within the template
# directory, default templates from within the Synapse package will be used.
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 0e9bba53..9a16a8fb 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -14,7 +14,6 @@
import logging
import os
-import warnings
from datetime import datetime
from typing import List, Optional, Pattern
@@ -26,45 +25,12 @@ from synapse.util import glob_to_regex
logger = logging.getLogger(__name__)
-ACME_SUPPORT_ENABLED_WARN = """\
-This server uses Synapse's built-in ACME support. Note that ACME v1 has been
-deprecated by Let's Encrypt, and that Synapse doesn't currently support ACME v2,
-which means that this feature will not work with Synapse installs set up after
-November 2019, and that it may stop working on June 2020 for installs set up
-before that date.
-
-For more info and alternative solutions, see
-https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
---------------------------------------------------------------------------------"""
-
class TlsConfig(Config):
section = "tls"
def read_config(self, config: dict, config_dir_path: str, **kwargs):
- acme_config = config.get("acme", None)
- if acme_config is None:
- acme_config = {}
-
- self.acme_enabled = acme_config.get("enabled", False)
-
- if self.acme_enabled:
- logger.warning(ACME_SUPPORT_ENABLED_WARN)
-
- # hyperlink complains on py2 if this is not a Unicode
- self.acme_url = str(
- acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory")
- )
- self.acme_port = acme_config.get("port", 80)
- self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"])
- self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
- self.acme_domain = acme_config.get("domain", config.get("server_name"))
-
- self.acme_account_key_file = self.abspath(
- acme_config.get("account_key_file", config_dir_path + "/client.key")
- )
-
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
@@ -229,11 +195,9 @@ class TlsConfig(Config):
data_dir_path,
tls_certificate_path,
tls_private_key_path,
- acme_domain,
**kwargs,
):
- """If the acme_domain is specified acme will be enabled.
- If the TLS paths are not specified the default will be certs in the
+ """If the TLS paths are not specified the default will be certs in the
config directory"""
base_key_name = os.path.join(config_dir_path, server_name)
@@ -243,28 +207,15 @@ class TlsConfig(Config):
"Please specify both a cert path and a key path or neither."
)
- tls_enabled = (
- "" if tls_certificate_path and tls_private_key_path or acme_domain else "#"
- )
+ tls_enabled = "" if tls_certificate_path and tls_private_key_path else "#"
if not tls_certificate_path:
tls_certificate_path = base_key_name + ".tls.crt"
if not tls_private_key_path:
tls_private_key_path = base_key_name + ".tls.key"
- acme_enabled = bool(acme_domain)
- acme_domain = "matrix.example.com"
-
- default_acme_account_file = os.path.join(data_dir_path, "acme_account.key")
-
- # this is to avoid the max line length. Sorrynotsorry
- proxypassline = (
- "ProxyPass /.well-known/acme-challenge "
- "http://localhost:8009/.well-known/acme-challenge"
- )
-
# flake8 doesn't recognise that variables are used in the below string
- _ = tls_enabled, proxypassline, acme_enabled, default_acme_account_file
+ _ = tls_enabled
return (
"""\
@@ -274,13 +225,9 @@ class TlsConfig(Config):
# This certificate, as of Synapse 1.0, will need to be a valid and verifiable
# certificate, signed by a recognised Certificate Authority.
#
- # See 'ACME support' below to enable auto-provisioning this certificate via
- # Let's Encrypt.
- #
- # If supplying your own, be sure to use a `.pem` file that includes the
- # full certificate chain including any intermediate certificates (for
- # instance, if using certbot, use `fullchain.pem` as your certificate,
- # not `cert.pem`).
+ # Be sure to use a `.pem` file that includes the full certificate chain including
+ # any intermediate certificates (for instance, if using certbot, use
+ # `fullchain.pem` as your certificate, not `cert.pem`).
#
%(tls_enabled)stls_certificate_path: "%(tls_certificate_path)s"
@@ -330,80 +277,6 @@ class TlsConfig(Config):
# - myCA1.pem
# - myCA2.pem
# - myCA3.pem
-
- # ACME support: This will configure Synapse to request a valid TLS certificate
- # for your configured `server_name` via Let's Encrypt.
- #
- # Note that ACME v1 is now deprecated, and Synapse currently doesn't support
- # ACME v2. This means that this feature currently won't work with installs set
- # up after November 2019. For more info, and alternative solutions, see
- # https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
- #
- # Note that provisioning a certificate in this way requires port 80 to be
- # routed to Synapse so that it can complete the http-01 ACME challenge.
- # By default, if you enable ACME support, Synapse will attempt to listen on
- # port 80 for incoming http-01 challenges - however, this will likely fail
- # with 'Permission denied' or a similar error.
- #
- # There are a couple of potential solutions to this:
- #
- # * If you already have an Apache, Nginx, or similar listening on port 80,
- # you can configure Synapse to use an alternate port, and have your web
- # server forward the requests. For example, assuming you set 'port: 8009'
- # below, on Apache, you would write:
- #
- # %(proxypassline)s
- #
- # * Alternatively, you can use something like `authbind` to give Synapse
- # permission to listen on port 80.
- #
- acme:
- # ACME support is disabled by default. Set this to `true` and uncomment
- # tls_certificate_path and tls_private_key_path above to enable it.
- #
- enabled: %(acme_enabled)s
-
- # Endpoint to use to request certificates. If you only want to test,
- # use Let's Encrypt's staging url:
- # https://acme-staging.api.letsencrypt.org/directory
- #
- #url: https://acme-v01.api.letsencrypt.org/directory
-
- # Port number to listen on for the HTTP-01 challenge. Change this if
- # you are forwarding connections through Apache/Nginx/etc.
- #
- port: 80
-
- # Local addresses to listen on for incoming connections.
- # Again, you may want to change this if you are forwarding connections
- # through Apache/Nginx/etc.
- #
- bind_addresses: ['::', '0.0.0.0']
-
- # How many days remaining on a certificate before it is renewed.
- #
- reprovision_threshold: 30
-
- # The domain that the certificate should be for. Normally this
- # should be the same as your Matrix domain (i.e., 'server_name'), but,
- # by putting a file at 'https://<server_name>/.well-known/matrix/server',
- # you can delegate incoming traffic to another server. If you do that,
- # you should give the target of the delegation here.
- #
- # For example: if your 'server_name' is 'example.com', but
- # 'https://example.com/.well-known/matrix/server' delegates to
- # 'matrix.example.com', you should put 'matrix.example.com' here.
- #
- # If not set, defaults to your 'server_name'.
- #
- domain: %(acme_domain)s
-
- # file to use for the account key. This will be generated if it doesn't
- # exist.
- #
- # If unspecified, we will use CONFDIR/client.key.
- #
- account_key_file: %(default_acme_account_file)s
"""
# Lowercase the string representation of boolean values
% {
@@ -415,8 +288,6 @@ class TlsConfig(Config):
def read_tls_certificate(self) -> crypto.X509:
"""Reads the TLS certificate from the configured file, and returns it
- Also checks if it is self-signed, and warns if so
-
Returns:
The certificate
"""
@@ -425,16 +296,6 @@ class TlsConfig(Config):
cert_pem = self.read_file(cert_path, "tls_certificate_path")
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
- # Check if it is self-signed, and issue a warning if so.
- if cert.get_issuer() == cert.get_subject():
- warnings.warn(
- (
- "Self-signed TLS certificates will not be accepted by Synapse 1.0. "
- "Please either provide a valid certificate, or use Synapse's ACME "
- "support to provision one."
- )
- )
-
return cert
def read_tls_private_key(self) -> crypto.PKey:
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 70c55656..33d7c602 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -160,6 +160,7 @@ def check(
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
+ # 5. If type is m.room.membership
if event.type == EventTypes.Member:
_is_membership_change_allowed(room_version_obj, event, auth_events)
logger.debug("Allowing! %s", event)
@@ -257,6 +258,11 @@ def _is_membership_change_allowed(
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
+ caller_knocked = (
+ caller
+ and room_version.msc2403_knocking
+ and caller.membership == Membership.KNOCK
+ )
# get info about the target
key = (EventTypes.Member, target_user_id)
@@ -283,6 +289,7 @@ def _is_membership_change_allowed(
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
+ "caller_knocked": caller_knocked,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
@@ -299,9 +306,14 @@ def _is_membership_change_allowed(
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
return
- if Membership.JOIN != membership:
+ # Require the user to be in the room for membership changes other than join/knock.
+ if Membership.JOIN != membership and (
+ RoomVersion.msc2403_knocking and Membership.KNOCK != membership
+ ):
+ # If the user has been invited or has knocked, they are allowed to change their
+ # membership event to leave
if (
- caller_invited
+ (caller_invited or caller_knocked)
and Membership.LEAVE == membership
and target_user_id == event.user_id
):
@@ -339,7 +351,9 @@ def _is_membership_change_allowed(
and join_rule == JoinRules.MSC3083_RESTRICTED
):
pass
- elif join_rule == JoinRules.INVITE:
+ elif join_rule == JoinRules.INVITE or (
+ room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+ ):
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
@@ -358,6 +372,17 @@ def _is_membership_change_allowed(
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
+ elif room_version.msc2403_knocking and Membership.KNOCK == membership:
+ if join_rule != JoinRules.KNOCK:
+ raise AuthError(403, "You don't have permission to knock")
+ elif target_user_id != event.user_id:
+ raise AuthError(403, "You cannot knock for other users")
+ elif target_in_room:
+ raise AuthError(403, "You cannot knock on a room you are already in")
+ elif caller_invited:
+ raise AuthError(403, "You are already invited to this room")
+ elif target_banned:
+ raise AuthError(403, "You are banned from this room")
else:
raise AuthError(500, "Unknown membership %s" % membership)
@@ -718,7 +743,7 @@ def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
if event.type == EventTypes.Member:
membership = event.content["membership"]
- if membership in [Membership.JOIN, Membership.INVITE]:
+ if membership in [Membership.JOIN, Membership.INVITE, Membership.KNOCK]:
auth_types.add((EventTypes.JoinRules, ""))
auth_types.add((EventTypes.Member, event.state_key))
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index c8b52cbc..0cb9c1cc 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -119,6 +119,7 @@ class _EventInternalMetadata:
redacted = DictProperty("redacted") # type: bool
txn_id = DictProperty("txn_id") # type: str
token_id = DictProperty("token_id") # type: str
+ historical = DictProperty("historical") # type: bool
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
@@ -204,6 +205,14 @@ class _EventInternalMetadata:
"""
return self._dict.get("redacted", False)
+ def is_historical(self) -> bool:
+ """Whether this is a historical message.
+ This is used by the batchsend historical message endpoint and
+ is needed to and mark the event as backfilled and skip some checks
+ like push notifications.
+ """
+ return self._dict.get("historical", False)
+
class EventBase(metaclass=abc.ABCMeta):
@property
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 5793553a..81bf8615 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from typing import Any, Dict, List, Optional, Tuple, Union
import attr
@@ -33,6 +34,8 @@ from synapse.types import EventID, JsonDict
from synapse.util import Clock
from synapse.util.stringutils import random_string
+logger = logging.getLogger(__name__)
+
@attr.s(slots=True, cmp=False, frozen=True)
class EventBuilder:
@@ -100,6 +103,7 @@ class EventBuilder:
self,
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
+ depth: Optional[int] = None,
) -> EventBase:
"""Transform into a fully signed and hashed event
@@ -108,6 +112,9 @@ class EventBuilder:
auth_event_ids: The event IDs to use as the auth events.
Should normally be set to None, which will cause them to be calculated
based on the room state at the prev_events.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
Returns:
The signed and hashed event.
@@ -131,8 +138,14 @@ class EventBuilder:
auth_events = auth_event_ids
prev_events = prev_event_ids
- old_depth = await self._store.get_max_depth_of(prev_event_ids)
- depth = old_depth + 1
+ # Otherwise, progress the depth as normal
+ if depth is None:
+ (
+ _,
+ most_recent_prev_event_depth,
+ ) = await self._store.get_max_depth_of(prev_event_ids)
+
+ depth = most_recent_prev_event_depth + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index d5fa1950..efec16c2 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,7 +15,18 @@
import inspect
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
@@ -29,20 +40,187 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
+ ["synapse.events.EventBase"],
+ Awaitable[Union[bool, str]],
+]
+USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
+USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
+USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
+USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
+LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
+ [
+ Optional[dict],
+ Optional[str],
+ Collection[Tuple[str, str]],
+ ],
+ Awaitable[RegistrationBehaviour],
+]
+CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
+ [
+ Optional[dict],
+ Optional[str],
+ Collection[Tuple[str, str]],
+ Optional[str],
+ ],
+ Awaitable[RegistrationBehaviour],
+]
+CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
+ [ReadableFileWrapper, FileInfo],
+ Awaitable[bool],
+]
+
+
+def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
+ """Wrapper that loads spam checkers configured using the old configuration, and
+ registers the spam checker hooks they implement.
+ """
+ spam_checkers = [] # type: List[Any]
+ api = hs.get_module_api()
+ for module, config in hs.config.spam_checkers:
+ # Older spam checkers don't accept the `api` argument, so we
+ # try and detect support.
+ spam_args = inspect.getfullargspec(module)
+ if "api" in spam_args.args:
+ spam_checkers.append(module(config=config, api=api))
+ else:
+ spam_checkers.append(module(config=config))
+
+ # The known spam checker hooks. If a spam checker module implements a method
+ # which name appears in this set, we'll want to register it.
+ spam_checker_methods = {
+ "check_event_for_spam",
+ "user_may_invite",
+ "user_may_create_room",
+ "user_may_create_room_alias",
+ "user_may_publish_room",
+ "check_username_for_spam",
+ "check_registration_for_spam",
+ "check_media_file_for_spam",
+ }
+
+ for spam_checker in spam_checkers:
+ # Methods on legacy spam checkers might not be async, so we wrap them around a
+ # wrapper that will call maybe_awaitable on the result.
+ def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+ # f might be None if the callback isn't implemented by the module. In this
+ # case we don't want to register a callback at all so we return None.
+ if f is None:
+ return None
+
+ wrapped_func = f
+
+ if f.__name__ == "check_registration_for_spam":
+ checker_args = inspect.signature(f)
+ if len(checker_args.parameters) == 3:
+ # Backwards compatibility; some modules might implement a hook that
+ # doesn't expect a 4th argument. In this case, wrap it in a function
+ # that gives it only 3 arguments and drops the auth_provider_id on
+ # the floor.
+ def wrapper(
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str],
+ ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
+ # We've already made sure f is not None above, but mypy doesn't
+ # do well across function boundaries so we need to tell it f is
+ # definitely not None.
+ assert f is not None
+
+ return f(
+ email_threepid,
+ username,
+ request_info,
+ )
+
+ wrapped_func = wrapper
+ elif len(checker_args.parameters) != 4:
+ raise RuntimeError(
+ "Bad signature for callback check_registration_for_spam",
+ )
+
+ def run(*args, **kwargs):
+ # mypy doesn't do well across function boundaries so we need to tell it
+ # wrapped_func is definitely not None.
+ assert wrapped_func is not None
+
+ return maybe_awaitable(wrapped_func(*args, **kwargs))
+
+ return run
+
+ # Register the hooks through the module API.
+ hooks = {
+ hook: async_wrapper(getattr(spam_checker, hook, None))
+ for hook in spam_checker_methods
+ }
+
+ api.register_spam_checker_callbacks(**hooks)
+
class SpamChecker:
- def __init__(self, hs: "synapse.server.HomeServer"):
- self.spam_checkers = [] # type: List[Any]
- api = hs.get_module_api()
-
- for module, config in hs.config.spam_checkers:
- # Older spam checkers don't accept the `api` argument, so we
- # try and detect support.
- spam_args = inspect.getfullargspec(module)
- if "api" in spam_args.args:
- self.spam_checkers.append(module(config=config, api=api))
- else:
- self.spam_checkers.append(module(config=config))
+ def __init__(self):
+ self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
+ self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
+ self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
+ self._user_may_create_room_alias_callbacks: List[
+ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
+ ] = []
+ self._user_may_publish_room_callbacks: List[USER_MAY_PUBLISH_ROOM_CALLBACK] = []
+ self._check_username_for_spam_callbacks: List[
+ CHECK_USERNAME_FOR_SPAM_CALLBACK
+ ] = []
+ self._check_registration_for_spam_callbacks: List[
+ CHECK_REGISTRATION_FOR_SPAM_CALLBACK
+ ] = []
+ self._check_media_file_for_spam_callbacks: List[
+ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
+ ] = []
+
+ def register_callbacks(
+ self,
+ check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+ user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
+ user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
+ user_may_create_room_alias: Optional[
+ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
+ ] = None,
+ user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None,
+ check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None,
+ check_registration_for_spam: Optional[
+ CHECK_REGISTRATION_FOR_SPAM_CALLBACK
+ ] = None,
+ check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+ ):
+ """Register callbacks from module for each hook."""
+ if check_event_for_spam is not None:
+ self._check_event_for_spam_callbacks.append(check_event_for_spam)
+
+ if user_may_invite is not None:
+ self._user_may_invite_callbacks.append(user_may_invite)
+
+ if user_may_create_room is not None:
+ self._user_may_create_room_callbacks.append(user_may_create_room)
+
+ if user_may_create_room_alias is not None:
+ self._user_may_create_room_alias_callbacks.append(
+ user_may_create_room_alias,
+ )
+
+ if user_may_publish_room is not None:
+ self._user_may_publish_room_callbacks.append(user_may_publish_room)
+
+ if check_username_for_spam is not None:
+ self._check_username_for_spam_callbacks.append(check_username_for_spam)
+
+ if check_registration_for_spam is not None:
+ self._check_registration_for_spam_callbacks.append(
+ check_registration_for_spam,
+ )
+
+ if check_media_file_for_spam is not None:
+ self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
@@ -60,9 +238,10 @@ class SpamChecker:
True or a string if the event is spammy. If a string is returned it
will be used as the error message returned to the user.
"""
- for spam_checker in self.spam_checkers:
- if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
- return True
+ for callback in self._check_event_for_spam_callbacks:
+ res = await callback(event) # type: Union[bool, str]
+ if res:
+ return res
return False
@@ -81,15 +260,8 @@ class SpamChecker:
Returns:
True if the user may send an invite, otherwise False
"""
- for spam_checker in self.spam_checkers:
- if (
- await maybe_awaitable(
- spam_checker.user_may_invite(
- inviter_userid, invitee_userid, room_id
- )
- )
- is False
- ):
+ for callback in self._user_may_invite_callbacks:
+ if await callback(inviter_userid, invitee_userid, room_id) is False:
return False
return True
@@ -105,11 +277,8 @@ class SpamChecker:
Returns:
True if the user may create a room, otherwise False
"""
- for spam_checker in self.spam_checkers:
- if (
- await maybe_awaitable(spam_checker.user_may_create_room(userid))
- is False
- ):
+ for callback in self._user_may_create_room_callbacks:
+ if await callback(userid) is False:
return False
return True
@@ -128,13 +297,8 @@ class SpamChecker:
Returns:
True if the user may create a room alias, otherwise False
"""
- for spam_checker in self.spam_checkers:
- if (
- await maybe_awaitable(
- spam_checker.user_may_create_room_alias(userid, room_alias)
- )
- is False
- ):
+ for callback in self._user_may_create_room_alias_callbacks:
+ if await callback(userid, room_alias) is False:
return False
return True
@@ -151,13 +315,8 @@ class SpamChecker:
Returns:
True if the user may publish the room, otherwise False
"""
- for spam_checker in self.spam_checkers:
- if (
- await maybe_awaitable(
- spam_checker.user_may_publish_room(userid, room_id)
- )
- is False
- ):
+ for callback in self._user_may_publish_room_callbacks:
+ if await callback(userid, room_id) is False:
return False
return True
@@ -177,15 +336,11 @@ class SpamChecker:
Returns:
True if the user is spammy.
"""
- for spam_checker in self.spam_checkers:
- # For backwards compatibility, only run if the method exists on the
- # spam checker
- checker = getattr(spam_checker, "check_username_for_spam", None)
- if checker:
- # Make a copy of the user profile object to ensure the spam checker
- # cannot modify it.
- if await maybe_awaitable(checker(user_profile.copy())):
- return True
+ for callback in self._check_username_for_spam_callbacks:
+ # Make a copy of the user profile object to ensure the spam checker cannot
+ # modify it.
+ if await callback(user_profile.copy()):
+ return True
return False
@@ -211,33 +366,13 @@ class SpamChecker:
Enum for how the request should be handled
"""
- for spam_checker in self.spam_checkers:
- # For backwards compatibility, only run if the method exists on the
- # spam checker
- checker = getattr(spam_checker, "check_registration_for_spam", None)
- if checker:
- # Provide auth_provider_id if the function supports it
- checker_args = inspect.signature(checker)
- if len(checker_args.parameters) == 4:
- d = checker(
- email_threepid,
- username,
- request_info,
- auth_provider_id,
- )
- elif len(checker_args.parameters) == 3:
- d = checker(email_threepid, username, request_info)
- else:
- logger.error(
- "Invalid signature for %s.check_registration_for_spam. Denying registration",
- spam_checker.__module__,
- )
- return RegistrationBehaviour.DENY
-
- behaviour = await maybe_awaitable(d)
- assert isinstance(behaviour, RegistrationBehaviour)
- if behaviour != RegistrationBehaviour.ALLOW:
- return behaviour
+ for callback in self._check_registration_for_spam_callbacks:
+ behaviour = await (
+ callback(email_threepid, username, request_info, auth_provider_id)
+ )
+ assert isinstance(behaviour, RegistrationBehaviour)
+ if behaviour != RegistrationBehaviour.ALLOW:
+ return behaviour
return RegistrationBehaviour.ALLOW
@@ -275,13 +410,9 @@ class SpamChecker:
allowed.
"""
- for spam_checker in self.spam_checkers:
- # For backwards compatibility, only run if the method exists on the
- # spam checker
- checker = getattr(spam_checker, "check_media_file_for_spam", None)
- if checker:
- spam = await maybe_awaitable(checker(file_wrapper, file_info))
- if spam:
- return True
+ for callback in self._check_media_file_for_spam_callbacks:
+ spam = await callback(file_wrapper, file_info)
+ if spam:
+ return True
return False
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 7d7cd9aa..ec96999e 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -242,6 +242,7 @@ def format_event_for_client_v1(d):
"replaces_state",
"prev_content",
"invite_room_state",
+ "knock_room_state",
)
for key in copy_keys:
if key in d["unsigned"]:
@@ -278,7 +279,7 @@ def serialize_event(
event_format=format_event_for_client_v1,
token_id=None,
only_event_fields=None,
- is_invite=False,
+ include_stripped_room_state=False,
):
"""Serialize event for clients
@@ -289,8 +290,10 @@ def serialize_event(
event_format
token_id
only_event_fields
- is_invite (bool): Whether this is an invite that is being sent to the
- invitee
+ include_stripped_room_state (bool): Some events can have stripped room state
+ stored in the `unsigned` field. This is required for invite and knock
+ functionality. If this option is False, that state will be removed from the
+ event before it is returned. Otherwise, it will be kept.
Returns:
dict
@@ -322,11 +325,13 @@ def serialize_event(
if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id
- # If this is an invite for somebody else, then we don't care about the
- # invite_room_state as that's meant solely for the invitee. Other clients
- # will already have the state since they're in the room.
- if not is_invite:
+ # invite_room_state and knock_room_state are a list of stripped room state events
+ # that are meant to provide metadata about a room to an invitee/knocker. They are
+ # intended to only be included in specific circumstances, such as down sync, and
+ # should not be included in any other case.
+ if not include_stripped_room_state:
d["unsigned"].pop("invite_room_state", None)
+ d["unsigned"].pop("knock_room_state", None)
if as_client_event:
d = event_format(d)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 1076ebc0..ed09c6af 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1,4 +1,5 @@
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -619,7 +620,8 @@ class FederationClient(FederationBase):
SynapseError: if the chosen remote server returns a 300/400 code, or
no servers successfully handle the request.
"""
- valid_memberships = {Membership.JOIN, Membership.LEAVE}
+ valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
+
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
@@ -638,6 +640,13 @@ class FederationClient(FederationBase):
if not room_version:
raise UnsupportedRoomVersionError()
+ if not room_version.msc2403_knocking and membership == Membership.KNOCK:
+ raise SynapseError(
+ 400,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response")
@@ -946,6 +955,62 @@ class FederationClient(FederationBase):
# content.
return resp[1]
+ async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
+ """Attempts to send a knock event to given a list of servers. Iterates
+ through the list until one attempt succeeds.
+
+ Doing so will cause the remote server to add the event to the graph,
+ and send the event out to the rest of the federation.
+
+ Args:
+ destinations: A list of candidate homeservers which are likely to be
+ participating in the room.
+ pdu: The event to be sent.
+
+ Returns:
+ The remote homeserver return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+
+ Raises:
+ SynapseError: If the chosen remote server returns a 3xx/4xx code.
+ RuntimeError: If no servers were reachable.
+ """
+
+ async def send_request(destination: str) -> JsonDict:
+ return await self._do_send_knock(destination, pdu)
+
+ return await self._try_destination_list(
+ "send_knock", destinations, send_request
+ )
+
+ async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict:
+ """Send a knock event to a remote homeserver.
+
+ Args:
+ destination: The homeserver to send to.
+ pdu: The event to send.
+
+ Returns:
+ The remote homeserver can optionally return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+ """
+ time_now = self._clock.time_msec()
+
+ return await self.transport_layer.send_knock_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
async def get_public_rooms(
self,
remote_server: str,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ace30aa4..2b07f185 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -129,7 +129,7 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(
hs.get_clock(), "state_resp", timeout_ms=30000
- ) # type: ResponseCache[Tuple[str, str]]
+ ) # type: ResponseCache[Tuple[str, Optional[str]]]
self._state_ids_resp_cache = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
@@ -138,6 +138,8 @@ class FederationServer(FederationBase):
hs.config.federation.federation_metrics_domains
)
+ self._room_prejoin_state_types = hs.config.api.room_prejoin_state
+
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
@@ -406,7 +408,7 @@ class FederationServer(FederationBase):
)
async def on_room_state_request(
- self, origin: str, room_id: str, event_id: str
+ self, origin: str, room_id: str, event_id: Optional[str]
) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -463,7 +465,7 @@ class FederationServer(FederationBase):
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
- self, room_id: str, event_id: str
+ self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
if event_id:
pdus = await self.handler.get_state_for_pdu(
@@ -586,6 +588,103 @@ class FederationServer(FederationBase):
await self.handler.on_send_leave_request(origin, pdu)
return {}
+ async def on_make_knock_request(
+ self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
+ ) -> Dict[str, Union[EventBase, str]]:
+ """We've received a /make_knock/ request, so we create a partial knock
+ event for the room and hand that back, along with the room version, to the knocking
+ homeserver. We do *not* persist or process this event until the other server has
+ signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: The room to create the knock event in.
+ user_id: The user to create the knock for.
+ supported_versions: The room versions supported by the requesting server.
+
+ Returns:
+ The partial knock event.
+ """
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, room_id)
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # Check that this room version is supported by the remote homeserver
+ if room_version.identifier not in supported_versions:
+ logger.warning(
+ "Room version %s not in %s", room_version.identifier, supported_versions
+ )
+ raise IncompatibleRoomVersionError(room_version=room_version.identifier)
+
+ # Check that this room supports knocking as defined by its room version
+ if not room_version.msc2403_knocking:
+ raise SynapseError(
+ 403,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
+ time_now = self._clock.time_msec()
+ return {
+ "event": pdu.get_pdu_json(time_now),
+ "room_version": room_version.identifier,
+ }
+
+ async def on_send_knock_request(
+ self,
+ origin: str,
+ content: JsonDict,
+ room_id: str,
+ ) -> Dict[str, List[JsonDict]]:
+ """
+ We have received a knock event for a room. Verify and send the event into the room
+ on the knocking homeserver's behalf. Then reply with some stripped state from the
+ room for the knockee.
+
+ Args:
+ origin: The remote homeserver of the knocking user.
+ content: The content of the request.
+ room_id: The ID of the room to knock on.
+
+ Returns:
+ The stripped room state.
+ """
+ logger.debug("on_send_knock_request: content: %s", content)
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # Check that this room supports knocking as defined by its room version
+ if not room_version.msc2403_knocking:
+ raise SynapseError(
+ 403,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ pdu = event_from_pdu_json(content, room_version)
+
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
+
+ logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
+
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+
+ # Handle the event, and retrieve the EventContext
+ event_context = await self.handler.on_send_knock_request(origin, pdu)
+
+ # Retrieve stripped state events from the room and send them back to the remote
+ # server. This will allow the remote server's clients to display information
+ # related to the room while the knock request is pending.
+ stripped_room_state = (
+ await self.store.get_stripped_room_state_from_event_context(
+ event_context, self._room_prejoin_state_types
+ )
+ )
+ return {"knock_state_events": stripped_room_state}
+
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 5b4f5d17..c9e7c574 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,5 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -220,7 +220,8 @@ class TransportLayerClient:
Fails with ``FederationDeniedError`` if the remote destination
is not in our federation whitelist
"""
- valid_memberships = {Membership.JOIN, Membership.LEAVE}
+ valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
+
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
@@ -322,6 +323,40 @@ class TransportLayerClient:
return response
@log_function
+ async def send_knock_v1(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
+ """
+ Sends a signed knock membership event to a remote server. This is the second
+ step for knocking after make_knock.
+
+ Args:
+ destination: The remote homeserver.
+ room_id: The ID of the room to knock on.
+ event_id: The ID of the knock membership event that we're sending.
+ content: The knock membership event that we're sending. Note that this is not the
+ `content` field of the membership event, but the entire signed membership event
+ itself represented as a JSON dict.
+
+ Returns:
+ The remote homeserver can optionally return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+ """
+ path = _create_v1_path("/send_knock/%s/%s", room_id, event_id)
+
+ return await self.client.put_json(
+ destination=destination, path=path, data=content
+ )
+
+ @log_function
async def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5756fcb5..bed47f8a 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,6 +1,5 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import functools
import logging
import re
@@ -28,13 +26,16 @@ from synapse.api.urls import (
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
+from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
parse_string_from_args,
+ parse_strings_from_args,
)
+from synapse.logging import opentracing
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
SynapseTags,
@@ -275,10 +276,17 @@ class BaseFederationServlet:
RATELIMIT = True # Whether to rate limit requests or not
- def __init__(self, handler, authenticator, ratelimiter, server_name):
- self.handler = handler
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ self.hs = hs
self.authenticator = authenticator
self.ratelimiter = ratelimiter
+ self.server_name = server_name
def _wrap(self, func):
authenticator = self.authenticator
@@ -338,6 +346,8 @@ class BaseFederationServlet:
)
with scope:
+ opentracing.inject_response_headers(request.responseHeaders)
+
if origin and self.RATELIMIT:
with ratelimiter.ratelimit(origin) as d:
await d
@@ -375,17 +385,30 @@ class BaseFederationServlet:
)
-class FederationSendServlet(BaseFederationServlet):
+class BaseFederationServerServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a federation server handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_federation_server()
+
+
+class FederationSendServlet(BaseFederationServerServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?"
# We ratelimit manually in the handler as we queue up the requests and we
# don't want to fill up the ratelimiter with blocked requests.
RATELIMIT = False
- def __init__(self, handler, server_name, **kwargs):
- super().__init__(handler, server_name=server_name, **kwargs)
- self.server_name = server_name
-
# This is when someone is trying to send us a bunch of data.
async def on_PUT(self, origin, content, query, transaction_id):
"""Called on PUT /send/<transaction_id>/
@@ -434,7 +457,7 @@ class FederationSendServlet(BaseFederationServlet):
return code, response
-class FederationEventServlet(BaseFederationServlet):
+class FederationEventServlet(BaseFederationServerServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
@@ -442,7 +465,7 @@ class FederationEventServlet(BaseFederationServlet):
return await self.handler.on_pdu_request(origin, event_id)
-class FederationStateV1Servlet(BaseFederationServlet):
+class FederationStateV1Servlet(BaseFederationServerServlet):
PATH = "/state/(?P<room_id>[^/]*)/?"
# This is when someone asks for all data for a given room.
@@ -454,7 +477,7 @@ class FederationStateV1Servlet(BaseFederationServlet):
)
-class FederationStateIdsServlet(BaseFederationServlet):
+class FederationStateIdsServlet(BaseFederationServerServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, room_id):
@@ -465,7 +488,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
)
-class FederationBackfillServlet(BaseFederationServlet):
+class FederationBackfillServlet(BaseFederationServerServlet):
PATH = "/backfill/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, room_id):
@@ -478,7 +501,7 @@ class FederationBackfillServlet(BaseFederationServlet):
return await self.handler.on_backfill_request(origin, room_id, versions, limit)
-class FederationQueryServlet(BaseFederationServlet):
+class FederationQueryServlet(BaseFederationServerServlet):
PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query
@@ -488,7 +511,7 @@ class FederationQueryServlet(BaseFederationServlet):
return await self.handler.on_query_request(query_type, args)
-class FederationMakeJoinServlet(BaseFederationServlet):
+class FederationMakeJoinServlet(BaseFederationServerServlet):
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, _content, query, room_id, user_id):
@@ -518,7 +541,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
return 200, content
-class FederationMakeLeaveServlet(BaseFederationServlet):
+class FederationMakeLeaveServlet(BaseFederationServerServlet):
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, user_id):
@@ -526,7 +549,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationV1SendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -534,7 +557,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2SendLeaveServlet(BaseFederationServlet):
+class FederationV2SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -544,14 +567,38 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationEventAuthServlet(BaseFederationServlet):
+class FederationMakeKnockServlet(BaseFederationServerServlet):
+ PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
+
+ async def on_GET(self, origin, content, query, room_id, user_id):
+ try:
+ # Retrieve the room versions the remote homeserver claims to support
+ supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
+ except KeyError:
+ raise SynapseError(400, "Missing required query parameter 'ver'")
+
+ content = await self.handler.on_make_knock_request(
+ origin, room_id, user_id, supported_versions=supported_versions
+ )
+ return 200, content
+
+
+class FederationV1SendKnockServlet(BaseFederationServerServlet):
+ PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_knock_request(origin, content, room_id)
+ return 200, content
+
+
+class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, event_id):
return await self.handler.on_event_auth(origin, room_id, event_id)
-class FederationV1SendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -561,7 +608,7 @@ class FederationV1SendJoinServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2SendJoinServlet(BaseFederationServlet):
+class FederationV2SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -573,7 +620,7 @@ class FederationV2SendJoinServlet(BaseFederationServlet):
return 200, content
-class FederationV1InviteServlet(BaseFederationServlet):
+class FederationV1InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -590,7 +637,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2InviteServlet(BaseFederationServlet):
+class FederationV2InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -614,7 +661,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
return 200, content
-class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
+class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
@@ -622,21 +669,21 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
return 200, {}
-class FederationClientKeysQueryServlet(BaseFederationServlet):
+class FederationClientKeysQueryServlet(BaseFederationServerServlet):
PATH = "/user/keys/query"
async def on_POST(self, origin, content, query):
return await self.handler.on_query_client_keys(origin, content)
-class FederationUserDevicesQueryServlet(BaseFederationServlet):
+class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, user_id):
return await self.handler.on_query_user_devices(origin, user_id)
-class FederationClientKeysClaimServlet(BaseFederationServlet):
+class FederationClientKeysClaimServlet(BaseFederationServerServlet):
PATH = "/user/keys/claim"
async def on_POST(self, origin, content, query):
@@ -644,7 +691,7 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
return 200, response
-class FederationGetMissingEventsServlet(BaseFederationServlet):
+class FederationGetMissingEventsServlet(BaseFederationServerServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@@ -664,7 +711,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
return 200, content
-class On3pidBindServlet(BaseFederationServlet):
+class On3pidBindServlet(BaseFederationServerServlet):
PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@@ -694,7 +741,7 @@ class On3pidBindServlet(BaseFederationServlet):
return 200, {}
-class OpenIdUserInfo(BaseFederationServlet):
+class OpenIdUserInfo(BaseFederationServerServlet):
"""
Exchange a bearer token for information about a user.
@@ -770,8 +817,16 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
- def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
- super().__init__(handler, authenticator, ratelimiter, server_name)
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ allow_access: bool,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_room_list_handler()
self.allow_access = allow_access
async def on_GET(self, origin, content, query):
@@ -856,7 +911,24 @@ class FederationVersionServlet(BaseFederationServlet):
)
-class FederationGroupsProfileServlet(BaseFederationServlet):
+class BaseGroupsServerServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a groups server handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_server_handler()
+
+
+class FederationGroupsProfileServlet(BaseGroupsServerServlet):
"""Get/set the basic profile of a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/profile"
@@ -882,7 +954,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsSummaryServlet(BaseFederationServlet):
+class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary"
async def on_GET(self, origin, content, query, group_id):
@@ -895,7 +967,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsRoomsServlet(BaseFederationServlet):
+class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
"""Get the rooms in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@@ -910,7 +982,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
"""Add/remove room from group"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@@ -938,7 +1010,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
+class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
"""Update room config in group"""
PATH = (
@@ -958,7 +1030,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
return 200, result
-class FederationGroupsUsersServlet(BaseFederationServlet):
+class FederationGroupsUsersServlet(BaseGroupsServerServlet):
"""Get the users in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/users"
@@ -973,7 +1045,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
"""Get the users that have been invited to a group"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@@ -990,7 +1062,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsInviteServlet(BaseFederationServlet):
+class FederationGroupsInviteServlet(BaseGroupsServerServlet):
"""Ask a group server to invite someone to the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@@ -1007,7 +1079,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet):
"""Accept an invitation from the group server"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@@ -1021,7 +1093,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsJoinServlet(BaseFederationServlet):
+class FederationGroupsJoinServlet(BaseGroupsServerServlet):
"""Attempt to join a group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@@ -1035,7 +1107,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
"""Leave or kick a user from the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@@ -1052,7 +1124,24 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+class BaseGroupsLocalServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a groups local handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_local_handler()
+
+
+class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet):
"""A group server has invited a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@@ -1061,12 +1150,16 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
+ assert isinstance(
+ self.handler, GroupsLocalHandler
+ ), "Workers cannot handle group invites."
+
new_content = await self.handler.on_invite(group_id, user_id, content)
return 200, new_content
-class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
"""A group server has removed a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@@ -1075,6 +1168,10 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
+ assert isinstance(
+ self.handler, GroupsLocalHandler
+ ), "Workers cannot handle group removals."
+
new_content = await self.handler.user_removed_from_group(
group_id, user_id, content
)
@@ -1087,6 +1184,16 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_attestation_renewer()
+
async def on_POST(self, origin, content, query, group_id, user_id):
# We don't need to check auth here as we check the attestation signatures
@@ -1097,7 +1204,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
"""Add/remove a room from the group summary, with optional category.
Matches both:
@@ -1154,7 +1261,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsCategoriesServlet(BaseFederationServlet):
+class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
"""Get all categories for a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@@ -1169,7 +1276,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsCategoryServlet(BaseFederationServlet):
+class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
"""Add/remove/get a category in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@@ -1222,7 +1329,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsRolesServlet(BaseFederationServlet):
+class FederationGroupsRolesServlet(BaseGroupsServerServlet):
"""Get roles in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@@ -1237,7 +1344,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsRoleServlet(BaseFederationServlet):
+class FederationGroupsRoleServlet(BaseGroupsServerServlet):
"""Add/remove/get a role in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@@ -1290,7 +1397,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
"""Add/remove a user from the group summary, with optional role.
Matches both:
@@ -1345,7 +1452,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet):
"""Get roles in a group"""
PATH = "/get_groups_publicised"
@@ -1358,7 +1465,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
"""Sets whether a group is joinable without an invite or knock"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
@@ -1379,6 +1486,16 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
PATH = "/spaces/(?P<room_id>[^/]*)"
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_space_summary_handler()
+
async def on_GET(
self,
origin: str,
@@ -1444,16 +1561,25 @@ class RoomComplexityServlet(BaseFederationServlet):
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
- async def on_GET(self, origin, content, query, room_id):
-
- store = self.handler.hs.get_datastore()
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self._store = self.hs.get_datastore()
- is_public = await store.is_room_world_readable_or_publicly_joinable(room_id)
+ async def on_GET(self, origin, content, query, room_id):
+ is_public = await self._store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
- complexity = await store.get_room_complexity(room_id)
+ complexity = await self._store.get_room_complexity(room_id)
return 200, complexity
@@ -1482,6 +1608,9 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationSpaceSummaryServlet,
+ FederationV1SendKnockServlet,
+ FederationMakeKnockServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
@@ -1523,6 +1652,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
+
DEFAULT_SERVLET_GROUPS = (
"federation",
"room_list",
@@ -1559,23 +1689,16 @@ def register_servlets(
if "federation" in servlet_groups:
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_federation_server(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
- FederationSpaceSummaryServlet(
- handler=hs.get_space_summary_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
if "openid" in servlet_groups:
for servletclass in OPENID_SERVLET_CLASSES:
servletclass(
- handler=hs.get_federation_server(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1584,7 +1707,7 @@ def register_servlets(
if "room_list" in servlet_groups:
for servletclass in ROOM_LIST_CLASSES:
servletclass(
- handler=hs.get_room_list_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1594,7 +1717,7 @@ def register_servlets(
if "group_server" in servlet_groups:
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_server_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1603,7 +1726,7 @@ def register_servlets(
if "group_local" in servlet_groups:
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_local_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1612,7 +1735,7 @@ def register_servlets(
if "group_attestation" in servlet_groups:
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_attestation_renewer(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
deleted file mode 100644
index 16ab93f5..00000000
--- a/synapse/handlers/acme.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING
-
-import twisted
-import twisted.internet.error
-from twisted.web import server, static
-from twisted.web.resource import Resource
-
-from synapse.app import check_bind_error
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-ACME_REGISTER_FAIL_ERROR = """
---------------------------------------------------------------------------------
-Failed to register with the ACME provider. This is likely happening because the installation
-is new, and ACME v1 has been deprecated by Let's Encrypt and disabled for
-new installations since November 2019.
-At the moment, Synapse doesn't support ACME v2. For more information and alternative
-solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
---------------------------------------------------------------------------------"""
-
-
-class AcmeHandler:
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.reactor = hs.get_reactor()
- self._acme_domain = hs.config.acme_domain
-
- async def start_listening(self) -> None:
- from synapse.handlers import acme_issuing_service
-
- # Configure logging for txacme, if you need to debug
- # from eliot import add_destinations
- # from eliot.twisted import TwistedDestination
- #
- # add_destinations(TwistedDestination())
-
- well_known = Resource()
-
- self._issuer = acme_issuing_service.create_issuing_service(
- self.reactor,
- acme_url=self.hs.config.acme_url,
- account_key_file=self.hs.config.acme_account_key_file,
- well_known_resource=well_known,
- )
-
- responder_resource = Resource()
- responder_resource.putChild(b".well-known", well_known)
- responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain"))
- srv = server.Site(responder_resource)
-
- bind_addresses = self.hs.config.acme_bind_addresses
- for host in bind_addresses:
- logger.info(
- "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
- )
- try:
- self.reactor.listenTCP(
- self.hs.config.acme_port, srv, backlog=50, interface=host
- )
- except twisted.internet.error.CannotListenError as e:
- check_bind_error(e, host, bind_addresses)
-
- # Make sure we are registered to the ACME server. There's no public API
- # for this, it is usually triggered by startService, but since we don't
- # want it to control where we save the certificates, we have to reach in
- # and trigger the registration machinery ourselves.
- self._issuer._registered = False
-
- try:
- await self._issuer._ensure_registered()
- except Exception:
- logger.error(ACME_REGISTER_FAIL_ERROR)
- raise
-
- async def provision_certificate(self) -> None:
-
- logger.warning("Reprovisioning %s", self._acme_domain)
-
- try:
- await self._issuer.issue_cert(self._acme_domain)
- except Exception:
- logger.exception("Fail!")
- raise
- logger.warning("Reprovisioned %s, saving.", self._acme_domain)
- cert_chain = self._issuer.cert_store.certs[self._acme_domain]
-
- try:
- with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
- for x in cert_chain:
- if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
- private_key_file.write(x)
-
- with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
- for x in cert_chain:
- if x.startswith(b"-----BEGIN CERTIFICATE-----"):
- certificate_file.write(x)
- except Exception:
- logger.exception("Failed saving!")
- raise
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
deleted file mode 100644
index a972d3fa..00000000
--- a/synapse/handlers/acme_issuing_service.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Utility function to create an ACME issuing service.
-
-This file contains the unconditional imports on the acme and cryptography bits that we
-only need (and may only have available) if we are doing ACME, so is designed to be
-imported conditionally.
-"""
-import logging
-from typing import Dict, Iterable, List
-
-import attr
-import pem
-from cryptography.hazmat.backends import default_backend
-from cryptography.hazmat.primitives import serialization
-from josepy import JWKRSA
-from josepy.jwa import RS256
-from txacme.challenges import HTTP01Responder
-from txacme.client import Client
-from txacme.interfaces import ICertificateStore
-from txacme.service import AcmeIssuingService
-from txacme.util import generate_private_key
-from zope.interface import implementer
-
-from twisted.internet import defer
-from twisted.internet.interfaces import IReactorTCP
-from twisted.python.filepath import FilePath
-from twisted.python.url import URL
-from twisted.web.resource import IResource
-
-logger = logging.getLogger(__name__)
-
-
-def create_issuing_service(
- reactor: IReactorTCP,
- acme_url: str,
- account_key_file: str,
- well_known_resource: IResource,
-) -> AcmeIssuingService:
- """Create an ACME issuing service, and attach it to a web Resource
-
- Args:
- reactor: twisted reactor
- acme_url: URL to use to request certificates
- account_key_file: where to store the account key
- well_known_resource: web resource for .well-known.
- we will attach a child resource for "acme-challenge".
-
- Returns:
- AcmeIssuingService
- """
- responder = HTTP01Responder()
-
- well_known_resource.putChild(b"acme-challenge", responder.resource)
-
- store = ErsatzStore()
-
- return AcmeIssuingService(
- cert_store=store,
- client_creator=(
- lambda: Client.from_url(
- reactor=reactor,
- url=URL.from_text(acme_url),
- key=load_or_create_client_key(account_key_file),
- alg=RS256,
- )
- ),
- clock=reactor,
- responders=[responder],
- )
-
-
-@attr.s(slots=True)
-@implementer(ICertificateStore)
-class ErsatzStore:
- """
- A store that only stores in memory.
- """
-
- certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
-
- def store(
- self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
- ) -> defer.Deferred:
- self.certs[server_name] = [o.as_bytes() for o in pem_objects]
- return defer.succeed(None)
-
-
-def load_or_create_client_key(key_file: str) -> JWKRSA:
- """Load the ACME account key from a file, creating it if it does not exist.
-
- Args:
- key_file: name of the file to use as the account key
- """
- # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
- # hardcode the 'client.key' filename
- acme_key_file = FilePath(key_file)
- if acme_key_file.exists():
- logger.info("Loading ACME account key from '%s'", acme_key_file)
- key = serialization.load_pem_private_key(
- acme_key_file.getContent(), password=None, backend=default_backend()
- )
- else:
- logger.info("Saving new ACME account key to '%s'", acme_key_file)
- key = generate_private_key("rsa")
- acme_key_file.setContent(
- key.private_bytes(
- encoding=serialization.Encoding.PEM,
- format=serialization.PrivateFormat.TraditionalOpenSSL,
- encryption_algorithm=serialization.NoEncryption(),
- )
- )
- return JWKRSA(key=key)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 8a6666a4..1971e373 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -302,6 +302,7 @@ class AuthHandler(BaseHandler):
request: SynapseRequest,
request_body: Dict[str, Any],
description: str,
+ can_skip_ui_auth: bool = False,
) -> Tuple[dict, Optional[str]]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -320,6 +321,10 @@ class AuthHandler(BaseHandler):
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
+ can_skip_ui_auth: True if the UI auth session timeout applies this
+ action. Should be set to False for any "dangerous"
+ actions (e.g. deactivating an account).
+
Returns:
A tuple of (params, session_id).
@@ -343,7 +348,7 @@ class AuthHandler(BaseHandler):
"""
if not requester.access_token_id:
raise ValueError("Cannot validate a user without an access token")
- if self._ui_auth_session_timeout:
+ if can_skip_ui_auth and self._ui_auth_session_timeout:
last_validated = await self.store.get_access_token_last_validated(
requester.access_token_id
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 97448780..3972849d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -79,9 +79,15 @@ class E2eKeysHandler:
"client_keys", self.on_federation_query_client_keys
)
+ # Limit the number of in-flight requests from a single device.
+ self._query_devices_linearizer = Linearizer(
+ name="query_devices",
+ max_count=10,
+ )
+
@trace
async def query_devices(
- self, query_body: JsonDict, timeout: int, from_user_id: str
+ self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict:
"""Handle a device key query from a client
@@ -105,191 +111,197 @@ class E2eKeysHandler:
from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
+ from_device_id: the device making the query. This is used to limit
+ the number of in-flight queries at a time.
"""
-
- device_keys_query = query_body.get(
- "device_keys", {}
- ) # type: Dict[str, Iterable[str]]
-
- # separate users by domain.
- # make a map from domain to user_id to device_ids
- local_query = {}
- remote_queries = {}
-
- for user_id, device_ids in device_keys_query.items():
- # we use UserID.from_string to catch invalid user ids
- if self.is_mine(UserID.from_string(user_id)):
- local_query[user_id] = device_ids
- else:
- remote_queries[user_id] = device_ids
-
- set_tag("local_key_query", local_query)
- set_tag("remote_key_query", remote_queries)
-
- # First get local devices.
- # A map of destination -> failure response.
- failures = {} # type: Dict[str, JsonDict]
- results = {}
- if local_query:
- local_result = await self.query_local_devices(local_query)
- for user_id, keys in local_result.items():
- if user_id in local_query:
- results[user_id] = keys
-
- # Get cached cross-signing keys
- cross_signing_keys = await self.get_cross_signing_keys_from_cache(
- device_keys_query, from_user_id
- )
-
- # Now attempt to get any remote devices from our local cache.
- # A map of destination -> user ID -> device IDs.
- remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
- if remote_queries:
- query_list = [] # type: List[Tuple[str, Optional[str]]]
- for user_id, device_ids in remote_queries.items():
- if device_ids:
- query_list.extend((user_id, device_id) for device_id in device_ids)
+ with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Iterable[str]]
+
+ # separate users by domain.
+ # make a map from domain to user_id to device_ids
+ local_query = {}
+ remote_queries = {}
+
+ for user_id, device_ids in device_keys_query.items():
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
+ local_query[user_id] = device_ids
else:
- query_list.append((user_id, None))
-
- (
- user_ids_not_in_cache,
- remote_results,
- ) = await self.store.get_user_devices_from_cache(query_list)
- for user_id, devices in remote_results.items():
- user_devices = results.setdefault(user_id, {})
- for device_id, device in devices.items():
- keys = device.get("keys", None)
- device_display_name = device.get("device_display_name", None)
- if keys:
- result = dict(keys)
- unsigned = result.setdefault("unsigned", {})
- if device_display_name:
- unsigned["device_display_name"] = device_display_name
- user_devices[device_id] = result
-
- # check for missing cross-signing keys.
- for user_id in remote_queries.keys():
- cached_cross_master = user_id in cross_signing_keys["master_keys"]
- cached_cross_selfsigning = (
- user_id in cross_signing_keys["self_signing_keys"]
- )
-
- # check if we are missing only one of cross-signing master or
- # self-signing key, but the other one is cached.
- # as we need both, this will issue a federation request.
- # if we don't have any of the keys, either the user doesn't have
- # cross-signing set up, or the cached device list
- # is not (yet) updated.
- if cached_cross_master ^ cached_cross_selfsigning:
- user_ids_not_in_cache.add(user_id)
-
- # add those users to the list to fetch over federation.
- for user_id in user_ids_not_in_cache:
- domain = get_domain_from_id(user_id)
- r = remote_queries_not_in_cache.setdefault(domain, {})
- r[user_id] = remote_queries[user_id]
-
- # Now fetch any devices that we don't have in our cache
- @trace
- async def do_remote_query(destination):
- """This is called when we are querying the device list of a user on
- a remote homeserver and their device list is not in the device list
- cache. If we share a room with this user and we're not querying for
- specific user we will update the cache with their device list.
- """
-
- destination_query = remote_queries_not_in_cache[destination]
-
- # We first consider whether we wish to update the device list cache with
- # the users device list. We want to track a user's devices when the
- # authenticated user shares a room with the queried user and the query
- # has not specified a particular device.
- # If we update the cache for the queried user we remove them from further
- # queries. We use the more efficient batched query_client_keys for all
- # remaining users
- user_ids_updated = []
- for (user_id, device_list) in destination_query.items():
- if user_id in user_ids_updated:
- continue
-
- if device_list:
- continue
+ remote_queries[user_id] = device_ids
+
+ set_tag("local_key_query", local_query)
+ set_tag("remote_key_query", remote_queries)
+
+ # First get local devices.
+ # A map of destination -> failure response.
+ failures = {} # type: Dict[str, JsonDict]
+ results = {}
+ if local_query:
+ local_result = await self.query_local_devices(local_query)
+ for user_id, keys in local_result.items():
+ if user_id in local_query:
+ results[user_id] = keys
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- continue
+ # Get cached cross-signing keys
+ cross_signing_keys = await self.get_cross_signing_keys_from_cache(
+ device_keys_query, from_user_id
+ )
- # We've decided we're sharing a room with this user and should
- # probably be tracking their device lists. However, we haven't
- # done an initial sync on the device list so we do it now.
- try:
- if self._is_master:
- user_devices = await self.device_handler.device_list_updater.user_device_resync(
- user_id
+ # Now attempt to get any remote devices from our local cache.
+ # A map of destination -> user ID -> device IDs.
+ remote_queries_not_in_cache = (
+ {}
+ ) # type: Dict[str, Dict[str, Iterable[str]]]
+ if remote_queries:
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
+ for user_id, device_ids in remote_queries.items():
+ if device_ids:
+ query_list.extend(
+ (user_id, device_id) for device_id in device_ids
)
else:
- user_devices = await self._user_device_resync_client(
- user_id=user_id
- )
-
- user_devices = user_devices["devices"]
- user_results = results.setdefault(user_id, {})
- for device in user_devices:
- user_results[device["device_id"]] = device["keys"]
- user_ids_updated.append(user_id)
- except Exception as e:
- failures[destination] = _exception_to_failure(e)
-
- if len(destination_query) == len(user_ids_updated):
- # We've updated all the users in the query and we do not need to
- # make any further remote calls.
- return
+ query_list.append((user_id, None))
- # Remove all the users from the query which we have updated
- for user_id in user_ids_updated:
- destination_query.pop(user_id)
+ (
+ user_ids_not_in_cache,
+ remote_results,
+ ) = await self.store.get_user_devices_from_cache(query_list)
+ for user_id, devices in remote_results.items():
+ user_devices = results.setdefault(user_id, {})
+ for device_id, device in devices.items():
+ keys = device.get("keys", None)
+ device_display_name = device.get("device_display_name", None)
+ if keys:
+ result = dict(keys)
+ unsigned = result.setdefault("unsigned", {})
+ if device_display_name:
+ unsigned["device_display_name"] = device_display_name
+ user_devices[device_id] = result
+
+ # check for missing cross-signing keys.
+ for user_id in remote_queries.keys():
+ cached_cross_master = user_id in cross_signing_keys["master_keys"]
+ cached_cross_selfsigning = (
+ user_id in cross_signing_keys["self_signing_keys"]
+ )
- try:
- remote_result = await self.federation.query_client_keys(
- destination, {"device_keys": destination_query}, timeout=timeout
- )
+ # check if we are missing only one of cross-signing master or
+ # self-signing key, but the other one is cached.
+ # as we need both, this will issue a federation request.
+ # if we don't have any of the keys, either the user doesn't have
+ # cross-signing set up, or the cached device list
+ # is not (yet) updated.
+ if cached_cross_master ^ cached_cross_selfsigning:
+ user_ids_not_in_cache.add(user_id)
+
+ # add those users to the list to fetch over federation.
+ for user_id in user_ids_not_in_cache:
+ domain = get_domain_from_id(user_id)
+ r = remote_queries_not_in_cache.setdefault(domain, {})
+ r[user_id] = remote_queries[user_id]
+
+ # Now fetch any devices that we don't have in our cache
+ @trace
+ async def do_remote_query(destination):
+ """This is called when we are querying the device list of a user on
+ a remote homeserver and their device list is not in the device list
+ cache. If we share a room with this user and we're not querying for
+ specific user we will update the cache with their device list.
+ """
+
+ destination_query = remote_queries_not_in_cache[destination]
+
+ # We first consider whether we wish to update the device list cache with
+ # the users device list. We want to track a user's devices when the
+ # authenticated user shares a room with the queried user and the query
+ # has not specified a particular device.
+ # If we update the cache for the queried user we remove them from further
+ # queries. We use the more efficient batched query_client_keys for all
+ # remaining users
+ user_ids_updated = []
+ for (user_id, device_list) in destination_query.items():
+ if user_id in user_ids_updated:
+ continue
+
+ if device_list:
+ continue
+
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ continue
+
+ # We've decided we're sharing a room with this user and should
+ # probably be tracking their device lists. However, we haven't
+ # done an initial sync on the device list so we do it now.
+ try:
+ if self._is_master:
+ user_devices = await self.device_handler.device_list_updater.user_device_resync(
+ user_id
+ )
+ else:
+ user_devices = await self._user_device_resync_client(
+ user_id=user_id
+ )
+
+ user_devices = user_devices["devices"]
+ user_results = results.setdefault(user_id, {})
+ for device in user_devices:
+ user_results[device["device_id"]] = device["keys"]
+ user_ids_updated.append(user_id)
+ except Exception as e:
+ failures[destination] = _exception_to_failure(e)
+
+ if len(destination_query) == len(user_ids_updated):
+ # We've updated all the users in the query and we do not need to
+ # make any further remote calls.
+ return
+
+ # Remove all the users from the query which we have updated
+ for user_id in user_ids_updated:
+ destination_query.pop(user_id)
- for user_id, keys in remote_result["device_keys"].items():
- if user_id in destination_query:
- results[user_id] = keys
+ try:
+ remote_result = await self.federation.query_client_keys(
+ destination, {"device_keys": destination_query}, timeout=timeout
+ )
- if "master_keys" in remote_result:
- for user_id, key in remote_result["master_keys"].items():
+ for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
+ results[user_id] = keys
- if "self_signing_keys" in remote_result:
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
- except Exception as e:
- failure = _exception_to_failure(e)
- failures[destination] = failure
- set_tag("error", True)
- set_tag("reason", failure)
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
- await make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(do_remote_query, destination)
- for destination in remote_queries_not_in_cache
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
- )
+ except Exception as e:
+ failure = _exception_to_failure(e)
+ failures[destination] = failure
+ set_tag("error", True)
+ set_tag("reason", failure)
+
+ await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(do_remote_query, destination)
+ for destination in remote_queries_not_in_cache
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
- ret = {"device_keys": results, "failures": failures}
+ ret = {"device_keys": results, "failures": failures}
- ret.update(cross_signing_keys)
+ ret.update(cross_signing_keys)
- return ret
+ return ret
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a0df16a3..989996b6 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -13,7 +13,12 @@
# limitations under the License.
from typing import TYPE_CHECKING, Collection, Optional
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import (
+ EventTypes,
+ JoinRules,
+ Membership,
+ RestrictedJoinRuleTypes,
+)
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
@@ -42,7 +47,7 @@ class EventAuthHandler:
Check whether a user can join a room without an invite due to restricted join rules.
When joining a room with restricted joined rules (as defined in MSC3083),
- the membership of spaces must be checked during a room join.
+ the membership of rooms must be checked during a room join.
Args:
state_ids: The state of the room as it currently is.
@@ -67,20 +72,20 @@ class EventAuthHandler:
if not await self.has_restricted_join_rules(state_ids, room_version):
return
- # Get the spaces which allow access to this room and check if the user is
+ # Get the rooms which allow access to this room and check if the user is
# in any of them.
- allowed_spaces = await self.get_spaces_that_allow_join(state_ids)
- if not await self.is_user_in_rooms(allowed_spaces, user_id):
+ allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
+ if not await self.is_user_in_rooms(allowed_rooms, user_id):
raise AuthError(
403,
- "You do not belong to any of the required spaces to join this room.",
+ "You do not belong to any of the required rooms to join this room.",
)
async def has_restricted_join_rules(
self, state_ids: StateMap[str], room_version: RoomVersion
) -> bool:
"""
- Return if the room has the proper join rules set for access via spaces.
+ Return if the room has the proper join rules set for access via rooms.
Args:
state_ids: The state of the room as it currently is.
@@ -102,17 +107,17 @@ class EventAuthHandler:
join_rules_event = await self._store.get_event(join_rules_event_id)
return join_rules_event.content.get("join_rule") == JoinRules.MSC3083_RESTRICTED
- async def get_spaces_that_allow_join(
+ async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]
) -> Collection[str]:
"""
- Generate a list of spaces which allow access to a room.
+ Generate a list of rooms in which membership allows access to a room.
Args:
- state_ids: The state of the room as it currently is.
+ state_ids: The current state of the room the user wishes to join
Returns:
- A collection of spaces which provide membership to the room.
+ A collection of room IDs. Membership in any of the rooms in the list grants the ability to join the target room.
"""
# If there's no join rule, then it defaults to invite (so this doesn't apply).
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
@@ -123,21 +128,25 @@ class EventAuthHandler:
join_rules_event = await self._store.get_event(join_rules_event_id)
# If allowed is of the wrong form, then only allow invited users.
- allowed_spaces = join_rules_event.content.get("allow", [])
- if not isinstance(allowed_spaces, list):
+ allow_list = join_rules_event.content.get("allow", [])
+ if not isinstance(allow_list, list):
return ()
# Pull out the other room IDs, invalid data gets filtered.
result = []
- for space in allowed_spaces:
- if not isinstance(space, dict):
+ for allow in allow_list:
+ if not isinstance(allow, dict):
+ continue
+
+ # If the type is unexpected, skip it.
+ if allow.get("type") != RestrictedJoinRuleTypes.ROOM_MEMBERSHIP:
continue
- space_id = space.get("space")
- if not isinstance(space_id, str):
+ room_id = allow.get("room_id")
+ if not isinstance(room_id, str):
continue
- result.append(space_id)
+ result.append(room_id)
return result
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index abbb7142..1b566dbf 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,6 +1,5 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,6 +33,7 @@ from typing import (
)
import attr
+from prometheus_client import Counter
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -102,6 +102,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+soft_failed_event_counter = Counter(
+ "synapse_federation_soft_failed_events_total",
+ "Events received over federation that we marked as soft_failed",
+)
+
@attr.s(slots=True)
class _NewEventInfo:
@@ -1550,6 +1555,77 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
+ @log_function
+ async def do_knock(
+ self,
+ target_hosts: List[str],
+ room_id: str,
+ knockee: str,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Sends the knock to the remote server.
+
+ This first triggers a make_knock request that returns a partial
+ event that we can fill out and sign. This is then sent to the
+ remote server via send_knock.
+
+ Knock events must be signed by the knockee's server before distributing.
+
+ Args:
+ target_hosts: A list of hosts that we want to try knocking through.
+ room_id: The ID of the room to knock on.
+ knockee: The ID of the user who is knocking.
+ content: The content of the knock event.
+
+ Returns:
+ A tuple of (event ID, stream ID).
+
+ Raises:
+ SynapseError: If the chosen remote server returns a 3xx/4xx code.
+ RuntimeError: If no servers were reachable.
+ """
+ logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee)
+
+ # Inform the remote server of the room versions we support
+ supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys())
+
+ # Ask the remote server to create a valid knock event for us. Once received,
+ # we sign the event
+ params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
+ origin, event, event_format_version = await self._make_and_verify_event(
+ target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
+ )
+
+ # Record the room ID and its version so that we have a record of the room
+ await self._maybe_store_room_on_outlier_membership(
+ room_id=event.room_id, room_version=event_format_version
+ )
+
+ # Initially try the host that we successfully called /make_knock on
+ try:
+ target_hosts.remove(origin)
+ target_hosts.insert(0, origin)
+ except ValueError:
+ pass
+
+ # Send the signed event back to the room, and potentially receive some
+ # further information about the room in the form of partial state events
+ stripped_room_state = await self.federation_client.send_knock(
+ target_hosts, event
+ )
+
+ # Store any stripped room state events in the "unsigned" key of the event.
+ # This is a bit of a hack and is cribbing off of invites. Basically we
+ # store the room state here and retrieve it again when this event appears
+ # in the invitee's sync stream. It is stripped out for all other local users.
+ event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+
+ context = await self.state_handler.compute_event_context(event)
+ stream_id = await self.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
+ return event.event_id, stream_id
+
async def _handle_queued_pdus(
self, room_queue: List[Tuple[EventBase, str]]
) -> None:
@@ -1885,7 +1961,7 @@ class FederationHandler(BaseHandler):
return event
async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
- """ We have received a leave event for a room. Fully process it."""
+ """We have received a leave event for a room. Fully process it."""
event = pdu
logger.debug(
@@ -1915,6 +1991,114 @@ class FederationHandler(BaseHandler):
return None
+ @log_function
+ async def on_make_knock_request(
+ self, origin: str, room_id: str, user_id: str
+ ) -> EventBase:
+ """We've received a make_knock request, so we create a partial
+ knock event for the room and return that. We do *not* persist or
+ process it until the other server has signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: The room to create the knock event in.
+ user_id: The user to create the knock for.
+
+ Returns:
+ The partial knock event.
+ """
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Get /make_knock request for user %r from different origin %s, ignoring",
+ user_id,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ room_version = await self.store.get_room_version_id(room_id)
+
+ builder = self.event_builder_factory.new(
+ room_version,
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.KNOCK},
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": user_id,
+ },
+ )
+
+ event, context = await self.event_creation_handler.create_new_client_event(
+ builder=builder
+ )
+
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.warning("Creation of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ try:
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_knock_request`
+ await self.auth.check_from_context(
+ room_version, event, context, do_sig_check=False
+ )
+ except AuthError as e:
+ logger.warning("Failed to create new knock %r because %s", event, e)
+ raise e
+
+ return event
+
+ @log_function
+ async def on_send_knock_request(
+ self, origin: str, event: EventBase
+ ) -> EventContext:
+ """
+ We have received a knock event for a room. Verify that event and send it into the room
+ on the knocking homeserver's behalf.
+
+ Args:
+ origin: The remote homeserver of the knocking user.
+ event: The knocking member event that has been signed by the remote homeserver.
+
+ Returns:
+ The context of the event after inserting it into the room graph.
+ """
+ logger.debug(
+ "on_send_knock_request: Got event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got /send_knock request for user %r from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ event.internal_metadata.outlier = False
+
+ context = await self.state_handler.compute_event_context(event)
+
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ await self._auth_and_persist_event(origin, event, context)
+
+ return context
+
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
@@ -2239,7 +2423,11 @@ class FederationHandler(BaseHandler):
)
async def _check_for_soft_fail(
- self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
+ self,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ backfilled: bool,
+ origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
such.
@@ -2248,6 +2436,7 @@ class FederationHandler(BaseHandler):
event
state: The state at the event if we don't have all the event's prev events
backfilled: Whether the event is from backfill
+ origin: The host the event originates from.
"""
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
@@ -2317,7 +2506,18 @@ class FederationHandler(BaseHandler):
try:
event_auth.check(room_version_obj, event, auth_events=current_auth_events)
except AuthError as e:
- logger.warning("Soft-failing %r because %s", event, e)
+ logger.warning(
+ "Soft-failing %r (from %s) because %s",
+ event,
+ e,
+ origin,
+ extra={
+ "room_id": event.room_id,
+ "mxid": event.sender,
+ "hs": origin,
+ },
+ )
+ soft_failed_event_counter.inc()
event.internal_metadata.soft_failed = True
async def on_get_missing_events(
@@ -2429,7 +2629,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
if not context.rejected:
- await self._check_for_soft_fail(event, state, backfilled)
+ await self._check_for_soft_fail(event, state, backfilled, origin=origin)
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9f365eb5..db12abd5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,6 +1,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+# Copyrignt 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -398,13 +399,14 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
- self.room_invite_state_types = self.hs.config.api.room_prejoin_state
+ self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
- self.membership_types_to_include_profile_data_in = (
- {Membership.JOIN, Membership.INVITE}
- if self.hs.config.include_profile_data_on_invite
- else {Membership.JOIN}
- )
+ self.membership_types_to_include_profile_data_in = {
+ Membership.JOIN,
+ Membership.KNOCK,
+ }
+ if self.hs.config.include_profile_data_on_invite:
+ self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
@@ -480,6 +482,9 @@ class EventCreationHandler:
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
+ outlier: bool = False,
+ historical: bool = False,
+ depth: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
"""
Given a dict from a client, create a new event.
@@ -506,6 +511,14 @@ class EventCreationHandler:
require_consent: Whether to check if the requester has
consented to the privacy policy.
+
+ outlier: Indicates whether the event is an `outlier`, i.e. if
+ it's from an arbitrary point and floating in the DAG as
+ opposed to being inline with the current DAG.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
+
Raises:
ResourceLimitError if server is blocked to some resource being
exceeded
@@ -561,11 +574,36 @@ class EventCreationHandler:
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
+ builder.internal_metadata.outlier = outlier
+
+ builder.internal_metadata.historical = historical
+
+ # Strip down the auth_event_ids to only what we need to auth the event.
+ # For example, we don't need extra m.room.member that don't match event.sender
+ if auth_event_ids is not None:
+ temp_event = await builder.build(
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
+ depth=depth,
+ )
+ auth_events = await self.store.get_events_as_list(auth_event_ids)
+ # Create a StateMap[str]
+ auth_event_state_map = {
+ (e.type, e.state_key): e.event_id for e in auth_events
+ }
+ # Actually strip down and use the necessary auth events
+ auth_event_ids = self.auth.compute_auth_events(
+ event=temp_event,
+ current_state_ids=auth_event_state_map,
+ for_verification=False,
+ )
+
event, context = await self.create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
+ depth=depth,
)
# In an ideal world we wouldn't need the second part of this condition. However,
@@ -722,9 +760,13 @@ class EventCreationHandler:
self,
requester: Requester,
event_dict: dict,
+ prev_event_ids: Optional[List[str]] = None,
+ auth_event_ids: Optional[List[str]] = None,
ratelimit: bool = True,
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
+ outlier: bool = False,
+ depth: Optional[int] = None,
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
@@ -734,10 +776,24 @@ class EventCreationHandler:
Args:
requester: The requester sending the event.
event_dict: An entire event.
+ prev_event_ids:
+ The event IDs to use as the prev events.
+ Should normally be left as None to automatically request them
+ from the database.
+ auth_event_ids:
+ The event ids to use as the auth_events for the new event.
+ Should normally be left as None, which will cause them to be calculated
+ based on the room state at the prev_events.
ratelimit: Whether to rate limit this send.
txn_id: The transaction ID.
ignore_shadow_ban: True if shadow-banned users should be allowed to
send this event.
+ outlier: Indicates whether the event is an `outlier`, i.e. if
+ it's from an arbitrary point and floating in the DAG as
+ opposed to being inline with the current DAG.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
Returns:
The event, and its stream ordering (if deduplication happened,
@@ -777,7 +833,13 @@ class EventCreationHandler:
return event, event.internal_metadata.stream_ordering
event, context = await self.create_event(
- requester, event_dict, txn_id=txn_id
+ requester,
+ event_dict,
+ txn_id=txn_id,
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
+ outlier=outlier,
+ depth=depth,
)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
@@ -809,6 +871,7 @@ class EventCreationHandler:
requester: Optional[Requester] = None,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@@ -826,6 +889,10 @@ class EventCreationHandler:
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
+
Returns:
Tuple of created event, context
"""
@@ -849,9 +916,24 @@ class EventCreationHandler:
), "Attempting to create an event with no prev_events"
event = await builder.build(
- prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
+ depth=depth,
)
- context = await self.state.compute_event_context(event)
+
+ old_state = None
+
+ # Pass on the outlier property from the builder to the event
+ # after it is created
+ if builder.internal_metadata.outlier:
+ event.internal_metadata.outlier = builder.internal_metadata.outlier
+
+ # Calculate the state for outliers that pass in their own `auth_event_ids`
+ if auth_event_ids:
+ old_state = await self.store.get_events_as_list(auth_event_ids)
+
+ context = await self.state.compute_event_context(event, old_state=old_state)
+
if requester:
context.app_service = requester.app_service
@@ -961,8 +1043,8 @@ class EventCreationHandler:
room_version = await self.store.get_room_version_id(event.room_id)
if event.internal_metadata.is_out_of_band_membership():
- # the only sort of out-of-band-membership events we expect to see here
- # are invite rejections we have generated ourselves.
+ # the only sort of out-of-band-membership events we expect to see here are
+ # invite rejections and rescinded knocks that we have generated ourselves.
assert event.type == EventTypes.Member
assert event.content["membership"] == Membership.LEAVE
else:
@@ -1016,7 +1098,13 @@ class EventCreationHandler:
the arguments.
"""
- await self.action_generator.handle_push_actions_for_event(event, context)
+ # Skip push notification actions for historical messages
+ # because we don't want to notify people about old history back in time.
+ # The historical messages also do not have the proper `context.current_state_ids`
+ # and `state_groups` because they have `prev_events` that aren't persisted yet
+ # (historical messages persisted in reverse-chronological order).
+ if not event.internal_metadata.is_historical():
+ await self.action_generator.handle_push_actions_for_event(event, context)
try:
# If we're a worker we need to hit out to the master.
@@ -1239,7 +1327,7 @@ class EventCreationHandler:
"invite_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
- self.room_invite_state_types,
+ self.room_prejoin_state_types,
membership_user_id=event.sender,
)
@@ -1257,6 +1345,14 @@ class EventCreationHandler:
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
+ if event.content["membership"] == Membership.KNOCK:
+ event.unsigned[
+ "knock_room_state"
+ ] = await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ )
+
if event.type == EventTypes.Redaction:
original_event = await self.store.get_event(
event.redacts,
@@ -1307,13 +1403,21 @@ class EventCreationHandler:
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
+ # Mark any `m.historical` messages as backfilled so they don't appear
+ # in `/sync` and have the proper decrementing `stream_ordering` as we import
+ backfilled = False
+ if event.internal_metadata.is_historical():
+ backfilled = True
+
# Note that this returns the event that was persisted, which may not be
# the same as we passed in if it was deduplicated due transaction IDs.
(
event,
event_pos,
max_stream_token,
- ) = await self.storage.persistence.persist_event(event, context=context)
+ ) = await self.storage.persistence.persist_event(
+ event, context=context, backfilled=backfilled
+ )
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 4ceef3fa..ca1ed6a5 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -195,7 +195,7 @@ class RegistrationHandler(BaseHandler):
bind_emails: list of emails to bind to this account.
by_admin: True if this registration is being made via the
admin api, otherwise False.
- user_agent_ips: Tuples of IP addresses and user-agents used
+ user_agent_ips: Tuples of user-agents and IP addresses used
during the registration process.
auth_provider_id: The SSO IdP the user used, if any.
Returns:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 141c9c04..5e3ef7ce 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -44,7 +44,7 @@ class RoomListHandler(BaseHandler):
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
hs.get_clock(), "room_list"
- ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
+ ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
self.remote_response_cache = ResponseCache(
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
@@ -54,7 +54,7 @@ class RoomListHandler(BaseHandler):
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
- network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
) -> JsonDict:
"""Generate a local public room list.
@@ -111,7 +111,7 @@ class RoomListHandler(BaseHandler):
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
- network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
) -> JsonDict:
"""Generate a public room list.
@@ -169,6 +169,7 @@ class RoomListHandler(BaseHandler):
"world_readable": room["history_visibility"]
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
+ "join_rule": room["join_rules"],
}
# Filter out Nones – rather omit the field altogether
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d6fc43e7..11925916 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,4 +1,5 @@
# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import abc
import logging
import random
@@ -30,7 +30,15 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
+from synapse.types import (
+ JsonDict,
+ Requester,
+ RoomAlias,
+ RoomID,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -126,6 +134,24 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def remote_knock(
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
+ """Try and knock on a room that this server is not in
+
+ Args:
+ remote_room_hosts: List of servers that can be used to knock via.
+ room_id: Room that we are trying to knock on.
+ user: User who is trying to knock.
+ content: A dict that should be used as the content of the knock event.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
async def remote_reject_invite(
self,
invite_event_id: str,
@@ -149,6 +175,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """Rescind a local knock made on a remote room.
+
+ Args:
+ knock_event_id: The ID of the knock event to rescind.
+ txn_id: An optional transaction ID supplied by the client.
+ requester: The user making the request, according to the access token.
+ content: The content of the generated leave event.
+
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
@@ -210,11 +257,42 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: str,
membership: str,
prev_event_ids: List[str],
+ auth_event_ids: Optional[List[str]] = None,
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
+ outlier: bool = False,
) -> Tuple[str, int]:
+ """
+ Internal membership update function to get an existing event or create
+ and persist a new event for the new membership change.
+
+ Args:
+ requester:
+ target:
+ room_id:
+ membership:
+ prev_event_ids: The event IDs to use as the prev events
+
+ auth_event_ids:
+ The event ids to use as the auth_events for the new event.
+ Should normally be left as None, which will cause them to be calculated
+ based on the room state at the prev_events.
+
+ txn_id:
+ ratelimit:
+ content:
+ require_consent:
+
+ outlier: Indicates whether the event is an `outlier`, i.e. if
+ it's from an arbitrary point and floating in the DAG as
+ opposed to being inline with the current DAG.
+
+ Returns:
+ Tuple of event ID and stream ordering position
+ """
+
user_id = target.to_string()
if content is None:
@@ -251,7 +329,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
},
txn_id=txn_id,
prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
require_consent=require_consent,
+ outlier=outlier,
)
prev_state_ids = await context.get_prev_state_ids()
@@ -352,6 +432,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
+ outlier: bool = False,
+ prev_event_ids: Optional[List[str]] = None,
+ auth_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -366,6 +449,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit: Whether to rate limit the request.
content: The content of the created event.
require_consent: Whether consent is required.
+ outlier: Indicates whether the event is an `outlier`, i.e. if
+ it's from an arbitrary point and floating in the DAG as
+ opposed to being inline with the current DAG.
+ prev_event_ids: The event IDs to use as the prev events
+ auth_event_ids:
+ The event ids to use as the auth_events for the new event.
+ Should normally be left as None, which will cause them to be calculated
+ based on the room state at the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -392,6 +483,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
content=content,
require_consent=require_consent,
+ outlier=outlier,
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
)
return result
@@ -408,10 +502,36 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
+ outlier: bool = False,
+ prev_event_ids: Optional[List[str]] = None,
+ auth_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.
Assumes that the membership linearizer is already held for the room.
+
+ Args:
+ requester:
+ target:
+ room_id:
+ action:
+ txn_id:
+ remote_room_hosts:
+ third_party_signed:
+ ratelimit:
+ content:
+ require_consent:
+ outlier: Indicates whether the event is an `outlier`, i.e. if
+ it's from an arbitrary point and floating in the DAG as
+ opposed to being inline with the current DAG.
+ prev_event_ids: The event IDs to use as the prev events
+ auth_event_ids:
+ The event ids to use as the auth_events for the new event.
+ Should normally be left as None, which will cause them to be calculated
+ based on the room state at the prev_events.
+
+ Returns:
+ A tuple of the new event ID and stream ID.
"""
content_specified = bool(content)
if content is None:
@@ -496,6 +616,21 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if block_invite:
raise SynapseError(403, "Invites have been disabled on this server")
+ if prev_event_ids:
+ return await self._local_membership_update(
+ requester=requester,
+ target=target,
+ room_id=room_id,
+ membership=effective_membership_state,
+ txn_id=txn_id,
+ ratelimit=ratelimit,
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
+ content=content,
+ require_consent=require_consent,
+ outlier=outlier,
+ )
+
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
current_state_ids = await self.state_handler.get_current_state_ids(
@@ -603,53 +738,79 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
- # perhaps we've been invited
+ # Figure out the user's current membership state for the room
(
current_membership_type,
current_membership_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
target.to_string(), room_id
)
- if (
- current_membership_type != Membership.INVITE
- or not current_membership_event_id
- ):
+ if not current_membership_type or not current_membership_event_id:
logger.info(
"%s sent a leave request to %s, but that is not an active room "
- "on this server, and there is no pending invite",
+ "on this server, or there is no pending invite or knock",
target,
room_id,
)
raise SynapseError(404, "Not a known room")
- invite = await self.store.get_event(current_membership_event_id)
- logger.info(
- "%s rejects invite to %s from %s", target, room_id, invite.sender
- )
+ # perhaps we've been invited
+ if current_membership_type == Membership.INVITE:
+ invite = await self.store.get_event(current_membership_event_id)
+ logger.info(
+ "%s rejects invite to %s from %s",
+ target,
+ room_id,
+ invite.sender,
+ )
- if not self.hs.is_mine_id(invite.sender):
- # send the rejection to the inviter's HS (with fallback to
- # local event)
- return await self.remote_reject_invite(
- invite.event_id,
- txn_id,
- requester,
- content,
+ if not self.hs.is_mine_id(invite.sender):
+ # send the rejection to the inviter's HS (with fallback to
+ # local event)
+ return await self.remote_reject_invite(
+ invite.event_id,
+ txn_id,
+ requester,
+ content,
+ )
+
+ # the inviter was on our server, but has now left. Carry on
+ # with the normal rejection codepath, which will also send the
+ # rejection out to any other servers we believe are still in the room.
+
+ # thanks to overzealous cleaning up of event_forward_extremities in
+ # `delete_old_current_state_events`, it's possible to end up with no
+ # forward extremities here. If that happens, let's just hang the
+ # rejection off the invite event.
+ #
+ # see: https://github.com/matrix-org/synapse/issues/7139
+ if len(latest_event_ids) == 0:
+ latest_event_ids = [invite.event_id]
+
+ # or perhaps this is a remote room that a local user has knocked on
+ elif current_membership_type == Membership.KNOCK:
+ knock = await self.store.get_event(current_membership_event_id)
+ return await self.remote_rescind_knock(
+ knock.event_id, txn_id, requester, content
)
- # the inviter was on our server, but has now left. Carry on
- # with the normal rejection codepath, which will also send the
- # rejection out to any other servers we believe are still in the room.
+ elif effective_membership_state == Membership.KNOCK:
+ if not is_host_in_room:
+ # The knock needs to be sent over federation instead
+ remote_room_hosts.append(get_domain_from_id(room_id))
+
+ content["membership"] = Membership.KNOCK
+
+ profile = self.profile_handler
+ if "displayname" not in content:
+ content["displayname"] = await profile.get_displayname(target)
+ if "avatar_url" not in content:
+ content["avatar_url"] = await profile.get_avatar_url(target)
- # thanks to overzealous cleaning up of event_forward_extremities in
- # `delete_old_current_state_events`, it's possible to end up with no
- # forward extremities here. If that happens, let's just hang the
- # rejection off the invite event.
- #
- # see: https://github.com/matrix-org/synapse/issues/7139
- if len(latest_event_ids) == 0:
- latest_event_ids = [invite.event_id]
+ return await self.remote_knock(
+ remote_room_hosts, room_id, target, content
+ )
return await self._local_membership_update(
requester=requester,
@@ -659,8 +820,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
+ auth_event_ids=auth_event_ids,
content=content,
require_consent=require_consent,
+ outlier=outlier,
)
async def transfer_room_state_on_room_upgrade(
@@ -1209,6 +1372,35 @@ class RoomMemberMasterHandler(RoomMemberHandler):
invite_event, txn_id, requester, content
)
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rescinds a local knock made on a remote room
+
+ Args:
+ knock_event_id: The ID of the knock event to rescind.
+ txn_id: The transaction ID to use.
+ requester: The originator of the request.
+ content: The content of the leave event.
+
+ Implements RoomMemberHandler.remote_rescind_knock
+ """
+ # TODO: We don't yet support rescinding knocks over federation
+ # as we don't know which homeserver to send it to. An obvious
+ # candidate is the remote homeserver we originally knocked through,
+ # however we don't currently store that information.
+
+ # Just rescind the knock locally
+ knock_event = await self.store.get_event(knock_event_id)
+ return await self._generate_local_out_of_band_leave(
+ knock_event, txn_id, requester, content
+ )
+
async def _generate_local_out_of_band_leave(
self,
previous_membership_event: EventBase,
@@ -1272,6 +1464,36 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return result_event.event_id, result_event.internal_metadata.stream_ordering
+ async def remote_knock(
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
+ """Sends a knock to a room. Attempts to do so via one remote out of a given list.
+
+ Args:
+ remote_room_hosts: A list of homeservers to try knocking through.
+ room_id: The ID of the room to knock on.
+ user: The user to knock on behalf of.
+ content: The content of the knock event.
+
+ Returns:
+ A tuple of (event ID, stream ID).
+ """
+ # filter ourselves out of remote_room_hosts
+ remote_room_hosts = [
+ host for host in remote_room_hosts if host != self.hs.hostname
+ ]
+
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ return await self.federation_handler.do_knock(
+ remote_room_hosts, room_id, user.to_string(), content=content
+ )
+
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room"""
user_left_room(self.distributor, target, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 3e89dd23..221552a2 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,10 +19,12 @@ from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
+ ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
+ ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -35,7 +37,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
+ self._remote_knock_client = ReplRemoteKnock.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
+ self._remote_rescind_client = ReplRescindKnock.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
async def _remote_join(
@@ -80,6 +84,53 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
)
return ret["event_id"], ret["stream_id"]
+ async def remote_rescind_knock(
+ self,
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ) -> Tuple[str, int]:
+ """
+ Rescinds a local knock made on a remote room
+
+ Args:
+ knock_event_id: the knock event
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the request, according to the access token
+ content: additional content to include in the leave event.
+ Normally an empty dict.
+
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event)
+ """
+ ret = await self._remote_rescind_client(
+ knock_event_id=knock_event_id,
+ txn_id=txn_id,
+ requester=requester,
+ content=content,
+ )
+ return ret["event_id"], ret["stream_id"]
+
+ async def remote_knock(
+ self,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
+ """Sends a knock to a room.
+
+ Implements RoomMemberHandler.remote_knock
+ """
+ ret = await self._remote_knock_client(
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user=user,
+ content=content,
+ )
+ return ret["event_id"], ret["stream_id"]
+
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room"""
await self._notify_change_client(
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 046dba6f..17fc47ce 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -160,14 +160,14 @@ class SpaceSummaryHandler:
# Check if the user is a member of any of the allowed spaces
# from the response.
- allowed_spaces = room.get("allowed_spaces")
+ allowed_rooms = room.get("allowed_spaces")
if (
not include_room
- and allowed_spaces
- and isinstance(allowed_spaces, list)
+ and allowed_rooms
+ and isinstance(allowed_rooms, list)
):
include_room = await self._event_auth_handler.is_user_in_rooms(
- allowed_spaces, requester
+ allowed_rooms, requester
)
# Finally, if this isn't the requested room, check ourselves
@@ -402,10 +402,7 @@ class SpaceSummaryHandler:
return (), ()
return res.rooms, tuple(
- ev.data
- for ev in res.events
- if ev.event_type == EventTypes.MSC1772_SPACE_CHILD
- or ev.event_type == EventTypes.SpaceChild
+ ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild
)
async def _is_room_accessible(
@@ -448,21 +445,20 @@ class SpaceSummaryHandler:
member_event_id = state_ids.get((EventTypes.Member, requester), None)
# If they're in the room they can see info on it.
- member_event = None
if member_event_id:
member_event = await self._store.get_event(member_event_id)
if member_event.membership in (Membership.JOIN, Membership.INVITE):
return True
# Otherwise, check if they should be allowed access via membership in a space.
- if self._event_auth_handler.has_restricted_join_rules(
+ if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
- allowed_spaces = (
- await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
+ allowed_rooms = (
+ await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
)
if await self._event_auth_handler.is_user_in_rooms(
- allowed_spaces, requester
+ allowed_rooms, requester
):
return True
@@ -478,10 +474,10 @@ class SpaceSummaryHandler:
if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
- allowed_spaces = (
- await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
+ allowed_rooms = (
+ await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
)
- for space_id in allowed_spaces:
+ for space_id in allowed_rooms:
if await self._auth.check_host_in_room(space_id, origin):
return True
@@ -514,17 +510,12 @@ class SpaceSummaryHandler:
current_state_ids[(EventTypes.Create, "")]
)
- # TODO: update once MSC1772 lands
- room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
- if not room_type:
- room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
-
room_version = await self._store.get_room_version(room_id)
- allowed_spaces = None
+ allowed_rooms = None
if await self._event_auth_handler.has_restricted_join_rules(
current_state_ids, room_version
):
- allowed_spaces = await self._event_auth_handler.get_spaces_that_allow_join(
+ allowed_rooms = await self._event_auth_handler.get_rooms_that_allow_join(
current_state_ids
)
@@ -540,8 +531,8 @@ class SpaceSummaryHandler:
),
"guest_can_join": stats["guest_access"] == "can_join",
"creation_ts": create_event.origin_server_ts,
- "room_type": room_type,
- "allowed_spaces": allowed_spaces,
+ "room_type": create_event.content.get(EventContentFields.ROOM_TYPE),
+ "allowed_spaces": allowed_rooms,
}
# Filter out Nones – rather omit the field altogether
@@ -569,9 +560,7 @@ class SpaceSummaryHandler:
[
event_id
for key, event_id in current_state_ids.items()
- # TODO: update once MSC1772 has been FCP for a period of time.
- if key[0] == EventTypes.MSC1772_SPACE_CHILD
- or key[0] == EventTypes.SpaceChild
+ if key[0] == EventTypes.SpaceChild
]
)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 044ff06d..0b297e54 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -41,7 +41,12 @@ from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.types import (
+ JsonDict,
+ UserID,
+ contains_invalid_mxid_characters,
+ create_requester,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.stringutils import random_string
@@ -185,11 +190,14 @@ class SsoHandler:
self._auth_handler = hs.get_auth_handler()
self._error_template = hs.config.sso_error_template
self._bad_user_template = hs.config.sso_auth_bad_user_template
+ self._profile_handler = hs.get_profile_handler()
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
+ self._sso_update_profile_information = hs.config.sso_update_profile_information
+
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
@@ -458,6 +466,21 @@ class SsoHandler:
request.getClientIP(),
)
new_user = True
+ elif self._sso_update_profile_information:
+ attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+ if attributes.display_name:
+ user_id_obj = UserID.from_string(user_id)
+ profile_display_name = await self._profile_handler.get_displayname(
+ user_id_obj
+ )
+ if profile_display_name != attributes.display_name:
+ requester = create_requester(
+ user_id,
+ authenticated_entity=user_id,
+ )
+ await self._profile_handler.set_displayname(
+ user_id_obj, requester, attributes.display_name, True
+ )
await self._auth_handler.complete_sso_login(
user_id,
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 383e3402..4e45d1da 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -1,4 +1,5 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -230,6 +231,8 @@ class StatsHandler:
room_stats_delta["left_members"] -= 1
elif prev_membership == Membership.BAN:
room_stats_delta["banned_members"] -= 1
+ elif prev_membership == Membership.KNOCK:
+ room_stats_delta["knocked_members"] -= 1
else:
raise ValueError(
"%r is not a valid prev_membership" % (prev_membership,)
@@ -251,6 +254,8 @@ class StatsHandler:
room_stats_delta["left_members"] += 1
elif membership == Membership.BAN:
room_stats_delta["banned_members"] += 1
+ elif membership == Membership.KNOCK:
+ room_stats_delta["knocked_members"] += 1
else:
raise ValueError("%r is not a valid membership" % (membership,))
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b1c58ffd..b9a03610 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -49,7 +49,7 @@ from synapse.types import (
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client
@@ -83,12 +83,15 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000
LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
+SyncRequestKey = Tuple[Any, ...]
+
+
@attr.s(slots=True, frozen=True)
class SyncConfig:
user = attr.ib(type=UserID)
filter_collection = attr.ib(type=FilterCollection)
is_guest = attr.ib(type=bool)
- request_key = attr.ib(type=Tuple[Any, ...])
+ request_key = attr.ib(type=SyncRequestKey)
device_id = attr.ib(type=Optional[str])
@@ -160,6 +163,16 @@ class InvitedSyncResult:
@attr.s(slots=True, frozen=True)
+class KnockedSyncResult:
+ room_id = attr.ib(type=str)
+ knock = attr.ib(type=EventBase)
+
+ def __bool__(self) -> bool:
+ """Knocked rooms should always be reported to the client"""
+ return True
+
+
+@attr.s(slots=True, frozen=True)
class GroupsSyncResult:
join = attr.ib(type=JsonDict)
invite = attr.ib(type=JsonDict)
@@ -192,6 +205,7 @@ class _RoomChanges:
room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
invited = attr.ib(type=List[InvitedSyncResult])
+ knocked = attr.ib(type=List[KnockedSyncResult])
newly_joined_rooms = attr.ib(type=List[str])
newly_left_rooms = attr.ib(type=List[str])
@@ -205,6 +219,7 @@ class SyncResult:
account_data: List of account_data events for the user.
joined: JoinedSyncResult for each joined room.
invited: InvitedSyncResult for each invited room.
+ knocked: KnockedSyncResult for each knocked on room.
archived: ArchivedSyncResult for each archived room.
to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed
@@ -220,6 +235,7 @@ class SyncResult:
account_data = attr.ib(type=List[JsonDict])
joined = attr.ib(type=List[JoinedSyncResult])
invited = attr.ib(type=List[InvitedSyncResult])
+ knocked = attr.ib(type=List[KnockedSyncResult])
archived = attr.ib(type=List[ArchivedSyncResult])
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
@@ -236,6 +252,7 @@ class SyncResult:
self.presence
or self.joined
or self.invited
+ or self.knocked
or self.archived
or self.account_data
or self.to_device
@@ -252,9 +269,9 @@ class SyncHandler:
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
- self.response_cache = ResponseCache(
+ self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache(
hs.get_clock(), "sync"
- ) # type: ResponseCache[Tuple[Any, ...]]
+ )
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
self.storage = hs.get_storage()
@@ -293,6 +310,7 @@ class SyncHandler:
since_token,
timeout,
full_state,
+ cache_context=True,
)
logger.debug("Returning sync response for %s", user_id)
return res
@@ -300,9 +318,10 @@ class SyncHandler:
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
- since_token: Optional[StreamToken] = None,
- timeout: int = 0,
- full_state: bool = False,
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult:
if since_token is None:
sync_type = "initial_sync"
@@ -329,13 +348,13 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
- result = await self.current_sync_for_user(
+ result: SyncResult = await self.current_sync_for_user(
sync_config, since_token, full_state=full_state
)
else:
- def current_sync_callback(before_token, after_token):
- return self.current_sync_for_user(sync_config, since_token)
+ async def current_sync_callback(before_token, after_token) -> SyncResult:
+ return await self.current_sync_for_user(sync_config, since_token)
result = await self.notifier.wait_for_events(
sync_config.user.to_string(),
@@ -344,6 +363,17 @@ class SyncHandler:
from_token=since_token,
)
+ # if nothing has happened in any of the users' rooms since /sync was called,
+ # the resultant next_batch will be the same as since_token (since the result
+ # is generated when wait_for_events is first called, and not regenerated
+ # when wait_for_events times out).
+ #
+ # If that happens, we mustn't cache it, so that when the client comes back
+ # with the same cache token, we don't immediately return the same empty
+ # result, causing a tightloop. (#8518)
+ if result.next_batch == since_token:
+ cache_context.should_cache = False
+
if result:
if sync_config.filter_collection.lazy_load_members():
lazy_loaded = "true"
@@ -1031,7 +1061,7 @@ class SyncHandler:
res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
+ newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@@ -1040,7 +1070,9 @@ class SyncHandler:
if self.hs_config.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
- sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
+ sync_result_builder,
+ newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users,
)
logger.debug("Fetching to-device data")
@@ -1049,7 +1081,7 @@ class SyncHandler:
device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
- newly_joined_or_invited_users=newly_joined_or_invited_users,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@@ -1083,6 +1115,7 @@ class SyncHandler:
account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined,
invited=sync_result_builder.invited,
+ knocked=sync_result_builder.knocked,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
@@ -1142,7 +1175,7 @@ class SyncHandler:
self,
sync_result_builder: "SyncResultBuilder",
newly_joined_rooms: Set[str],
- newly_joined_or_invited_users: Set[str],
+ newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
) -> DeviceLists:
@@ -1151,8 +1184,9 @@ class SyncHandler:
Args:
sync_result_builder
newly_joined_rooms: Set of rooms user has joined since previous sync
- newly_joined_or_invited_users: Set of users that have joined or
- been invited to a room since previous sync.
+ newly_joined_or_invited_or_knocked_users: Set of users that have joined,
+ been invited to a room or are knocking on a room since
+ previous sync.
newly_left_rooms: Set of rooms user has left since previous sync
newly_left_users: Set of users that have left a room we're in since
previous sync
@@ -1163,7 +1197,9 @@ class SyncHandler:
# We're going to mutate these fields, so lets copy them rather than
# assume they won't get used later.
- newly_joined_or_invited_users = set(newly_joined_or_invited_users)
+ newly_joined_or_invited_or_knocked_users = set(
+ newly_joined_or_invited_or_knocked_users
+ )
newly_left_users = set(newly_left_users)
if since_token and since_token.device_list_key:
@@ -1202,11 +1238,11 @@ class SyncHandler:
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
joined_users = await self.store.get_users_in_room(room_id)
- newly_joined_or_invited_users.update(joined_users)
+ newly_joined_or_invited_or_knocked_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
- users_that_have_changed.update(newly_joined_or_invited_users)
+ users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
user_signatures_changed = (
await self.store.get_users_whose_signatures_changed(
@@ -1452,6 +1488,7 @@ class SyncHandler:
room_entries = room_changes.room_entries
invited = room_changes.invited
+ knocked = room_changes.knocked
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
@@ -1472,9 +1509,10 @@ class SyncHandler:
await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
+ sync_result_builder.knocked.extend(knocked)
- # Now we want to get any newly joined or invited users
- newly_joined_or_invited_users = set()
+ # Now we want to get any newly joined, invited or knocking users
+ newly_joined_or_invited_or_knocked_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@@ -1486,19 +1524,22 @@ class SyncHandler:
if (
event.membership == Membership.JOIN
or event.membership == Membership.INVITE
+ or event.membership == Membership.KNOCK
):
- newly_joined_or_invited_users.add(event.state_key)
+ newly_joined_or_invited_or_knocked_users.add(
+ event.state_key
+ )
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
- newly_left_users -= newly_joined_or_invited_users
+ newly_left_users -= newly_joined_or_invited_or_knocked_users
return (
set(newly_joined_rooms),
- newly_joined_or_invited_users,
+ newly_joined_or_invited_or_knocked_users,
set(newly_left_rooms),
newly_left_users,
)
@@ -1553,6 +1594,7 @@ class SyncHandler:
newly_left_rooms = []
room_entries = []
invited = []
+ knocked = []
for room_id, events in mem_change_events_by_room_id.items():
logger.debug(
"Membership changes in %s: [%s]",
@@ -1632,9 +1674,17 @@ class SyncHandler:
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
if event.sender not in ignored_users:
- room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
- if room_sync:
- invited.append(room_sync)
+ invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+ if invite_room_sync:
+ invited.append(invite_room_sync)
+
+ # Only bother if our latest membership in the room is knock (and we haven't
+ # been accepted/rejected in the meantime).
+ should_knock = non_joins[-1].membership == Membership.KNOCK
+ if should_knock:
+ knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
+ if knock_room_sync:
+ knocked.append(knock_room_sync)
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
@@ -1738,7 +1788,13 @@ class SyncHandler:
)
room_entries.append(entry)
- return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
+ return _RoomChanges(
+ room_entries,
+ invited,
+ knocked,
+ newly_joined_rooms,
+ newly_left_rooms,
+ )
async def _get_all_rooms(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
@@ -1758,6 +1814,7 @@ class SyncHandler:
membership_list = (
Membership.INVITE,
+ Membership.KNOCK,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
@@ -1769,6 +1826,7 @@ class SyncHandler:
room_entries = []
invited = []
+ knocked = []
for event in room_list:
if event.membership == Membership.JOIN:
@@ -1788,8 +1846,11 @@ class SyncHandler:
continue
invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
+ elif event.membership == Membership.KNOCK:
+ knock = await self.store.get_event(event.event_id)
+ knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock))
elif event.membership in (Membership.LEAVE, Membership.BAN):
- # Always send down rooms we were banned or kicked from.
+ # Always send down rooms we were banned from or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if user_id == event.sender:
@@ -1810,7 +1871,7 @@ class SyncHandler:
)
)
- return _RoomChanges(room_entries, invited, [], [])
+ return _RoomChanges(room_entries, invited, knocked, [], [])
async def _generate_room_entry(
self,
@@ -2101,6 +2162,7 @@ class SyncResultBuilder:
account_data (list)
joined (list[JoinedSyncResult])
invited (list[InvitedSyncResult])
+ knocked (list[KnockedSyncResult])
archived (list[ArchivedSyncResult])
groups (GroupsSyncResult|None)
to_device (list)
@@ -2116,6 +2178,7 @@ class SyncResultBuilder:
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
+ knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1998990a..b8849c01 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -65,13 +65,9 @@ from synapse.http.client import (
read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
-from synapse.logging.opentracing import (
- inject_active_span_byte_dict,
- set_tag,
- start_active_span,
- tags,
-)
+from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -322,7 +318,9 @@ class MatrixFederationHttpClient:
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
- hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
+ hs.get_reactor(),
+ hs.config.federation_ip_range_whitelist,
+ hs.config.federation_ip_range_blacklist,
) # type: ISynapseReactor
user_agent = hs.version_string
@@ -497,7 +495,7 @@ class MatrixFederationHttpClient:
# Inject the span into the headers
headers_dict = {} # type: Dict[bytes, List[bytes]]
- inject_active_span_byte_dict(headers_dict, request.destination)
+ opentracing.inject_header_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index d61563d3..fda8da21 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,7 +13,6 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
-
import logging
from typing import Dict, Iterable, List, Optional, overload
@@ -295,6 +294,30 @@ def parse_strings_from_args(
return default
+@overload
+def parse_string_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[str] = None,
+ required: Literal[True] = True,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> str:
+ ...
+
+
+@overload
+def parse_string_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[str] = None,
+ required: bool = False,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> Optional[str]:
+ ...
+
+
def parse_string_from_args(
args: Dict[bytes, List[bytes]],
name: str,
@@ -431,7 +454,7 @@ class RestServlet:
"""
def register(self, http_server):
- """ Register this servlet with the given HTTP server. """
+ """Register this servlet with the given HTTP server."""
patterns = getattr(self, "PATTERNS", None)
if patterns:
for method in ("GET", "PUT", "POST", "DELETE"):
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 8002a250..6e82f7c7 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -20,8 +20,9 @@ import logging
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
-# The properties of a standard LogRecord.
-_LOG_RECORD_ATTRIBUTES = {
+# The properties of a standard LogRecord that should be ignored when generating
+# JSON logs.
+_IGNORED_LOG_RECORD_ATTRIBUTES = {
"args",
"asctime",
"created",
@@ -59,9 +60,9 @@ class JsonFormatter(logging.Formatter):
return self._format(record, event)
def _format(self, record: logging.LogRecord, event: dict) -> str:
- # Add any extra attributes to the event.
+ # Add attributes specified via the extra keyword to the logged event.
for key, value in record.__dict__.items():
- if key not in _LOG_RECORD_ATTRIBUTES:
+ if key not in _IGNORED_LOG_RECORD_ATTRIBUTES:
event[key] = value
return _encoder.encode(event)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index dd937734..140ed711 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -168,11 +168,12 @@ import inspect
import logging
import re
from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Pattern, Type
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type
import attr
from twisted.internet import defer
+from twisted.web.http_headers import Headers
from synapse.config import ConfigError
from synapse.util import json_decoder, json_encoder
@@ -278,6 +279,10 @@ class SynapseTags:
DB_TXN_ID = "db.txn_id"
+class SynapseBaggage:
+ FORCE_TRACING = "synapse-force-tracing"
+
+
# Block everything by default
# A regex which matches the server_names to expose traces for.
# None means 'block everything'.
@@ -285,6 +290,8 @@ _homeserver_whitelist = None # type: Optional[Pattern[str]]
# Util methods
+Sentinel = object()
+
def only_if_tracing(func):
"""Executes the function only if we're tracing. Otherwise returns None."""
@@ -447,12 +454,28 @@ def start_active_span(
)
-def start_active_span_follows_from(operation_name, contexts):
+def start_active_span_follows_from(
+ operation_name: str, contexts: Collection, inherit_force_tracing=False
+):
+ """Starts an active opentracing span, with additional references to previous spans
+
+ Args:
+ operation_name: name of the operation represented by the new span
+ contexts: the previous spans to inherit from
+ inherit_force_tracing: if set, and any of the previous contexts have had tracing
+ forced, the new span will also have tracing forced.
+ """
if opentracing is None:
return noop_context_manager()
references = [opentracing.follows_from(context) for context in contexts]
scope = start_active_span(operation_name, references=references)
+
+ if inherit_force_tracing and any(
+ is_context_forced_tracing(ctx) for ctx in contexts
+ ):
+ force_tracing(scope.span)
+
return scope
@@ -551,6 +574,10 @@ def start_active_span_from_edu(
# Opentracing setters for tags, logs, etc
+@only_if_tracing
+def active_span():
+ """Get the currently active span, if any"""
+ return opentracing.tracer.active_span
@ensure_active_span("set a tag")
@@ -571,62 +598,52 @@ def set_operation_name(operation_name):
opentracing.tracer.active_span.set_operation_name(operation_name)
-# Injection and extraction
-
+@only_if_tracing
+def force_tracing(span=Sentinel) -> None:
+ """Force sampling for the active/given span and its children.
-@ensure_active_span("inject the span into a header")
-def inject_active_span_twisted_headers(headers, destination, check_destination=True):
+ Args:
+ span: span to force tracing for. By default, the active span.
"""
- Injects a span context into twisted headers in-place
+ if span is Sentinel:
+ span = opentracing.tracer.active_span
+ if span is None:
+ logger.error("No active span in force_tracing")
+ return
- Args:
- headers (twisted.web.http_headers.Headers)
- destination (str): address of entity receiving the span context. If check_destination
- is true the context will only be injected if the destination matches the
- opentracing whitelist
- check_destination (bool): If false, destination will be ignored and the context
- will always be injected.
- span (opentracing.Span)
+ span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
- Returns:
- In-place modification of headers
+ # also set a bit of baggage, so that we have a way of figuring out if
+ # it is enabled later
+ span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
- Note:
- The headers set by the tracer are custom to the tracer implementation which
- should be unique enough that they don't interfere with any headers set by
- synapse or twisted. If we're still using jaeger these headers would be those
- here:
- https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
- """
- if check_destination and not whitelisted_homeserver(destination):
- return
+def is_context_forced_tracing(span_context) -> bool:
+ """Check if sampling has been force for the given span context."""
+ if span_context is None:
+ return False
+ return span_context.baggage.get(SynapseBaggage.FORCE_TRACING) is not None
- span = opentracing.tracer.active_span
- carrier = {} # type: Dict[str, str]
- opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
- for key, value in carrier.items():
- headers.addRawHeaders(key, value)
+# Injection and extraction
-@ensure_active_span("inject the span into a byte dict")
-def inject_active_span_byte_dict(headers, destination, check_destination=True):
+@ensure_active_span("inject the span into a header dict")
+def inject_header_dict(
+ headers: Dict[bytes, List[bytes]],
+ destination: Optional[str] = None,
+ check_destination: bool = True,
+) -> None:
"""
- Injects a span context into a dict where the headers are encoded as byte
- strings
+ Injects a span context into a dict of HTTP headers
Args:
- headers (dict)
- destination (str): address of entity receiving the span context. If check_destination
- is true the context will only be injected if the destination matches the
- opentracing whitelist
+ headers: the dict to inject headers into
+ destination: address of entity receiving the span context. Must be given unless
+ check_destination is False. The context will only be injected if the
+ destination matches the opentracing whitelist
check_destination (bool): If false, destination will be ignored and the context
will always be injected.
- span (opentracing.Span)
-
- Returns:
- In-place modification of headers
Note:
The headers set by the tracer are custom to the tracer implementation which
@@ -635,8 +652,13 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
here:
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
- if check_destination and not whitelisted_homeserver(destination):
- return
+ if check_destination:
+ if destination is None:
+ raise ValueError(
+ "destination must be given unless check_destination is False"
+ )
+ if not whitelisted_homeserver(destination):
+ return
span = opentracing.tracer.active_span
@@ -647,36 +669,23 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
headers[key.encode()] = [value.encode()]
-@ensure_active_span("inject the span into a text map")
-def inject_active_span_text_map(carrier, destination, check_destination=True):
- """
- Injects a span context into a dict
-
- Args:
- carrier (dict)
- destination (str): address of entity receiving the span context. If check_destination
- is true the context will only be injected if the destination matches the
- opentracing whitelist
- check_destination (bool): If false, destination will be ignored and the context
- will always be injected.
-
- Returns:
- In-place modification of carrier
-
- Note:
- The headers set by the tracer are custom to the tracer implementation which
- should be unique enough that they don't interfere with any headers set by
- synapse or twisted. If we're still using jaeger these headers would be those
- here:
- https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
- """
-
- if check_destination and not whitelisted_homeserver(destination):
+def inject_response_headers(response_headers: Headers) -> None:
+ """Inject the current trace id into the HTTP response headers"""
+ if not opentracing:
+ return
+ span = opentracing.tracer.active_span
+ if not span:
return
- opentracing.tracer.inject(
- opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
- )
+ # This is a bit implementation-specific.
+ #
+ # Jaeger's Spans have a trace_id property; other implementations (including the
+ # dummy opentracing.span.Span which we use if init_tracer is not called) do not
+ # expose it
+ trace_id = getattr(span, "trace_id", None)
+
+ if trace_id is not None:
+ response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
@ensure_active_span("get the active span context as a dict", ret={})
@@ -854,6 +863,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
scope = start_active_span(request_name)
with scope:
+ inject_response_headers(request.responseHeaders)
try:
yield
finally:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index cecdc96b..58b255eb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
from twisted.internet import defer
+from twisted.web.resource import IResource
from synapse.events import EventBase
from synapse.http.client import SimpleHttpClient
@@ -42,7 +43,7 @@ class ModuleApi:
can register new users etc if necessary.
"""
- def __init__(self, hs, auth_handler):
+ def __init__(self, hs: "HomeServer", auth_handler):
self._hs = hs
self._store = hs.get_datastore()
@@ -56,6 +57,33 @@ class ModuleApi:
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
self._public_room_list_manager = PublicRoomListManager(hs)
+ self._spam_checker = hs.get_spam_checker()
+
+ #################################################################################
+ # The following methods should only be called during the module's initialisation.
+
+ @property
+ def register_spam_checker_callbacks(self):
+ """Registers callbacks for spam checking capabilities."""
+ return self._spam_checker.register_callbacks
+
+ def register_web_resource(self, path: str, resource: IResource):
+ """Registers a web resource to be served at the given path.
+
+ This function should be called during initialisation of the module.
+
+ If multiple modules register a resource for the same path, the module that
+ appears the highest in the configuration file takes priority.
+
+ Args:
+ path: The path to register the resource for.
+ resource: The resource to attach to this path.
+ """
+ self._hs.register_module_web_resource(path, resource)
+
+ #########################################################################
+ # The following methods can be called by the module at any point in time.
+
@property
def http_client(self):
"""Allows making outbound HTTP requests to remote resources.
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index d24864c5..02bbb0be 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -15,3 +15,4 @@
"""Exception types which are exposed as part of the stable module API"""
from synapse.api.errors import RedirectException, SynapseError # noqa: F401
+from synapse.config._base import ConfigError # noqa: F401
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 546231be..271c17c2 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -75,11 +75,9 @@ REQUIREMENTS = [
"phonenumbers>=8.2.0",
# we use GaugeHistogramMetric, which was added in prom-client 0.4.0.
"prometheus_client>=0.4.0",
- # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
- # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
- # is out in November.)
+ # we use `order`, which arrived in attrs 19.2.0.
# Note: 21.1.0 broke `/sync`, see #9936
- "attrs>=19.1.0,!=21.1.0",
+ "attrs>=19.2.0,!=21.1.0",
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
@@ -98,11 +96,6 @@ CONDITIONAL_REQUIREMENTS = {
"psycopg2cffi>=2.8 ; platform_python_implementation == 'PyPy'",
"psycopg2cffi-compat==1.1 ; platform_python_implementation == 'PyPy'",
],
- # ACME support is required to provision TLS certificates from authorities
- # that use the protocol, such as Let's Encrypt.
- "acme": [
- "txacme>=0.9.2",
- ],
"saml2": [
"pysaml2>=4.5.0",
],
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 5685cf21..f13a7c23 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -23,7 +23,8 @@ from prometheus_client import Counter, Gauge
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
-from synapse.logging.opentracing import inject_active_span_byte_dict, trace
+from synapse.logging import opentracing
+from synapse.logging.opentracing import trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -235,7 +236,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
- inject_active_span_byte_dict(headers, None, check_destination=False)
+ opentracing.inject_header_dict(headers, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
break
@@ -284,7 +285,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
self.__class__.__name__,
)
- def _check_auth_and_handle(self, request, **kwargs):
+ async def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@@ -299,8 +300,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self.CACHE:
txn_id = kwargs.pop("txn_id")
- return self.response_cache.wrap(
+ return await self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
)
- return self._handle_request(request, **kwargs)
+ return await self._handle_request(request, **kwargs)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 289a397d..34206c50 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -97,6 +97,76 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id}
+class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
+ """Perform a remote knock for the given user on the given room
+
+ Request format:
+
+ POST /_synapse/replication/remote_knock/:room_id/:user_id
+
+ {
+ "requester": ...,
+ "remote_room_hosts": [...],
+ "content": { ... }
+ }
+ """
+
+ NAME = "remote_knock"
+ PATH_ARGS = ("room_id", "user_id")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.federation_handler = hs.get_federation_handler()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore
+ requester: Requester,
+ room_id: str,
+ user_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ ):
+ """
+ Args:
+ requester: The user making the request, according to the access token.
+ room_id: The ID of the room to knock on.
+ user_id: The ID of the knocking user.
+ remote_room_hosts: Servers to try and send the knock via.
+ content: The event content to use for the knock event.
+ """
+ return {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "content": content,
+ }
+
+ async def _handle_request( # type: ignore
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ user_id: str,
+ ):
+ content = parse_json_object_from_request(request)
+
+ remote_room_hosts = content["remote_room_hosts"]
+ event_content = content["content"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ request.requester = requester
+
+ logger.debug("remote_knock: %s on room: %s", user_id, room_id)
+
+ event_id, stream_id = await self.federation_handler.do_knock(
+ remote_room_hosts, room_id, user_id, event_content
+ )
+
+ return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects an out-of-band invite we have received from a remote server
@@ -167,6 +237,75 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id}
+class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
+ """Rescinds a local knock made on a remote room
+
+ Request format:
+
+ POST /_synapse/replication/remote_rescind_knock/:event_id
+
+ {
+ "txn_id": ...,
+ "requester": ...,
+ "content": { ... }
+ }
+ """
+
+ NAME = "remote_rescind_knock"
+ PATH_ARGS = ("knock_event_id",)
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.member_handler = hs.get_room_member_handler()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore
+ knock_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ):
+ """
+ Args:
+ knock_event_id: The ID of the knock to be rescinded.
+ txn_id: An optional transaction ID supplied by the client.
+ requester: The user making the rescind request, according to the access token.
+ content: The content to include in the rescind event.
+ """
+ return {
+ "txn_id": txn_id,
+ "requester": requester.serialize(),
+ "content": content,
+ }
+
+ async def _handle_request( # type: ignore
+ self,
+ request: SynapseRequest,
+ knock_event_id: str,
+ ):
+ content = parse_json_object_from_request(request)
+
+ txn_id = content["txn_id"]
+ event_content = content["content"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ request.requester = requester
+
+ # hopefully we're now on the master, so this won't recurse!
+ event_id, stream_id = await self.member_handler.remote_rescind_knock(
+ knock_event_id,
+ txn_id,
+ requester,
+ event_content,
+ )
+
+ return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
@@ -206,7 +345,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return {}
- def _handle_request( # type: ignore
+ async def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str, change: str
) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 7ced4c54..2ad7a200 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -571,7 +571,7 @@ class ReplicationCommandHandler:
def on_REMOTE_SERVER_UP(
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
):
- """"Called when get a new REMOTE_SERVER_UP command."""
+ """Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
self._notifier.notify_remote_server_up(cmd.data)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 79d52d2d..d29f2fea 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -38,6 +38,7 @@ from synapse.rest.client.v2_alpha import (
filter,
groups,
keys,
+ knock,
notifications,
openid,
password_policy,
@@ -120,6 +121,7 @@ class ClientRestResource(JsonResource):
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
+ knock.register_servlets(hs, client_resource)
# moving to /_synapse/admin
admin.register_servlets_for_client_rest_resource(hs, client_resource)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 12210585..92ebe838 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -14,13 +14,12 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
-
import logging
import re
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from urllib import parse as urlparse
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -38,6 +37,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
parse_string,
+ parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
@@ -266,6 +266,288 @@ class RoomSendEventRestServlet(TransactionRestServlet):
)
+class RoomBatchSendEventRestServlet(TransactionRestServlet):
+ """
+ API endpoint which can insert a chunk of events historically back in time
+ next to the given `prev_event`.
+
+ `chunk_id` comes from `next_chunk_id `in the response of the batch send
+ endpoint and is derived from the "insertion" events added to each chunk.
+ It's not required for the first batch send.
+
+ `state_events_at_start` is used to define the historical state events
+ needed to auth the events like join events. These events will float
+ outside of the normal DAG as outlier's and won't be visible in the chat
+ history which also allows us to insert multiple chunks without having a bunch
+ of `@mxid joined the room` noise between each chunk.
+
+ `events` is chronological chunk/list of events you want to insert.
+ There is a reverse-chronological constraint on chunks so once you insert
+ some messages, you can only insert older ones after that.
+ tldr; Insert chunks from your most recent history -> oldest history.
+
+ POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event=<eventID>&chunk_id=<chunkID>
+ {
+ "events": [ ... ],
+ "state_events_at_start": [ ... ]
+ }
+ """
+
+ PATTERNS = (
+ re.compile(
+ "^/_matrix/client/unstable/org.matrix.msc2716"
+ "/rooms/(?P<room_id>[^/]*)/batch_send$"
+ ),
+ )
+
+ def __init__(self, hs):
+ super().__init__(hs)
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state_store = hs.get_storage().state
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+ (
+ most_recent_prev_event_id,
+ most_recent_prev_event_depth,
+ ) = await self.store.get_max_depth_of(prev_event_ids)
+
+ # We want to insert the historical event after the `prev_event` but before the successor event
+ #
+ # We inherit depth from the successor event instead of the `prev_event`
+ # because events returned from `/messages` are first sorted by `topological_ordering`
+ # which is just the `depth` and then tie-break with `stream_ordering`.
+ #
+ # We mark these inserted historical events as "backfilled" which gives them a
+ # negative `stream_ordering`. If we use the same depth as the `prev_event`,
+ # then our historical event will tie-break and be sorted before the `prev_event`
+ # when it should come after.
+ #
+ # We want to use the successor event depth so they appear after `prev_event` because
+ # it has a larger `depth` but before the successor event because the `stream_ordering`
+ # is negative before the successor event.
+ successor_event_ids = await self.store.get_successor_events(
+ [most_recent_prev_event_id]
+ )
+
+ # If we can't find any successor events, then it's a forward extremity of
+ # historical messages and we can just inherit from the previous historical
+ # event which we can already assume has the correct depth where we want
+ # to insert into.
+ if not successor_event_ids:
+ depth = most_recent_prev_event_depth
+ else:
+ (
+ _,
+ oldest_successor_depth,
+ ) = await self.store.get_min_depth_of(successor_event_ids)
+
+ depth = oldest_successor_depth
+
+ return depth
+
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
+
+ if not requester.app_service:
+ raise AuthError(
+ 403,
+ "Only application services can use the /batchsend endpoint",
+ )
+
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ["state_events_at_start", "events"])
+
+ prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
+ chunk_id_from_query = parse_string(request, "chunk_id", default=None)
+
+ if prev_events_from_query is None:
+ raise SynapseError(
+ 400,
+ "prev_event query parameter is required when inserting historical messages back in time",
+ errcode=Codes.MISSING_PARAM,
+ )
+
+ # For the event we are inserting next to (`prev_events_from_query`),
+ # find the most recent auth events (derived from state events) that
+ # allowed that message to be sent. We will use that as a base
+ # to auth our historical messages against.
+ (
+ most_recent_prev_event_id,
+ _,
+ ) = await self.store.get_max_depth_of(prev_events_from_query)
+ # mapping from (type, state_key) -> state_event_id
+ prev_state_map = await self.state_store.get_state_ids_for_event(
+ most_recent_prev_event_id
+ )
+ # List of state event ID's
+ prev_state_ids = list(prev_state_map.values())
+ auth_event_ids = prev_state_ids
+
+ for state_event in body["state_events_at_start"]:
+ assert_params_in_dict(
+ state_event, ["type", "origin_server_ts", "content", "sender"]
+ )
+
+ logger.debug(
+ "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
+ state_event,
+ auth_event_ids,
+ )
+
+ event_dict = {
+ "type": state_event["type"],
+ "origin_server_ts": state_event["origin_server_ts"],
+ "content": state_event["content"],
+ "room_id": room_id,
+ "sender": state_event["sender"],
+ "state_key": state_event["state_key"],
+ }
+
+ # Make the state events float off on their own
+ fake_prev_event_id = "$" + random_string(43)
+
+ # TODO: This is pretty much the same as some other code to handle inserting state in this file
+ if event_dict["type"] == EventTypes.Member:
+ membership = event_dict["content"].get("membership", None)
+ event_id, _ = await self.room_member_handler.update_membership(
+ requester,
+ target=UserID.from_string(event_dict["state_key"]),
+ room_id=room_id,
+ action=membership,
+ content=event_dict["content"],
+ outlier=True,
+ prev_event_ids=[fake_prev_event_id],
+ # Make sure to use a copy of this list because we modify it
+ # later in the loop here. Otherwise it will be the same
+ # reference and also update in the event when we append later.
+ auth_event_ids=auth_event_ids.copy(),
+ )
+ else:
+ # TODO: Add some complement tests that adds state that is not member joins
+ # and will use this code path. Maybe we only want to support join state events
+ # and can get rid of this `else`?
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ event_dict,
+ outlier=True,
+ prev_event_ids=[fake_prev_event_id],
+ # Make sure to use a copy of this list because we modify it
+ # later in the loop here. Otherwise it will be the same
+ # reference and also update in the event when we append later.
+ auth_event_ids=auth_event_ids.copy(),
+ )
+ event_id = event.event_id
+
+ auth_event_ids.append(event_id)
+
+ events_to_create = body["events"]
+
+ # If provided, connect the chunk to the last insertion point
+ # The chunk ID passed in comes from the chunk_id in the
+ # "insertion" event from the previous chunk.
+ if chunk_id_from_query:
+ last_event_in_chunk = events_to_create[-1]
+ last_event_in_chunk["content"][
+ EventContentFields.MSC2716_CHUNK_ID
+ ] = chunk_id_from_query
+
+ # Add an "insertion" event to the start of each chunk (next to the oldest
+ # event in the chunk) so the next chunk can be connected to this one.
+ next_chunk_id = random_string(64)
+ insertion_event = {
+ "type": EventTypes.MSC2716_INSERTION,
+ "sender": requester.user.to_string(),
+ "content": {
+ EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
+ EventContentFields.MSC2716_HISTORICAL: True,
+ },
+ # Since the insertion event is put at the start of the chunk,
+ # where the oldest event is, copy the origin_server_ts from
+ # the first event we're inserting
+ "origin_server_ts": events_to_create[0]["origin_server_ts"],
+ }
+ # Prepend the insertion event to the start of the chunk
+ events_to_create = [insertion_event] + events_to_create
+
+ inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query)
+
+ event_ids = []
+ prev_event_ids = prev_events_from_query
+ events_to_persist = []
+ for ev in events_to_create:
+ assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
+
+ # Mark all events as historical
+ # This has important semantics within the Synapse internals to backfill properly
+ ev["content"][EventContentFields.MSC2716_HISTORICAL] = True
+
+ event_dict = {
+ "type": ev["type"],
+ "origin_server_ts": ev["origin_server_ts"],
+ "content": ev["content"],
+ "room_id": room_id,
+ "sender": ev["sender"], # requester.user.to_string(),
+ "prev_events": prev_event_ids.copy(),
+ }
+
+ event, context = await self.event_creation_handler.create_event(
+ requester,
+ event_dict,
+ prev_event_ids=event_dict.get("prev_events"),
+ auth_event_ids=auth_event_ids,
+ historical=True,
+ depth=inherited_depth,
+ )
+ logger.debug(
+ "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
+ event,
+ prev_event_ids,
+ auth_event_ids,
+ )
+
+ assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
+ event.sender,
+ )
+
+ events_to_persist.append((event, context))
+ event_id = event.event_id
+
+ event_ids.append(event_id)
+ prev_event_ids = [event_id]
+
+ # Persist events in reverse-chronological order so they have the
+ # correct stream_ordering as they are backfilled (which decrements).
+ # Events are sorted by (topological_ordering, stream_ordering)
+ # where topological_ordering is just depth.
+ for (event, context) in reversed(events_to_persist):
+ ev = await self.event_creation_handler.handle_new_client_event(
+ requester=requester,
+ event=event,
+ context=context,
+ )
+
+ return 200, {
+ "state_events": auth_event_ids,
+ "events": event_ids,
+ "next_chunk_id": next_chunk_id,
+ }
+
+ def on_GET(self, request, room_id):
+ return 501, "Not implemented"
+
+ def on_PUT(self, request, room_id):
+ return self.txns.fetch_or_execute_request(
+ request, self.on_POST, request, room_id
+ )
+
+
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs):
@@ -278,7 +560,12 @@ class JoinRoomAliasServlet(TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_identifier, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_identifier: str,
+ txn_id: Optional[str] = None,
+ ):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
@@ -290,17 +577,18 @@ class JoinRoomAliasServlet(TransactionRestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
- try:
- remote_room_hosts = [
- x.decode("ascii") for x in request.args[b"server_name"]
- ] # type: Optional[List[str]]
- except Exception:
- remote_room_hosts = None
+
+ # twisted.web.server.Request.args is incorrectly defined as Optional[Any]
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
+
+ remote_room_hosts = parse_strings_from_args(
+ args, "server_name", required=False
+ )
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
+ room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+ room_id = room_id_obj.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
@@ -1048,6 +1336,8 @@ class RoomSpaceSummaryRestServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server, is_worker=False):
+ msc2716_enabled = hs.config.experimental.msc2716_enabled
+
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -1055,6 +1345,8 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
JoinRoomAliasServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
+ if msc2716_enabled:
+ RoomBatchSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 9af05f9b..8b9674db 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -86,6 +86,9 @@ class DeleteDevicesRestServlet(RestServlet):
request,
body,
"remove device(s) from your account",
+ # Users might call this multiple times in a row while cleaning up
+ # devices, allow a single UI auth session to be re-used.
+ can_skip_ui_auth=True,
)
await self.device_handler.delete_devices(
@@ -135,6 +138,9 @@ class DeviceRestServlet(RestServlet):
request,
body,
"remove a device from your account",
+ # Users might call this multiple times in a row while cleaning up
+ # devices, allow a single UI auth session to be re-used.
+ can_skip_ui_auth=True,
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index a57ccbb5..33cf8de1 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
+ device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
+ result = await self.e2e_keys_handler.query_devices(
+ body, timeout, user_id, device_id
+ )
return 200, result
@@ -274,6 +277,9 @@ class SigningKeyUploadServlet(RestServlet):
request,
body,
"add a device signing key to your account",
+ # Allow skipping of UI auth since this is frequently called directly
+ # after login and it is silly to ask users to re-auth immediately.
+ can_skip_ui_auth=True,
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/v2_alpha/knock.py
new file mode 100644
index 00000000..7d1bc406
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/knock.py
@@ -0,0 +1,107 @@
+# Copyright 2020 Sorunome
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.constants import Membership
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_json_object_from_request,
+ parse_strings_from_args,
+)
+from synapse.http.site import SynapseRequest
+from synapse.logging.opentracing import set_tag
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.types import JsonDict, RoomAlias, RoomID
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class KnockRoomAliasServlet(RestServlet):
+ """
+ POST /knock/{roomIdOrAlias}
+ """
+
+ PATTERNS = client_patterns("/knock/(?P<room_identifier>[^/]*)")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.txns = HttpTransactionCache(hs)
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_identifier: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+
+ content = parse_json_object_from_request(request)
+ event_content = None
+ if "reason" in content:
+ event_content = {"reason": content["reason"]}
+
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+
+ # twisted.web.server.Request.args is incorrectly defined as Optional[Any]
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
+
+ remote_room_hosts = parse_strings_from_args(
+ args, "server_name", required=False
+ )
+ elif RoomAlias.is_valid(room_identifier):
+ handler = self.room_member_handler
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+ room_id = room_id_obj.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=requester.user,
+ room_id=room_id,
+ action=Membership.KNOCK,
+ txn_id=txn_id,
+ third_party_signed=None,
+ remote_room_hosts=remote_room_hosts,
+ content=event_content,
+ )
+
+ return 200, {"room_id": room_id}
+
+ def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
+ set_tag("txn_id", txn_id)
+
+ return self.txns.fetch_or_execute_request(
+ request, self.on_POST, request, room_identifier, txn_id
+ )
+
+
+def register_servlets(hs, http_server):
+ KnockRoomAliasServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index d3322acc..e8d26738 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -85,7 +85,7 @@ class IdTokenServlet(RestServlet):
"access_token": token,
"token_type": "Bearer",
"matrix_server_name": self.server_name,
- "expires_in": self.EXPIRES_MS / 1000,
+ "expires_in": self.EXPIRES_MS // 1000,
},
)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 95ee3f1b..042e1788 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
-from synapse.api.constants import PresenceState
+from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
@@ -24,7 +23,7 @@ from synapse.events.utils import (
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import SyncConfig
+from synapse.handlers.sync import KnockedSyncResult, SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken
@@ -220,6 +219,10 @@ class SyncRestServlet(RestServlet):
sync_result.invited, time_now, access_token_id, event_formatter
)
+ knocked = await self.encode_knocked(
+ sync_result.knocked, time_now, access_token_id, event_formatter
+ )
+
archived = await self.encode_archived(
sync_result.archived,
time_now,
@@ -237,11 +240,16 @@ class SyncRestServlet(RestServlet):
"left": list(sync_result.device_lists.left),
},
"presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
- "rooms": {"join": joined, "invite": invited, "leave": archived},
+ "rooms": {
+ Membership.JOIN: joined,
+ Membership.INVITE: invited,
+ Membership.KNOCK: knocked,
+ Membership.LEAVE: archived,
+ },
"groups": {
- "join": sync_result.groups.join,
- "invite": sync_result.groups.invite,
- "leave": sync_result.groups.leave,
+ Membership.JOIN: sync_result.groups.join,
+ Membership.INVITE: sync_result.groups.invite,
+ Membership.LEAVE: sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
@@ -303,7 +311,7 @@ class SyncRestServlet(RestServlet):
Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
- sync results for rooms this user is joined to
+ sync results for rooms this user is invited to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
@@ -322,7 +330,7 @@ class SyncRestServlet(RestServlet):
time_now,
token_id=token_id,
event_format=event_formatter,
- is_invite=True,
+ include_stripped_room_state=True,
)
unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
@@ -332,6 +340,60 @@ class SyncRestServlet(RestServlet):
return invited
+ async def encode_knocked(
+ self,
+ rooms: List[KnockedSyncResult],
+ time_now: int,
+ token_id: int,
+ event_formatter: Callable[[Dict], Dict],
+ ) -> Dict[str, Dict[str, Any]]:
+ """
+ Encode the rooms we've knocked on in a sync result.
+
+ Args:
+ rooms: list of sync results for rooms this user is knocking on
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing of transaction IDs
+ event_formatter: function to convert from federation format to client format
+
+ Returns:
+ The list of rooms the user has knocked on, in our response format.
+ """
+ knocked = {}
+ for room in rooms:
+ knock = await self._event_serializer.serialize_event(
+ room.knock,
+ time_now,
+ token_id=token_id,
+ event_format=event_formatter,
+ include_stripped_room_state=True,
+ )
+
+ # Extract the `unsigned` key from the knock event.
+ # This is where we (cheekily) store the knock state events
+ unsigned = knock.setdefault("unsigned", {})
+
+ # Duplicate the dictionary in order to avoid modifying the original
+ unsigned = dict(unsigned)
+
+ # Extract the stripped room state from the unsigned dict
+ # This is for clients to get a little bit of information about
+ # the room they've knocked on, without revealing any sensitive information
+ knocked_state = list(unsigned.pop("knock_room_state", []))
+
+ # Append the actual knock membership event itself as well. This provides
+ # the client with:
+ #
+ # * A knock state event that they can use for easier internal tracking
+ # * The rough timestamp of when the knock occurred contained within the event
+ knocked_state.append(knock)
+
+ # Build the `knock_state` dictionary, which will contain the state of the
+ # room that the client has knocked on
+ knocked[room.room_id] = {"knock_state": {"events": knocked_state}}
+
+ return knocked
+
async def encode_archived(
self, rooms, time_now, token_id, event_fields, event_formatter
):
diff --git a/synapse/server.py b/synapse/server.py
index fec0024c..2c27d2a7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -1,6 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,6 +37,7 @@ import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS
+from twisted.web.resource import IResource
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
@@ -66,7 +65,6 @@ from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionR
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler
-from synapse.handlers.acme import AcmeHandler
from synapse.handlers.admin import AdminHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
@@ -259,6 +257,38 @@ class HomeServer(metaclass=abc.ABCMeta):
self.datastores = None # type: Optional[Databases]
+ self._module_web_resources: Dict[str, IResource] = {}
+ self._module_web_resources_consumed = False
+
+ def register_module_web_resource(self, path: str, resource: IResource):
+ """Allows a module to register a web resource to be served at the given path.
+
+ If multiple modules register a resource for the same path, the module that
+ appears the highest in the configuration file takes priority.
+
+ Args:
+ path: The path to register the resource for.
+ resource: The resource to attach to this path.
+
+ Raises:
+ SynapseError(500): A module tried to register a web resource after the HTTP
+ listeners have been started.
+ """
+ if self._module_web_resources_consumed:
+ raise RuntimeError(
+ "Tried to register a web resource from a module after startup",
+ )
+
+ # Don't register a resource that's already been registered.
+ if path not in self._module_web_resources.keys():
+ self._module_web_resources[path] = resource
+ else:
+ logger.warning(
+ "Module tried to register a web resource for path %s but another module"
+ " has already registered a resource for this path.",
+ path,
+ )
+
def get_instance_id(self) -> str:
"""A unique ID for this synapse process instance.
@@ -495,10 +525,6 @@ class HomeServer(metaclass=abc.ABCMeta):
return E2eRoomKeysHandler(self)
@cache_in_self
- def get_acme_handler(self) -> AcmeHandler:
- return AcmeHandler(self)
-
- @cache_in_self
def get_admin_handler(self) -> AdminHandler:
return AdminHandler(self)
@@ -651,7 +677,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_spam_checker(self) -> SpamChecker:
- return SpamChecker(self)
+ return SpamChecker()
@cache_in_self
def get_third_party_event_rules(self) -> ThirdPartyEventRules:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9ba5778a..0e3dd4e9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -62,6 +62,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ self._allow_device_name_lookup_over_federation = (
+ self.hs.config.federation.allow_device_name_lookup_over_federation
+ )
+
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
@@ -85,7 +92,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
result["keys"] = keys
device_display_name = None
- if self.hs.config.allow_device_name_lookup_over_federation:
+ if self._allow_device_name_lookup_over_federation:
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ff81d5cd..c0ea4455 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,6 +16,7 @@ import logging
from queue import Empty, PriorityQueue
from typing import Collection, Dict, Iterable, List, Set, Tuple
+from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
from synapse.events import EventBase
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -670,8 +671,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return dict(txn)
- async def get_max_depth_of(self, event_ids: List[str]) -> int:
- """Returns the max depth of a set of event IDs
+ async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+ """Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
event_ids: The event IDs to calculate the max depth of.
@@ -680,14 +681,53 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
table="events",
column="event_id",
iterable=event_ids,
- retcols=("depth",),
+ retcols=(
+ "event_id",
+ "depth",
+ ),
desc="get_max_depth_of",
)
if not rows:
- return 0
+ return None, 0
else:
- return max(row["depth"] for row in rows)
+ max_depth_event_id = ""
+ current_max_depth = 0
+ for row in rows:
+ if row["depth"] > current_max_depth:
+ max_depth_event_id = row["event_id"]
+ current_max_depth = row["depth"]
+
+ return max_depth_event_id, current_max_depth
+
+ async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+ """Returns the event ID and depth for the event that has the min depth from a set of event IDs
+
+ Args:
+ event_ids: The event IDs to calculate the max depth of.
+ """
+ rows = await self.db_pool.simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=(
+ "event_id",
+ "depth",
+ ),
+ desc="get_min_depth_of",
+ )
+
+ if not rows:
+ return None, 0
+ else:
+ min_depth_event_id = ""
+ current_min_depth = MAX_DEPTH
+ for row in rows:
+ if row["depth"] < current_min_depth:
+ min_depth_event_id = row["event_id"]
+ current_min_depth = row["depth"]
+
+ return min_depth_event_id, current_min_depth
async def get_prev_events_for_room(self, room_id: str) -> List[str]:
"""
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 2a96bcd3..9f0d64a3 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -19,7 +19,7 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -177,11 +177,13 @@ class RoomWorkerStore(SQLBaseStore):
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR history_visibility = 'world_readable'
+ join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ OR history_visibility = 'world_readable'
)
AND joined_members > 0
""" % {
- "published_sql": published_sql
+ "published_sql": published_sql,
+ "knock_join_rule": JoinRules.KNOCK,
}
txn.execute(sql, query_args)
@@ -303,7 +305,7 @@ class RoomWorkerStore(SQLBaseStore):
sql = """
SELECT
room_id, name, topic, canonical_alias, joined_members,
- avatar, history_visibility, joined_members, guest_access
+ avatar, history_visibility, guest_access, join_rules
FROM (
%(published_sql)s
) published
@@ -311,7 +313,8 @@ class RoomWorkerStore(SQLBaseStore):
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR history_visibility = 'world_readable'
+ join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ OR history_visibility = 'world_readable'
)
AND joined_members > 0
%(where_clause)s
@@ -320,6 +323,7 @@ class RoomWorkerStore(SQLBaseStore):
"published_sql": published_sql,
"where_clause": where_clause,
"dir": "DESC" if forwards else "ASC",
+ "knock_join_rule": JoinRules.KNOCK,
}
if limit is not None:
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 5fc3bb5a..2796354a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -90,7 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
60 * 1000,
)
self.hs.get_clock().call_later(
- 1000,
+ 1,
self._count_known_servers,
)
LaterGauge(
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index ae9f8809..82a18335 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -41,6 +41,7 @@ ABSOLUTE_STATS_FIELDS = {
"current_state_events",
"joined_members",
"invited_members",
+ "knocked_members",
"left_members",
"banned_members",
"local_users_in_room",
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 33dc752d..051095fe 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -16,9 +16,24 @@
import itertools
import logging
-from collections import deque, namedtuple
-from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
+from collections import deque
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Collection,
+ Deque,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
+import attr
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -26,6 +41,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
@@ -37,7 +53,7 @@ from synapse.types import (
StateMap,
get_domain_from_id,
)
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -89,25 +105,53 @@ times_pruned_extremities = Counter(
)
-class _EventPeristenceQueue:
+@attr.s(auto_attribs=True, slots=True)
+class _EventPersistQueueItem:
+ events_and_contexts: List[Tuple[EventBase, EventContext]]
+ backfilled: bool
+ deferred: ObservableDeferred
+
+ parent_opentracing_span_contexts: List = attr.ib(factory=list)
+ """A list of opentracing spans waiting for this batch"""
+
+ opentracing_span_context: Any = None
+ """The opentracing span under which the persistence actually happened"""
+
+
+_PersistResult = TypeVar("_PersistResult")
+
+
+class _EventPeristenceQueue(Generic[_PersistResult]):
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
"""
- _EventPersistQueueItem = namedtuple(
- "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
- )
+ def __init__(
+ self,
+ per_item_callback: Callable[
+ [List[Tuple[EventBase, EventContext]], bool],
+ Awaitable[_PersistResult],
+ ],
+ ):
+ """Create a new event persistence queue
- def __init__(self):
- self._event_persist_queues = {}
- self._currently_persisting_rooms = set()
+ The per_item_callback will be called for each item added via add_to_queue,
+ and its result will be returned via the Deferreds returned from add_to_queue.
+ """
+ self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {}
+ self._currently_persisting_rooms: Set[str] = set()
+ self._per_item_callback = per_item_callback
- def add_to_queue(self, room_id, events_and_contexts, backfilled):
+ async def add_to_queue(
+ self,
+ room_id: str,
+ events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
+ backfilled: bool,
+ ) -> _PersistResult:
"""Add events to the queue, with the given persist_event options.
- NB: due to the normal usage pattern of this method, it does *not*
- follow the synapse logcontext rules, and leaves the logcontext in
- place whether or not the returned deferred is ready.
+ If we are not already processing events in this room, starts off a background
+ process to to so, calling the per_item_callback for each item.
Args:
room_id (str):
@@ -115,38 +159,54 @@ class _EventPeristenceQueue:
backfilled (bool):
Returns:
- defer.Deferred: a deferred which will resolve once the events are
- persisted. Runs its callbacks *without* a logcontext. The result
- is the same as that returned by the callback passed to
- `handle_queue`.
+ the result returned by the `_per_item_callback` passed to
+ `__init__`.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
- if queue:
- # if the last item in the queue has the same `backfilled` setting,
- # we can just add these new events to that item.
- end_item = queue[-1]
- if end_item.backfilled == backfilled:
- end_item.events_and_contexts.extend(events_and_contexts)
- return end_item.deferred.observe()
- deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
+ if queue and queue[-1].backfilled == backfilled:
+ end_item = queue[-1]
+ else:
+ # need to make a new queue item
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
- queue.append(
- self._EventPersistQueueItem(
- events_and_contexts=events_and_contexts,
+ end_item = _EventPersistQueueItem(
+ events_and_contexts=[],
backfilled=backfilled,
deferred=deferred,
)
- )
+ queue.append(end_item)
+
+ # add our events to the queue item
+ end_item.events_and_contexts.extend(events_and_contexts)
+
+ # also add our active opentracing span to the item so that we get a link back
+ span = opentracing.active_span()
+ if span:
+ end_item.parent_opentracing_span_contexts.append(span.context)
+
+ # start a processor for the queue, if there isn't one already
+ self._handle_queue(room_id)
+
+ # wait for the queue item to complete
+ res = await make_deferred_yieldable(end_item.deferred.observe())
- return deferred.observe()
+ # add another opentracing span which links to the persist trace.
+ with opentracing.start_active_span_follows_from(
+ "persist_event_batch_complete", (end_item.opentracing_span_context,)
+ ):
+ pass
+
+ return res
- def handle_queue(self, room_id, per_item_callback):
+ def _handle_queue(self, room_id):
"""Attempts to handle the queue for a room if not already being handled.
- The given callback will be invoked with for each item in the queue,
+ The queue's callback will be invoked with for each item in the queue,
of type _EventPersistQueueItem. The per_item_callback will continuously
- be called with new items, unless the queue becomnes empty. The return
+ be called with new items, unless the queue becomes empty. The return
value of the function will be given to the deferreds waiting on the item,
exceptions will be passed to the deferreds as well.
@@ -156,7 +216,6 @@ class _EventPeristenceQueue:
If another callback is currently handling the queue then it will not be
invoked.
"""
-
if room_id in self._currently_persisting_rooms:
return
@@ -167,7 +226,17 @@ class _EventPeristenceQueue:
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
- ret = await per_item_callback(item)
+ with opentracing.start_active_span_follows_from(
+ "persist_event_batch",
+ item.parent_opentracing_span_contexts,
+ inherit_force_tracing=True,
+ ) as scope:
+ if scope:
+ item.opentracing_span_context = scope.span.context
+
+ ret = await self._per_item_callback(
+ item.events_and_contexts, item.backfilled
+ )
except Exception:
with PreserveLoggingContext():
item.deferred.errback()
@@ -214,9 +283,10 @@ class EventsPersistenceStorage:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
- self._event_persist_queue = _EventPeristenceQueue()
+ self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
self._state_resolution_handler = hs.get_state_resolution_handler()
+ @opentracing.trace
async def persist_events(
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -241,26 +311,21 @@ class EventsPersistenceStorage:
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
- deferreds = []
- for room_id, evs_ctxs in partitioned.items():
- d = self._event_persist_queue.add_to_queue(
+ async def enqueue(item):
+ room_id, evs_ctxs = item
+ return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
)
- deferreds.append(d)
- for room_id in partitioned:
- self._maybe_start_persisting(room_id)
+ ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
- # Each deferred returns a map from event ID to existing event ID if the
- # event was deduplicated. (The dict may also include other entries if
+ # Each call to add_to_queue returns a map from event ID to existing event ID if
+ # the event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events).
#
- # Since we use `defer.gatherResults` we need to merge the returned list
+ # Since we use `yieldable_gather_results` we need to merge the returned list
# of dicts into one.
- ret_vals = await make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
- replaced_events = {}
+ replaced_events: Dict[str, str] = {}
for d in ret_vals:
replaced_events.update(d)
@@ -277,6 +342,7 @@ class EventsPersistenceStorage:
self.main_store.get_room_max_token(),
)
+ @opentracing.trace
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
@@ -287,16 +353,12 @@ class EventsPersistenceStorage:
event if it was deduplicated due to an existing event matching the
transaction ID.
"""
- deferred = self._event_persist_queue.add_to_queue(
- event.room_id, [(event, context)], backfilled=backfilled
- )
-
- self._maybe_start_persisting(event.room_id)
-
- # The deferred returns a map from event ID to existing event ID if the
+ # add_to_queue returns a map from event ID to existing event ID if the
# event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events.)
- replaced_events = await make_deferred_yieldable(deferred)
+ replaced_events = await self._event_persist_queue.add_to_queue(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
replaced_event = replaced_events.get(event.event_id)
if replaced_event:
event = await self.main_store.get_event(replaced_event)
@@ -308,29 +370,14 @@ class EventsPersistenceStorage:
pos = PersistedEventPosition(self._instance_name, event_stream_id)
return event, pos, self.main_store.get_room_max_token()
- def _maybe_start_persisting(self, room_id: str):
- """Pokes the `_event_persist_queue` to start handling new items in the
- queue, if not already in progress.
-
- Causes the deferreds returned by `add_to_queue` to resolve with: a
- dictionary of event ID to event ID we didn't persist as we already had
- another event persisted with the same TXN ID.
- """
-
- async def persisting_queue(item):
- with Measure(self._clock, "persist_events"):
- return await self._persist_events(
- item.events_and_contexts, backfilled=item.backfilled
- )
-
- self._event_persist_queue.handle_queue(room_id, persisting_queue)
-
- async def _persist_events(
+ async def _persist_event_batch(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> Dict[str, str]:
- """Calculates the change to current state and forward extremities, and
+ """Callback for the _event_persist_queue
+
+ Calculates the change to current state and forward extremities, and
persists the given events and with those updates.
Returns:
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 3799d467..683e5e3b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,4 @@
-# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2014 - 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,7 +25,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.schema import SCHEMA_VERSION
+from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
from synapse.storage.types import Cursor
logger = logging.getLogger(__name__)
@@ -59,6 +58,28 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
)
+@attr.s
+class _SchemaState:
+ current_version: int = attr.ib()
+ """The current schema version of the database"""
+
+ compat_version: Optional[int] = attr.ib()
+ """The SCHEMA_VERSION of the oldest version of Synapse for this database
+
+ If this is None, we have an old version of the database without the necessary
+ table.
+ """
+
+ applied_deltas: Collection[str] = attr.ib(factory=tuple)
+ """Any delta files for `current_version` which have already been applied"""
+
+ upgraded: bool = attr.ib(default=False)
+ """Whether the current state was reached by applying deltas.
+
+ If False, we have run the full schema for `current_version`, and have applied no
+ deltas since. If True, we have run some deltas since the original creation."""
+
+
def prepare_database(
db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
@@ -96,12 +117,11 @@ def prepare_database(
version_info = _get_or_create_schema_state(cur, database_engine)
if version_info:
- user_version, delta_files, upgraded = version_info
logger.info(
"%r: Existing schema is %i (+%i deltas)",
databases,
- user_version,
- len(delta_files),
+ version_info.current_version,
+ len(version_info.applied_deltas),
)
# config should only be None when we are preparing an in-memory SQLite db,
@@ -113,16 +133,18 @@ def prepare_database(
# if it's a worker app, refuse to upgrade the database, to avoid multiple
# workers doing it at once.
- if config.worker_app is not None and user_version != SCHEMA_VERSION:
+ if (
+ config.worker_app is not None
+ and version_info.current_version != SCHEMA_VERSION
+ ):
raise UpgradeDatabaseException(
- OUTDATED_SCHEMA_ON_WORKER_ERROR % (SCHEMA_VERSION, user_version)
+ OUTDATED_SCHEMA_ON_WORKER_ERROR
+ % (SCHEMA_VERSION, version_info.current_version)
)
_upgrade_existing_database(
cur,
- user_version,
- delta_files,
- upgraded,
+ version_info,
database_engine,
config,
databases=databases,
@@ -261,9 +283,7 @@ def _setup_new_database(
_upgrade_existing_database(
cur,
- current_version=max_current_ver,
- applied_delta_files=[],
- upgraded=False,
+ _SchemaState(current_version=max_current_ver, compat_version=None),
database_engine=database_engine,
config=None,
databases=databases,
@@ -273,9 +293,7 @@ def _setup_new_database(
def _upgrade_existing_database(
cur: Cursor,
- current_version: int,
- applied_delta_files: List[str],
- upgraded: bool,
+ current_schema_state: _SchemaState,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str],
@@ -321,12 +339,8 @@ def _upgrade_existing_database(
Args:
cur
- current_version: The current version of the schema.
- applied_delta_files: A list of deltas that have already been applied.
- upgraded: Whether the current version was generated by having
- applied deltas or from full schema file. If `True` the function
- will never apply delta files for the given `current_version`, since
- the current_version wasn't generated by applying those delta files.
+ current_schema_state: The current version of the schema, as
+ returned by _get_or_create_schema_state
database_engine
config:
None if we are initialising a blank database, otherwise the application
@@ -337,13 +351,16 @@ def _upgrade_existing_database(
upgrade portions of the delta scripts.
"""
if is_empty:
- assert not applied_delta_files
+ assert not current_schema_state.applied_deltas
else:
assert config
is_worker = config and config.worker_app is not None
- if current_version > SCHEMA_VERSION:
+ if (
+ current_schema_state.compat_version is not None
+ and current_schema_state.compat_version > SCHEMA_VERSION
+ ):
raise ValueError(
"Cannot use this database as it is too "
+ "new for the server to understand"
@@ -357,14 +374,26 @@ def _upgrade_existing_database(
assert config is not None
check_database_before_upgrade(cur, database_engine, config)
- start_ver = current_version
+ # update schema_compat_version before we run any upgrades, so that if synapse
+ # gets downgraded again, it won't try to run against the upgraded database.
+ if (
+ current_schema_state.compat_version is None
+ or current_schema_state.compat_version < SCHEMA_COMPAT_VERSION
+ ):
+ cur.execute("DELETE FROM schema_compat_version")
+ cur.execute(
+ "INSERT INTO schema_compat_version(compat_version) VALUES (?)",
+ (SCHEMA_COMPAT_VERSION,),
+ )
+
+ start_ver = current_schema_state.current_version
# if we got to this schema version by running a full_schema rather than a series
# of deltas, we should not run the deltas for this version.
- if not upgraded:
+ if not current_schema_state.upgraded:
start_ver += 1
- logger.debug("applied_delta_files: %s", applied_delta_files)
+ logger.debug("applied_delta_files: %s", current_schema_state.applied_deltas)
if isinstance(database_engine, PostgresEngine):
specific_engine_extension = ".postgres"
@@ -440,7 +469,7 @@ def _upgrade_existing_database(
absolute_path = entry.absolute_path
logger.debug("Found file: %s (%s)", relative_path, absolute_path)
- if relative_path in applied_delta_files:
+ if relative_path in current_schema_state.applied_deltas:
continue
root_name, ext = os.path.splitext(file_name)
@@ -621,7 +650,7 @@ def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
def _get_or_create_schema_state(
txn: Cursor, database_engine: BaseDatabaseEngine
-) -> Optional[Tuple[int, List[str], bool]]:
+) -> Optional[_SchemaState]:
# Bluntly try creating the schema_version tables.
sql_path = os.path.join(schema_path, "common", "schema_version.sql")
executescript(txn, sql_path)
@@ -629,17 +658,31 @@ def _get_or_create_schema_state(
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
+ if row is None:
+ # new database
+ return None
+
+ current_version = int(row[0])
+ upgraded = bool(row[1])
+
+ compat_version: Optional[int] = None
+ txn.execute("SELECT compat_version FROM schema_compat_version")
+ row = txn.fetchone()
if row is not None:
- current_version = int(row[0])
- txn.execute(
- "SELECT file FROM applied_schema_deltas WHERE version >= ?",
- (current_version,),
- )
- applied_deltas = [d for d, in txn]
- upgraded = bool(row[1])
- return current_version, applied_deltas, upgraded
+ compat_version = int(row[0])
+
+ txn.execute(
+ "SELECT file FROM applied_schema_deltas WHERE version >= ?",
+ (current_version,),
+ )
+ applied_deltas = tuple(d for d, in txn)
- return None
+ return _SchemaState(
+ current_version=current_version,
+ compat_version=compat_version,
+ applied_deltas=applied_deltas,
+ upgraded=upgraded,
+ )
@attr.s(slots=True)
diff --git a/synapse/storage/schema/README.md b/synapse/storage/schema/README.md
index 030153db..729f44ea 100644
--- a/synapse/storage/schema/README.md
+++ b/synapse/storage/schema/README.md
@@ -1,37 +1,4 @@
# Synapse Database Schemas
-This directory contains the schema files used to build Synapse databases.
-
-Synapse supports splitting its datastore across multiple physical databases (which can
-be useful for large installations), and the schema files are therefore split according
-to the logical database they are apply to.
-
-At the time of writing, the following "logical" databases are supported:
-
-* `state` - used to store Matrix room state (more specifically, `state_groups`,
- their relationships and contents.)
-* `main` - stores everything else.
-
-Addionally, the `common` directory contains schema files for tables which must be
-present on *all* physical databases.
-
-## Full schema dumps
-
-In the `full_schemas` directories, only the most recently-numbered snapshot is useful
-(`54` at the time of writing). Older snapshots (eg, `16`) are present for historical
-reference only.
-
-## Building full schema dumps
-
-If you want to recreate these schemas, they need to be made from a database that
-has had all background updates run.
-
-To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
-`full.sql.postgres` and `full.sql.sqlite` files.
-
-Ensure postgres is installed, then run:
-
- ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
-
-NB at the time of writing, this script predates the split into separate `state`/`main`
-databases so will require updates to handle that correctly.
+This directory contains the schema files used to build Synapse databases. For more
+information, see /docs/development/database_schema.md.
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index f0d9f231..d36ba1d7 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,6 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Remember to update this number every time a change is made to database
-# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 59
+"""Represents the expectations made by the codebase about the database schema
+
+This should be incremented whenever the codebase changes its requirements on the
+shape of the database schema (even if those requirements are backwards-compatible with
+older versions of Synapse).
+
+See `README.md <synapse/storage/schema/README.md>`_ for more information on how this
+works.
+"""
+
+
+SCHEMA_COMPAT_VERSION = 59
+"""Limit on how far the synapse codebase can be rolled back without breaking db compat
+
+This value is stored in the database, and checked on startup. If the value in the
+database is greater than SCHEMA_VERSION, then Synapse will refuse to start.
+"""
diff --git a/synapse/storage/schema/common/schema_version.sql b/synapse/storage/schema/common/schema_version.sql
index 42e5cb6d..f41fde5d 100644
--- a/synapse/storage/schema/common/schema_version.sql
+++ b/synapse/storage/schema/common/schema_version.sql
@@ -20,6 +20,13 @@ CREATE TABLE IF NOT EXISTS schema_version(
CHECK (Lock='X')
);
+CREATE TABLE IF NOT EXISTS schema_compat_version(
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ -- The SCHEMA_VERSION of the oldest synapse this database can be used with
+ compat_version INTEGER NOT NULL,
+ CHECK (Lock='X')
+);
+
CREATE TABLE IF NOT EXISTS applied_schema_deltas(
version INTEGER NOT NULL,
file TEXT NOT NULL,
diff --git a/synapse/storage/schema/main/delta/59/11add_knock_members_to_stats.sql b/synapse/storage/schema/main/delta/59/11add_knock_members_to_stats.sql
new file mode 100644
index 00000000..8eb2196f
--- /dev/null
+++ b/synapse/storage/schema/main/delta/59/11add_knock_members_to_stats.sql
@@ -0,0 +1,20 @@
+/* Copyright 2020 Sorunome
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Existing rows will default to NULL, so anything reading from these tables
+-- needs to interpret NULL as 0. This is fine here as no existing rooms can have
+-- any knocked members.
+ALTER TABLE room_stats_current ADD COLUMN knocked_members INT;
+ALTER TABLE room_stats_historical ADD COLUMN knocked_members BIGINT;
diff --git a/synapse/types.py b/synapse/types.py
index e52cd7ff..8d2fa00f 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -284,14 +284,14 @@ class RoomAlias(DomainSpecificString):
@attr.s(slots=True, frozen=True, repr=False)
class RoomID(DomainSpecificString):
- """Structure representing a room id. """
+ """Structure representing a room id."""
SIGIL = "!"
@attr.s(slots=True, frozen=True, repr=False)
class EventID(DomainSpecificString):
- """Structure representing an event id. """
+ """Structure representing an event id."""
SIGIL = "$"
@@ -404,7 +404,7 @@ def map_username_to_mxid_localpart(
return username.decode("ascii")
-@attr.s(frozen=True, slots=True, cmp=False)
+@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 25ea1bcc..34c662c4 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar
+
+import attr
from twisted.internet import defer
@@ -23,10 +25,36 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-T = TypeVar("T")
+# the type of the key in the cache
+KV = TypeVar("KV")
+
+# the type of the result from the operation
+RV = TypeVar("RV")
+
+@attr.s(auto_attribs=True)
+class ResponseCacheContext(Generic[KV]):
+ """Information about a missed ResponseCache hit
-class ResponseCache(Generic[T]):
+ This object can be passed into the callback for additional feedback
+ """
+
+ cache_key: KV
+ """The cache key that caused the cache miss
+
+ This should be considered read-only.
+
+ TODO: in attrs 20.1, make it frozen with an on_setattr.
+ """
+
+ should_cache: bool = True
+ """Whether the result should be cached once the request completes.
+
+ This can be modified by the callback if it decides its result should not be cached.
+ """
+
+
+class ResponseCache(Generic[KV]):
"""
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
@@ -35,8 +63,10 @@ class ResponseCache(Generic[T]):
"""
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
- # Requests that haven't finished yet.
- self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
+ # This is poorly-named: it includes both complete and incomplete results.
+ # We keep complete results rather than switching to absolute values because
+ # that makes it easier to cache Failure results.
+ self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred]
self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
@@ -50,16 +80,13 @@ class ResponseCache(Generic[T]):
def __len__(self) -> int:
return self.size()
- def get(self, key: T) -> Optional[defer.Deferred]:
+ def get(self, key: KV) -> Optional[defer.Deferred]:
"""Look up the given key.
- Can return either a new Deferred (which also doesn't follow the synapse
- logcontext rules), or, if the request has completed, the actual
- result. You will probably want to make_deferred_yieldable the result.
+ Returns a new Deferred (which also doesn't follow the synapse
+ logcontext rules). You will probably want to make_deferred_yieldable the result.
- If there is no entry for the key, returns None. It is worth noting that
- this means there is no way to distinguish a completed result of None
- from an absent cache entry.
+ If there is no entry for the key, returns None.
Args:
key: key to get/set in the cache
@@ -76,42 +103,56 @@ class ResponseCache(Generic[T]):
self._metrics.inc_misses()
return None
- def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
+ def _set(
+ self, context: ResponseCacheContext[KV], deferred: defer.Deferred
+ ) -> defer.Deferred:
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
you should wrap normal synapse deferreds with
synapse.logging.context.run_in_background).
- Can return either a new Deferred (which also doesn't follow the synapse
- logcontext rules), or, if *deferred* was already complete, the actual
- result. You will probably want to make_deferred_yieldable the result.
+ Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
+ You will probably want to make_deferred_yieldable the result.
Args:
- key: key to get/set in the cache
+ context: Information about the cache miss
deferred: The deferred which resolves to the result.
Returns:
A new deferred which resolves to the actual result.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
+ key = context.cache_key
self.pending_result_cache[key] = result
- def remove(r):
- if self.timeout_sec:
+ def on_complete(r):
+ # if this cache has a non-zero timeout, and the callback has not cleared
+ # the should_cache bit, we leave it in the cache for now and schedule
+ # its removal later.
+ if self.timeout_sec and context.should_cache:
self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None
)
else:
+ # otherwise, remove the result immediately.
self.pending_result_cache.pop(key, None)
return r
- result.addBoth(remove)
+ # make sure we do this *after* adding the entry to pending_result_cache,
+ # in case the result is already complete (in which case flipping the order would
+ # leave us with a stuck entry in the cache).
+ result.addBoth(on_complete)
return result.observe()
- def wrap(
- self, key: T, callback: Callable[..., Any], *args: Any, **kwargs: Any
- ) -> defer.Deferred:
+ async def wrap(
+ self,
+ key: KV,
+ callback: Callable[..., Awaitable[RV]],
+ *args: Any,
+ cache_context: bool = False,
+ **kwargs: Any,
+ ) -> RV:
"""Wrap together a *get* and *set* call, taking care of logcontexts
First looks up the key in the cache, and if it is present makes it
@@ -140,22 +181,28 @@ class ResponseCache(Generic[T]):
*args: positional parameters to pass to the callback, if it is used
+ cache_context: if set, the callback will be given a `cache_context` kw arg,
+ which will be a ResponseCacheContext object.
+
**kwargs: named parameters to pass to the callback, if it is used
Returns:
- Deferred which resolves to the result
+ The result of the callback (from the cache, or otherwise)
"""
result = self.get(key)
if not result:
logger.debug(
"[%s]: no cached result for [%s], calculating new one", self._name, key
)
+ context = ResponseCacheContext(cache_key=key)
+ if cache_context:
+ kwargs["cache_context"] = context
d = run_in_background(callback, *args, **kwargs)
- result = self.set(key, d)
+ result = self._set(context, d)
elif not isinstance(result, defer.Deferred) or result.called:
logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
logger.info(
"[%s]: using incomplete cached result for [%s]", self._name, key
)
- return make_deferred_yieldable(result)
+ return await make_deferred_yieldable(result)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 6d14351b..45353d41 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -133,12 +133,17 @@ class Measure:
self.start = self.clock.time()
self._logging_context.__enter__()
in_flight.register((self.name,), self._update_in_flight)
+
+ logger.debug("Entering block %s", self.name)
+
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.start is None:
raise RuntimeError("Measure() block exited without being entered")
+ logger.debug("Exiting block %s", self.name)
+
duration = self.clock.time() - self.start
usage = self.get_resource_usage()
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index cbfbd097..5a638c6e 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -51,21 +51,26 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
# Load the module config. If None, pass an empty dictionary instead
module_config = provider.get("config") or {}
- try:
- provider_config = provider_class.parse_config(module_config)
- except jsonschema.ValidationError as e:
- raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
- except ConfigError as e:
- raise _wrap_config_error(
- "Failed to parse config for module %r" % (modulename,),
- prefix=itertools.chain(config_path, ("config",)),
- e=e,
- )
- except Exception as e:
- raise ConfigError(
- "Failed to parse config for module %r" % (modulename,),
- path=itertools.chain(config_path, ("config",)),
- ) from e
+ if hasattr(provider_class, "parse_config"):
+ try:
+ provider_config = provider_class.parse_config(module_config)
+ except jsonschema.ValidationError as e:
+ raise json_error_to_config_error(
+ e, itertools.chain(config_path, ("config",))
+ )
+ except ConfigError as e:
+ raise _wrap_config_error(
+ "Failed to parse config for module %r" % (modulename,),
+ prefix=itertools.chain(config_path, ("config",)),
+ e=e,
+ )
+ except Exception as e:
+ raise ConfigError(
+ "Failed to parse config for module %r" % (modulename,),
+ path=itertools.chain(config_path, ("config",)),
+ ) from e
+ else:
+ provider_config = module_config
return provider_class, provider_config