summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2019-07-26 12:56:53 -0300
committerAndrej Shadura <andrewsh@debian.org>2019-07-26 12:56:53 -0300
commit510a9ebe26ed04c0a834f0b3ba2057896fc3aedb (patch)
tree865b30880f1a611dbed3958da1121b5ccccc989e /synapse
parent2c2556601d5da4ffb4205200d95e77439dc5f560 (diff)
New upstream version 1.2.1
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py154
-rw-r--r--synapse/api/errors.py55
-rw-r--r--synapse/app/_base.py79
-rw-r--r--synapse/app/admin_cmd.py264
-rw-r--r--synapse/app/appservice.py5
-rw-r--r--synapse/app/client_reader.py5
-rw-r--r--synapse/app/event_creator.py5
-rw-r--r--synapse/app/federation_reader.py5
-rw-r--r--synapse/app/federation_sender.py5
-rw-r--r--synapse/app/frontend_proxy.py5
-rwxr-xr-xsynapse/app/homeserver.py5
-rw-r--r--synapse/app/media_repository.py5
-rw-r--r--synapse/app/pusher.py5
-rw-r--r--synapse/app/synchrotron.py5
-rw-r--r--synapse/app/user_dir.py5
-rw-r--r--synapse/appservice/scheduler.py2
-rw-r--r--synapse/config/_base.py77
-rw-r--r--synapse/config/database.py3
-rw-r--r--synapse/config/emailconfig.py19
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/config/logger.py7
-rw-r--r--synapse/config/ratelimiting.py4
-rw-r--r--synapse/config/registration.py24
-rw-r--r--synapse/config/server.py5
-rw-r--r--synapse/config/tracer.py59
-rw-r--r--synapse/crypto/keyring.py15
-rw-r--r--synapse/events/__init__.py11
-rw-r--r--synapse/events/snapshot.py2
-rw-r--r--synapse/events/utils.py22
-rw-r--r--synapse/federation/federation_base.py24
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/federation_server.py8
-rw-r--r--synapse/federation/persistence.py34
-rw-r--r--synapse/federation/sender/__init__.py12
-rw-r--r--synapse/federation/sender/transaction_manager.py9
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py458
-rw-r--r--synapse/groups/attestations.py2
-rw-r--r--synapse/handlers/account_validity.py2
-rw-r--r--synapse/handlers/admin.py183
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py40
-rw-r--r--synapse/handlers/e2e_keys.py7
-rw-r--r--synapse/handlers/events.py2
-rw-r--r--synapse/handlers/federation.py80
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/message.py35
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/handlers/presence.py4
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/receipts.py35
-rw-r--r--synapse/handlers/register.py152
-rw-r--r--synapse/handlers/room_member.py55
-rw-r--r--synapse/handlers/room_member_worker.py12
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/http/client.py2
-rw-r--r--synapse/http/federation/matrix_federation_agent.py2
-rw-r--r--synapse/http/federation/srv_resolver.py2
-rw-r--r--synapse/http/matrixfederationclient.py30
-rw-r--r--synapse/http/request_metrics.py2
-rw-r--r--synapse/http/server.py43
-rw-r--r--synapse/http/servlet.py9
-rw-r--r--synapse/http/site.py2
-rw-r--r--synapse/logging/__init__.py0
-rw-r--r--synapse/logging/context.py697
-rw-r--r--synapse/logging/formatter.py53
-rw-r--r--synapse/logging/opentracing.py483
-rw-r--r--synapse/logging/scopecontextmanager.py138
-rw-r--r--synapse/logging/utils.py (renamed from synapse/util/logutils.py)0
-rw-r--r--synapse/metrics/__init__.py17
-rw-r--r--synapse/metrics/_exposition.py258
-rw-r--r--synapse/metrics/background_process_metrics.py2
-rw-r--r--synapse/metrics/resource.py20
-rw-r--r--synapse/module_api/__init__.py56
-rw-r--r--synapse/notifier.py4
-rw-r--r--synapse/push/baserules.py13
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/python_dependencies.py5
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/http/membership.py65
-rw-r--r--synapse/replication/http/register.py6
-rw-r--r--synapse/replication/tcp/protocol.py2
-rw-r--r--synapse/rest/admin/__init__.py3
-rw-r--r--synapse/rest/admin/server_notice_servlet.py9
-rw-r--r--synapse/rest/client/transactions.py2
-rw-r--r--synapse/rest/client/v1/directory.py10
-rw-r--r--synapse/rest/client/v1/login.py55
-rw-r--r--synapse/rest/client/v1/room.py46
-rw-r--r--synapse/rest/client/v2_alpha/register.py11
-rw-r--r--synapse/rest/client/v2_alpha/relations.py97
-rw-r--r--synapse/rest/media/v1/_base.py6
-rw-r--r--synapse/rest/media/v1/media_repository.py12
-rw-r--r--synapse/rest/media/v1/media_storage.py5
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py2
-rw-r--r--synapse/rest/media/v1/storage_provider.py7
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/events.py4
-rw-r--r--synapse/storage/events_worker.py415
-rw-r--r--synapse/storage/registration.py98
-rw-r--r--synapse/storage/relations.py2
-rw-r--r--synapse/storage/roommember.py20
-rw-r--r--synapse/storage/schema/delta/55/access_token_expiry.sql18
-rw-r--r--synapse/storage/stream.py16
-rw-r--r--synapse/storage/transactions.py28
-rw-r--r--synapse/util/__init__.py8
-rw-r--r--synapse/util/async_helpers.py9
-rw-r--r--synapse/util/caches/descriptors.py11
-rw-r--r--synapse/util/caches/response_cache.py4
-rw-r--r--synapse/util/distributor.py2
-rw-r--r--synapse/util/file_consumer.py2
-rw-r--r--synapse/util/logcontext.py693
-rw-r--r--synapse/util/logformatter.py44
-rw-r--r--synapse/util/metrics.py2
-rw-r--r--synapse/util/ratelimitutils.py18
-rw-r--r--synapse/util/retryutils.py4
119 files changed, 3604 insertions, 2019 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index cf22fabd..8301a13d 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -35,4 +35,4 @@ try:
except ImportError:
pass
-__version__ = "1.1.0"
+__version__ = "1.2.1"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 86f14564..7ce6540b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -25,7 +25,13 @@ from twisted.internet import defer
import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership
-from synapse.api.errors import AuthError, Codes, ResourceLimitError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientTokenError,
+ MissingClientTokenError,
+ ResourceLimitError,
+)
from synapse.config.server import is_threepid_reserved
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
@@ -63,7 +69,6 @@ class Auth(object):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache)
@@ -189,18 +194,17 @@ class Auth(object):
Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object
Raises:
- AuthError if no user by that token exists or the token is invalid.
+ InvalidClientCredentialsError if no user by that token exists or the token
+ is invalid.
+ AuthError if access is denied for the user in the access token
"""
- # Can optionally look elsewhere in the request (e.g. headers)
try:
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent", default=[b""]
)[0].decode("ascii", "surrogateescape")
- access_token = self.get_access_token_from_request(
- request, self.TOKEN_NOT_FOUND_HTTP_STATUS
- )
+ access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
if user_id:
@@ -264,18 +268,12 @@ class Auth(object):
)
)
except KeyError:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Missing access token.",
- errcode=Codes.MISSING_TOKEN,
- )
+ raise MissingClientTokenError()
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
- self.get_access_token_from_request(
- request, self.TOKEN_NOT_FOUND_HTTP_STATUS
- )
+ self.get_access_token_from_request(request)
)
if app_service is None:
defer.returnValue((None, None))
@@ -313,13 +311,25 @@ class Auth(object):
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
- AuthError if no user by that token exists or the token is invalid.
+ InvalidClientCredentialsError if no user by that token exists or the token
+ is invalid.
"""
if rights == "access":
# first look in the database
r = yield self._look_up_user_by_access_token(token)
if r:
+ valid_until_ms = r["valid_until_ms"]
+ if (
+ valid_until_ms is not None
+ and valid_until_ms < self.clock.time_msec()
+ ):
+ # there was a valid access token, but it has expired.
+ # soft-logout the user.
+ raise InvalidClientTokenError(
+ msg="Access token has expired", soft_logout=True
+ )
+
defer.returnValue(r)
# otherwise it needs to be a valid macaroon
@@ -331,11 +341,7 @@ class Auth(object):
if not guest:
# non-guest access tokens must be in the database
logger.warning("Unrecognised access token - not in store.")
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Unrecognised access token.",
- errcode=Codes.UNKNOWN_TOKEN,
- )
+ raise InvalidClientTokenError()
# Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway).
@@ -350,16 +356,10 @@ class Auth(object):
# guest tokens.
stored_user = yield self.store.get_user_by_id(user_id)
if not stored_user:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Unknown user_id %s" % user_id,
- errcode=Codes.UNKNOWN_TOKEN,
- )
+ raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Guest access token used for regular user",
- errcode=Codes.UNKNOWN_TOKEN,
+ raise InvalidClientTokenError(
+ "Guest access token used for regular user"
)
ret = {
"user": user,
@@ -386,11 +386,7 @@ class Auth(object):
ValueError,
) as e:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Invalid macaroon passed.",
- errcode=Codes.UNKNOWN_TOKEN,
- )
+ raise InvalidClientTokenError("Invalid macaroon passed.")
def _parse_and_validate_macaroon(self, token, rights="access"):
"""Takes a macaroon and tries to parse and validate it. This is cached
@@ -430,11 +426,7 @@ class Auth(object):
macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Invalid macaroon passed.",
- errcode=Codes.UNKNOWN_TOKEN,
- )
+ raise InvalidClientTokenError("Invalid macaroon passed.")
if not has_expiry and rights == "access":
self.token_cache[token] = (user_id, guest)
@@ -453,17 +445,14 @@ class Auth(object):
(str) user id
Raises:
- AuthError if there is no user_id caveat in the macaroon
+ InvalidClientCredentialsError if there is no user_id caveat in the
+ macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix) :]
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "No user caveat in macaroon",
- errcode=Codes.UNKNOWN_TOKEN,
- )
+ raise InvalidClientTokenError("No user caveat in macaroon")
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
"""
@@ -527,26 +516,18 @@ class Auth(object):
"token_id": ret.get("token_id", None),
"is_guest": False,
"device_id": ret.get("device_id"),
+ "valid_until_ms": ret.get("valid_until_ms"),
}
defer.returnValue(user_info)
def get_appservice_by_req(self, request):
- try:
- token = self.get_access_token_from_request(
- request, self.TOKEN_NOT_FOUND_HTTP_STATUS
- )
- service = self.store.get_app_service_by_token(token)
- if not service:
- logger.warn("Unrecognised appservice access token.")
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Unrecognised access token.",
- errcode=Codes.UNKNOWN_TOKEN,
- )
- request.authenticated_entity = service.sender
- return defer.succeed(service)
- except KeyError:
- raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.")
+ token = self.get_access_token_from_request(request)
+ service = self.store.get_app_service_by_token(token)
+ if not service:
+ logger.warn("Unrecognised appservice access token.")
+ raise InvalidClientTokenError()
+ request.authenticated_entity = service.sender
+ return defer.succeed(service)
def is_server_admin(self, user):
""" Check if the given user is a local server admin.
@@ -625,21 +606,6 @@ class Auth(object):
defer.returnValue(auth_ids)
- def check_redaction(self, room_version, event, auth_events):
- """Check whether the event sender is allowed to redact the target event.
-
- Returns:
- True if the the sender is allowed to redact the target event if the
- target event was created by them.
- False if the sender is allowed to redact the target event with no
- further checks.
-
- Raises:
- AuthError if the event sender is definitely not allowed to redact
- the target event.
- """
- return event_auth.check_redaction(room_version, event, auth_events)
-
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
"""Check if the user is allowed to edit the room's entry in the
@@ -692,20 +658,16 @@ class Auth(object):
return bool(query_params) or bool(auth_headers)
@staticmethod
- def get_access_token_from_request(request, token_not_found_http_status=401):
+ def get_access_token_from_request(request):
"""Extracts the access_token from the request.
Args:
request: The http request.
- token_not_found_http_status(int): The HTTP status code to set in the
- AuthError if the token isn't found. This is used in some of the
- legacy APIs to change the status code to 403 from the default of
- 401 since some of the old clients depended on auth errors returning
- 403.
Returns:
unicode: The access_token
Raises:
- AuthError: If there isn't an access_token in the request.
+ MissingClientTokenError: If there isn't a single access_token in the
+ request
"""
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@@ -714,34 +676,20 @@ class Auth(object):
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
- raise AuthError(
- token_not_found_http_status,
- "Mixing Authorization headers and access_token query parameters.",
- errcode=Codes.MISSING_TOKEN,
+ raise MissingClientTokenError(
+ "Mixing Authorization headers and access_token query parameters."
)
if len(auth_headers) > 1:
- raise AuthError(
- token_not_found_http_status,
- "Too many Authorization headers.",
- errcode=Codes.MISSING_TOKEN,
- )
+ raise MissingClientTokenError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
return parts[1].decode("ascii")
else:
- raise AuthError(
- token_not_found_http_status,
- "Invalid Authorization header.",
- errcode=Codes.MISSING_TOKEN,
- )
+ raise MissingClientTokenError("Invalid Authorization header.")
else:
# Try to get the access_token from the query params.
if not query_params:
- raise AuthError(
- token_not_found_http_status,
- "Missing access token.",
- errcode=Codes.MISSING_TOKEN,
- )
+ raise MissingClientTokenError()
return query_params[0].decode("ascii")
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 28b5c2af..ad3e2620 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -139,6 +139,22 @@ class ConsentNotGivenError(SynapseError):
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
+class UserDeactivatedError(SynapseError):
+ """The error returned to the client when the user attempted to access an
+ authenticated endpoint, but the account has been deactivated.
+ """
+
+ def __init__(self, msg):
+ """Constructs a UserDeactivatedError
+
+ Args:
+ msg (str): The human-readable error message
+ """
+ super(UserDeactivatedError, self).__init__(
+ code=http_client.FORBIDDEN, msg=msg, errcode=Codes.UNKNOWN
+ )
+
+
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
@@ -210,7 +226,9 @@ class NotFoundError(SynapseError):
class AuthError(SynapseError):
- """An error raised when there was a problem authorising an event."""
+ """An error raised when there was a problem authorising an event, and at various
+ other poorly-defined times.
+ """
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
@@ -218,6 +236,41 @@ class AuthError(SynapseError):
super(AuthError, self).__init__(*args, **kwargs)
+class InvalidClientCredentialsError(SynapseError):
+ """An error raised when there was a problem with the authorisation credentials
+ in a client request.
+
+ https://matrix.org/docs/spec/client_server/r0.5.0#using-access-tokens:
+
+ When credentials are required but missing or invalid, the HTTP call will
+ return with a status of 401 and the error code, M_MISSING_TOKEN or
+ M_UNKNOWN_TOKEN respectively.
+ """
+
+ def __init__(self, msg, errcode):
+ super().__init__(code=401, msg=msg, errcode=errcode)
+
+
+class MissingClientTokenError(InvalidClientCredentialsError):
+ """Raised when we couldn't find the access token in a request"""
+
+ def __init__(self, msg="Missing access token"):
+ super().__init__(msg=msg, errcode="M_MISSING_TOKEN")
+
+
+class InvalidClientTokenError(InvalidClientCredentialsError):
+ """Raised when we didn't understand the access token in a request"""
+
+ def __init__(self, msg="Unrecognised access token", soft_logout=False):
+ super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
+ self._soft_logout = soft_logout
+
+ def error_dict(self):
+ d = super().error_dict()
+ d["soft_logout"] = self._soft_logout
+ return d
+
+
class ResourceLimitError(SynapseError):
"""
Any error raised when there is a problem with resource usage.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index d50a9840..540dbd92 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -27,7 +27,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
from synapse.crypto import context_factory
-from synapse.util import PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext
from synapse.util.async_helpers import Linearizer
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -48,7 +48,7 @@ def register_sighup(func):
_sighup_callbacks.append(func)
-def start_worker_reactor(appname, config):
+def start_worker_reactor(appname, config, run_command=reactor.run):
""" Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
@@ -57,6 +57,7 @@ def start_worker_reactor(appname, config):
Args:
appname (str): application name which will be sent to syslog
config (synapse.config.Config): config object
+ run_command (Callable[]): callable that actually runs the reactor
"""
logger = logging.getLogger(config.worker_app)
@@ -69,11 +70,19 @@ def start_worker_reactor(appname, config):
daemonize=config.worker_daemonize,
print_pidfile=config.print_pidfile,
logger=logger,
+ run_command=run_command,
)
def start_reactor(
- appname, soft_file_limit, gc_thresholds, pid_file, daemonize, print_pidfile, logger
+ appname,
+ soft_file_limit,
+ gc_thresholds,
+ pid_file,
+ daemonize,
+ print_pidfile,
+ logger,
+ run_command=reactor.run,
):
""" Run the reactor in the main process
@@ -88,38 +97,42 @@ def start_reactor(
daemonize (bool): true to run the reactor in a background process
print_pidfile (bool): whether to print the pid file, if daemonize is True
logger (logging.Logger): logger instance to pass to Daemonize
+ run_command (Callable[]): callable that actually runs the reactor
"""
install_dns_limiter(reactor)
def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
-
- change_resource_limit(soft_file_limit)
- if gc_thresholds:
- gc.set_threshold(*gc_thresholds)
- reactor.run()
-
- if daemonize:
- if print_pidfile:
- print(pid_file)
-
- daemon = Daemonize(
- app=appname,
- pid=pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ logger.info("Running")
+ change_resource_limit(soft_file_limit)
+ if gc_thresholds:
+ gc.set_threshold(*gc_thresholds)
+ run_command()
+
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ #
+ # We also need to drop the logcontext before forking if we're daemonizing,
+ # otherwise the cputime metrics get confused about the per-thread resource usage
+ # appearing to go backwards.
+ with PreserveLoggingContext():
+ if daemonize:
+ if print_pidfile:
+ print(pid_file)
+
+ daemon = Daemonize(
+ app=appname,
+ pid=pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
def quit_with_error(error_string):
@@ -136,8 +149,7 @@ def listen_metrics(bind_addresses, port):
"""
Start Prometheus metrics server.
"""
- from synapse.metrics import RegistryProxy
- from prometheus_client import start_http_server
+ from synapse.metrics import RegistryProxy, start_http_server
for host in bind_addresses:
logger.info("Starting metrics listener on %s:%d", host, port)
@@ -240,6 +252,9 @@ def start(hs, listeners=None):
# Load the certificate from disk.
refresh_certificate(hs)
+ # Start the tracer
+ synapse.logging.opentracing.init_tracer(hs.config)
+
# It is now safe to start your Synapse.
hs.start_listening(listeners)
hs.get_datastore().start_profiling()
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
new file mode 100644
index 00000000..1fd52a55
--- /dev/null
+++ b/synapse/app/admin_cmd.py
@@ -0,0 +1,264 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import logging
+import os
+import sys
+import tempfile
+
+from canonicaljson import json
+
+from twisted.internet import defer, task
+
+import synapse
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.handlers.admin import ExfiltrationWriter
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.filtering import SlavedFilteringStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.util.logcontext import LoggingContext
+from synapse.util.versionstring import get_version_string
+
+logger = logging.getLogger("synapse.app.admin_cmd")
+
+
+class AdminCmdSlavedStore(
+ SlavedReceiptsStore,
+ SlavedAccountDataStore,
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ SlavedFilteringStore,
+ SlavedPresenceStore,
+ SlavedGroupServerStore,
+ SlavedDeviceInboxStore,
+ SlavedDeviceStore,
+ SlavedPushRuleStore,
+ SlavedEventStore,
+ SlavedClientIpStore,
+ RoomStore,
+ BaseSlavedStore,
+):
+ pass
+
+
+class AdminCmdServer(HomeServer):
+ DATASTORE_CLASS = AdminCmdSlavedStore
+
+ def _listen_http(self, listener_config):
+ pass
+
+ def start_listening(self, listeners):
+ pass
+
+ def build_tcp_replication(self):
+ return AdminCmdReplicationHandler(self)
+
+
+class AdminCmdReplicationHandler(ReplicationClientHandler):
+ @defer.inlineCallbacks
+ def on_rdata(self, stream_name, token, rows):
+ pass
+
+ def get_streams_to_replicate(self):
+ return {}
+
+
+@defer.inlineCallbacks
+def export_data_command(hs, args):
+ """Export data for a user.
+
+ Args:
+ hs (HomeServer)
+ args (argparse.Namespace)
+ """
+
+ user_id = args.user_id
+ directory = args.output_directory
+
+ res = yield hs.get_handlers().admin_handler.export_user_data(
+ user_id, FileExfiltrationWriter(user_id, directory=directory)
+ )
+ print(res)
+
+
+class FileExfiltrationWriter(ExfiltrationWriter):
+ """An ExfiltrationWriter that writes the users data to a directory.
+ Returns the directory location on completion.
+
+ Note: This writes to disk on the main reactor thread.
+
+ Args:
+ user_id (str): The user whose data is being exfiltrated.
+ directory (str|None): The directory to write the data to, if None then
+ will write to a temporary directory.
+ """
+
+ def __init__(self, user_id, directory=None):
+ self.user_id = user_id
+
+ if directory:
+ self.base_directory = directory
+ else:
+ self.base_directory = tempfile.mkdtemp(
+ prefix="synapse-exfiltrate__%s__" % (user_id,)
+ )
+
+ os.makedirs(self.base_directory, exist_ok=True)
+ if list(os.listdir(self.base_directory)):
+ raise Exception("Directory must be empty")
+
+ def write_events(self, room_id, events):
+ room_directory = os.path.join(self.base_directory, "rooms", room_id)
+ os.makedirs(room_directory, exist_ok=True)
+ events_file = os.path.join(room_directory, "events")
+
+ with open(events_file, "a") as f:
+ for event in events:
+ print(json.dumps(event.get_pdu_json()), file=f)
+
+ def write_state(self, room_id, event_id, state):
+ room_directory = os.path.join(self.base_directory, "rooms", room_id)
+ state_directory = os.path.join(room_directory, "state")
+ os.makedirs(state_directory, exist_ok=True)
+
+ event_file = os.path.join(state_directory, event_id)
+
+ with open(event_file, "a") as f:
+ for event in state.values():
+ print(json.dumps(event.get_pdu_json()), file=f)
+
+ def write_invite(self, room_id, event, state):
+ self.write_events(room_id, [event])
+
+ # We write the invite state somewhere else as they aren't full events
+ # and are only a subset of the state at the event.
+ room_directory = os.path.join(self.base_directory, "rooms", room_id)
+ os.makedirs(room_directory, exist_ok=True)
+
+ invite_state = os.path.join(room_directory, "invite_state")
+
+ with open(invite_state, "a") as f:
+ for event in state.values():
+ print(json.dumps(event), file=f)
+
+ def finished(self):
+ return self.base_directory
+
+
+def start(config_options):
+ parser = argparse.ArgumentParser(description="Synapse Admin Command")
+ HomeServerConfig.add_arguments_to_parser(parser)
+
+ subparser = parser.add_subparsers(
+ title="Admin Commands",
+ required=True,
+ dest="command",
+ metavar="<admin_command>",
+ help="The admin command to perform.",
+ )
+ export_data_parser = subparser.add_parser(
+ "export-data", help="Export all data for a user"
+ )
+ export_data_parser.add_argument("user_id", help="User to extra data from")
+ export_data_parser.add_argument(
+ "--output-directory",
+ action="store",
+ metavar="DIRECTORY",
+ required=False,
+ help="The directory to store the exported data in. Must be empty. Defaults"
+ " to creating a temp directory.",
+ )
+ export_data_parser.set_defaults(func=export_data_command)
+
+ try:
+ config, args = HomeServerConfig.load_config_with_parser(parser, config_options)
+ except ConfigError as e:
+ sys.stderr.write("\n" + str(e) + "\n")
+ sys.exit(1)
+
+ if config.worker_app is not None:
+ assert config.worker_app == "synapse.app.admin_cmd"
+
+ # Update the config with some basic overrides so that don't have to specify
+ # a full worker config.
+ config.worker_app = "synapse.app.admin_cmd"
+
+ if (
+ not config.worker_daemonize
+ and not config.worker_log_file
+ and not config.worker_log_config
+ ):
+ # Since we're meant to be run as a "command" let's not redirect stdio
+ # unless we've actually set log config.
+ config.no_redirect_stdio = True
+
+ # Explicitly disable background processes
+ config.update_user_directory = False
+ config.start_pushers = False
+ config.send_federation = False
+
+ setup_logging(config, use_worker_options=True)
+
+ synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ database_engine = create_engine(config.database_config)
+
+ ss = AdminCmdServer(
+ config.server_name,
+ db_config=config.database_config,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ss.setup()
+
+ # We use task.react as the basic run command as it correctly handles tearing
+ # down the reactor when the deferreds resolve and setting the return value.
+ # We also make sure that `_base.start` gets run before we actually run the
+ # command.
+
+ @defer.inlineCallbacks
+ def run(_reactor):
+ with LoggingContext("command"):
+ yield _base.start(ss, [])
+ yield args.func(ss, args)
+
+ _base.start_worker_reactor(
+ "synapse-admin-cmd", config, run_command=lambda: task.react(run)
+ )
+
+
+if __name__ == "__main__":
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 9120bdb1..e01f3e5f 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -26,8 +26,8 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext, run_in_background
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
@@ -36,7 +36,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 90bc79cd..29bddc48 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -27,8 +27,8 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -64,7 +64,6 @@ from synapse.rest.client.versions import VersionsRestServlet
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index ff522e44..042cfd04 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -27,8 +27,8 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -59,7 +59,6 @@ from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 94214209..76a97f8f 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -28,8 +28,8 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -48,7 +48,6 @@ from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 969be58d..fec49d50 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -27,9 +27,9 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.federation import send_queue
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
+from synapse.logging.context import LoggingContext, run_in_background
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
@@ -44,7 +44,6 @@ from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 2fd7d57e..1f1f1df7 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -29,8 +29,8 @@ from synapse.config.logger import setup_logging
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
@@ -41,7 +41,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 49da105c..0c075cb3 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -54,9 +54,9 @@ from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
+from synapse.logging.context import LoggingContext
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.module_api import ModuleApi
from synapse.python_dependencies import check_requirements
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
@@ -72,7 +72,6 @@ from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
from synapse.util.rlimit import change_resource_limit
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index cf0e2036..d70780e9 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -27,8 +27,8 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
@@ -40,7 +40,6 @@ from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index df29ea5e..070de7d0 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -26,8 +26,8 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext, run_in_background
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage._base import __func__
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.events import SlavedEventStore
@@ -38,7 +38,6 @@ from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 85894991..315c0306 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -31,8 +31,8 @@ from synapse.config.logger import setup_logging
from synapse.handlers.presence import PresenceHandler, get_interested_parties
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext, run_in_background
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -57,7 +57,6 @@ from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 2d9d2e1b..03ef21bd 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -28,8 +28,8 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.logging.context import LoggingContext, run_in_background
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
@@ -46,7 +46,6 @@ from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index b54bf541..e5b36494 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -53,8 +53,8 @@ import logging
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import run_in_background
logger = logging.getLogger(__name__)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 965478d8..6ce5cd07 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -137,12 +137,42 @@ class Config(object):
return file_stream.read()
def invoke_all(self, name, *args, **kargs):
+ """Invoke all instance methods with the given name and arguments in the
+ class's MRO.
+
+ Args:
+ name (str): Name of function to invoke
+ *args
+ **kwargs
+
+ Returns:
+ list: The list of the return values from each method called
+ """
results = []
for cls in type(self).mro():
if name in cls.__dict__:
results.append(getattr(cls, name)(self, *args, **kargs))
return results
+ @classmethod
+ def invoke_all_static(cls, name, *args, **kargs):
+ """Invoke all static methods with the given name and arguments in the
+ class's MRO.
+
+ Args:
+ name (str): Name of function to invoke
+ *args
+ **kwargs
+
+ Returns:
+ list: The list of the return values from each method called
+ """
+ results = []
+ for c in cls.mro():
+ if name in c.__dict__:
+ results.append(getattr(c, name)(*args, **kargs))
+ return results
+
def generate_config(
self,
config_dir_path,
@@ -202,6 +232,23 @@ class Config(object):
Returns: Config object.
"""
config_parser = argparse.ArgumentParser(description=description)
+ cls.add_arguments_to_parser(config_parser)
+ obj, _ = cls.load_config_with_parser(config_parser, argv)
+
+ return obj
+
+ @classmethod
+ def add_arguments_to_parser(cls, config_parser):
+ """Adds all the config flags to an ArgumentParser.
+
+ Doesn't support config-file-generation: used by the worker apps.
+
+ Used for workers where we want to add extra flags/subcommands.
+
+ Args:
+ config_parser (ArgumentParser): App description
+ """
+
config_parser.add_argument(
"-c",
"--config-path",
@@ -219,16 +266,34 @@ class Config(object):
" Defaults to the directory containing the last config file",
)
- obj = cls()
+ cls.invoke_all_static("add_arguments", config_parser)
- obj.invoke_all("add_arguments", config_parser)
+ @classmethod
+ def load_config_with_parser(cls, parser, argv):
+ """Parse the commandline and config files with the given parser
+
+ Doesn't support config-file-generation: used by the worker apps.
- config_args = config_parser.parse_args(argv)
+ Used for workers where we want to add extra flags/subcommands.
+
+ Args:
+ parser (ArgumentParser)
+ argv (list[str])
+
+ Returns:
+ tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed
+ config object and the parsed argparse.Namespace object from
+ `parser.parse_args(..)`
+ """
+
+ obj = cls()
+
+ config_args = parser.parse_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
if not config_files:
- config_parser.error("Must supply a config file.")
+ parser.error("Must supply a config file.")
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
@@ -244,7 +309,7 @@ class Config(object):
obj.invoke_all("read_arguments", config_args)
- return obj
+ return obj, config_args
@classmethod
def load_or_generate_config(cls, description, argv):
@@ -401,7 +466,7 @@ class Config(object):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
- obj.invoke_all("add_arguments", parser)
+ obj.invoke_all_static("add_arguments", parser)
args = parser.parse_args(remaining_args)
config_dict = read_config_files(config_files)
diff --git a/synapse/config/database.py b/synapse/config/database.py
index bcb2089d..746a6cd1 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -69,7 +69,8 @@ class DatabaseConfig(Config):
if database_path is not None:
self.database_config["args"]["database"] = database_path
- def add_arguments(self, parser):
+ @staticmethod
+ def add_arguments(parser):
db_group = parser.add_argument_group("database")
db_group.add_argument(
"-d",
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index fcd55d3e..8381b8eb 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -112,13 +112,17 @@ class EmailConfig(Config):
missing = []
for k in required:
if k not in email_config:
- missing.append(k)
+ missing.append("email." + k)
+
+ if config.get("public_baseurl") is None:
+ missing.append("public_base_url")
if len(missing) > 0:
raise RuntimeError(
- "email.password_reset_behaviour is set to 'local' "
- "but required keys are missing: %s"
- % (", ".join(["email." + k for k in missing]),)
+ "Password resets emails are configured to be sent from "
+ "this homeserver due to a partial 'email' block. "
+ "However, the following required keys are missing: %s"
+ % (", ".join(missing),)
)
# Templates for password reset emails
@@ -156,13 +160,6 @@ class EmailConfig(Config):
filepath, "email.password_reset_template_success_html"
)
- if config.get("public_baseurl") is None:
- raise RuntimeError(
- "email.password_reset_behaviour is set to 'local' but no "
- "public_baseurl is set. This is necessary to generate password "
- "reset links"
- )
-
if self.email_enable_notifs:
required = [
"smtp_host",
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index acadef4f..72acad4f 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -40,6 +40,7 @@ from .spam_checker import SpamCheckerConfig
from .stats import StatsConfig
from .third_party_event_rules import ThirdPartyRulesConfig
from .tls import TlsConfig
+from .tracer import TracerConfig
from .user_directory import UserDirectoryConfig
from .voip import VoipConfig
from .workers import WorkerConfig
@@ -75,5 +76,6 @@ class HomeServerConfig(
ServerNoticesConfig,
RoomDirectoryConfig,
ThirdPartyRulesConfig,
+ TracerConfig,
):
pass
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 931aec41..40502a57 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -24,7 +24,7 @@ from twisted.logger import STDLibLogObserver, globalLogBeginner
import synapse
from synapse.app import _base as appbase
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.logging.context import LoggingContextFilter
from synapse.util.versionstring import get_version_string
from ._base import Config
@@ -40,7 +40,7 @@ formatters:
filters:
context:
- (): synapse.util.logcontext.LoggingContextFilter
+ (): synapse.logging.context.LoggingContextFilter
request: ""
handlers:
@@ -103,7 +103,8 @@ class LoggingConfig(Config):
if args.log_file is not None:
self.log_file = args.log_file
- def add_arguments(cls, parser):
+ @staticmethod
+ def add_arguments(parser):
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"-v",
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 8c587f3f..33f31cf2 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -23,7 +23,7 @@ class RateLimitConfig(object):
class FederationRateLimitConfig(object):
_items_and_default = {
- "window_size": 10000,
+ "window_size": 1000,
"sleep_limit": 10,
"sleep_delay": 500,
"reject_limit": 50,
@@ -54,7 +54,7 @@ class RatelimitConfig(Config):
# Load the new-style federation config, if it exists. Otherwise, fall
# back to the old method.
- if "federation_rc" in config:
+ if "rc_federation" in config:
self.rc_federation = FederationRateLimitConfig(**config["rc_federation"])
else:
self.rc_federation = FederationRateLimitConfig(
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 4a59e6ec..c3de7a4e 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -71,9 +71,8 @@ class RegistrationConfig(Config):
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
- self.invite_3pid_guest = self.allow_guest_access and config.get(
- "invite_3pid_guest", False
- )
+ if config.get("invite_3pid_guest", False):
+ raise ConfigError("invite_3pid_guest is no longer supported")
self.auto_join_rooms = config.get("auto_join_rooms", [])
for room_alias in self.auto_join_rooms:
@@ -85,6 +84,11 @@ class RegistrationConfig(Config):
"disable_msisdn_registration", False
)
+ session_lifetime = config.get("session_lifetime")
+ if session_lifetime is not None:
+ session_lifetime = self.parse_duration(session_lifetime)
+ self.session_lifetime = session_lifetime
+
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@@ -142,6 +146,17 @@ class RegistrationConfig(Config):
# renew_at: 1w
# renew_email_subject: "Renew your %%(app)s account"
+ # Time that a user's session remains valid for, after they log in.
+ #
+ # Note that this is not currently compatible with guest logins.
+ #
+ # Note also that this is calculated at login time: changes are not applied
+ # retrospectively to users who have already logged in.
+ #
+ # By default, this is infinite.
+ #
+ #session_lifetime: 24h
+
# The user must provide all of the below types of 3PID when registering.
#
#registrations_require_3pid:
@@ -222,7 +237,8 @@ class RegistrationConfig(Config):
% locals()
)
- def add_arguments(self, parser):
+ @staticmethod
+ def add_arguments(parser):
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--enable-registration",
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 2a74dea2..00170f13 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -136,7 +136,7 @@ class ServerConfig(Config):
# Whether to enable experimental MSC1849 (aka relations) support
self.experimental_msc1849_support_enabled = config.get(
- "experimental_msc1849_support_enabled", False
+ "experimental_msc1849_support_enabled", True
)
# Options to control access by tracking MAU
@@ -639,7 +639,8 @@ class ServerConfig(Config):
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
- def add_arguments(self, parser):
+ @staticmethod
+ def add_arguments(parser):
server_group = parser.add_argument_group("server")
server_group.add_argument(
"-D",
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
new file mode 100644
index 00000000..44794544
--- /dev/null
+++ b/synapse/config/tracer.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.d
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config, ConfigError
+
+
+class TracerConfig(Config):
+ def read_config(self, config, **kwargs):
+ opentracing_config = config.get("opentracing")
+ if opentracing_config is None:
+ opentracing_config = {}
+
+ self.opentracer_enabled = opentracing_config.get("enabled", False)
+ if not self.opentracer_enabled:
+ return
+
+ # The tracer is enabled so sanitize the config
+
+ self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", [])
+ if not isinstance(self.opentracer_whitelist, list):
+ raise ConfigError("Tracer homeserver_whitelist config is malformed")
+
+ def generate_config_section(cls, **kwargs):
+ return """\
+ ## Opentracing ##
+
+ # These settings enable opentracing, which implements distributed tracing.
+ # This allows you to observe the causal chains of events across servers
+ # including requests, key lookups etc., across any server running
+ # synapse or any other other services which supports opentracing
+ # (specifically those implemented with Jaeger).
+ #
+ opentracing:
+ # tracing is disabled by default. Uncomment the following line to enable it.
+ #
+ #enabled: true
+
+ # The list of homeservers we wish to send and receive span contexts and span baggage.
+ # See docs/opentracing.rst
+ # This is a list of regexes which are matched against the server_name of the
+ # homeserver.
+ #
+ # By defult, it is empty, so no servers are matched.
+ #
+ #homeserver_whitelist:
+ # - ".*"
+ """
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 10c2eb7f..341c8631 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -44,15 +44,16 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.storage.keys import FetchKeyResult
-from synapse.util import logcontext, unwrapFirstError
-from synapse.util.async_helpers import yieldable_gather_results
-from synapse.util.logcontext import (
+from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
+ make_deferred_yieldable,
preserve_fn,
run_in_background,
)
+from synapse.storage.keys import FetchKeyResult
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure
from synapse.util.retryutils import NotRetryingDestination
@@ -140,7 +141,7 @@ class Keyring(object):
"""
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
requests = (req,)
- return logcontext.make_deferred_yieldable(self._verify_objects(requests)[0])
+ return make_deferred_yieldable(self._verify_objects(requests)[0])
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as
@@ -557,7 +558,7 @@ class BaseV2KeyFetcher(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json)
- yield logcontext.make_deferred_yieldable(
+ yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -612,7 +613,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
defer.returnValue({})
- results = yield logcontext.make_deferred_yieldable(
+ results = yield make_deferred_yieldable(
defer.gatherResults(
[run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index d3de70e6..88ed6d76 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -104,6 +104,17 @@ class _EventInternalMetadata(object):
"""
return getattr(self, "proactively_send", True)
+ def is_redacted(self):
+ """Whether the event has been redacted.
+
+ This is used for efficiently checking whether an event has been
+ marked as redacted without needing to make another database call.
+
+ Returns:
+ bool
+ """
+ return getattr(self, "redacted", False)
+
def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a96cdada..a9545e6c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -19,7 +19,7 @@ from frozendict import frozendict
from twisted.internet import defer
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable, run_in_background
class EventContext(object):
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f24f0c16..9487a886 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -52,10 +52,15 @@ def prune_event(event):
from . import event_type_from_format_version
- return event_type_from_format_version(event.format_version)(
+ pruned_event = event_type_from_format_version(event.format_version)(
pruned_event_dict, event.internal_metadata.get_dict()
)
+ # Mark the event as redacted
+ pruned_event.internal_metadata.redacted = True
+
+ return pruned_event
+
def prune_event_dict(event_dict):
"""Redacts the event_dict in the same way as `prune_event`, except it
@@ -360,9 +365,12 @@ class EventClientSerializer(object):
event_id = event.event_id
serialized_event = serialize_event(event, time_now, **kwargs)
- # If MSC1849 is enabled then we need to look if thre are any relations
- # we need to bundle in with the event
- if self.experimental_msc1849_support_enabled and bundle_aggregations:
+ # If MSC1849 is enabled then we need to look if there are any relations
+ # we need to bundle in with the event.
+ # Do not bundle relations if the event has been redacted
+ if not event.internal_metadata.is_redacted() and (
+ self.experimental_msc1849_support_enabled and bundle_aggregations
+ ):
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
@@ -392,7 +400,11 @@ class EventClientSerializer(object):
serialized_event["content"].pop("m.relates_to", None)
r = serialized_event["unsigned"].setdefault("m.relations", {})
- r[RelationTypes.REPLACE] = {"event_id": edit.event_id}
+ r[RelationTypes.REPLACE] = {
+ "event_id": edit.event_id,
+ "origin_server_ts": edit.origin_server_ts,
+ "sender": edit.sender,
+ }
defer.returnValue(serialized_event)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 1e925b19..f7bb806a 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -27,8 +27,14 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import event_type_from_format_version
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
+from synapse.logging.context import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+ preserve_fn,
+)
from synapse.types import get_domain_from_id
-from synapse.util import logcontext, unwrapFirstError
+from synapse.util import unwrapFirstError
logger = logging.getLogger(__name__)
@@ -73,7 +79,7 @@ class FederationBase(object):
@defer.inlineCallbacks
def handle_check_result(pdu, deferred):
try:
- res = yield logcontext.make_deferred_yieldable(deferred)
+ res = yield make_deferred_yieldable(deferred)
except SynapseError:
res = None
@@ -102,10 +108,10 @@ class FederationBase(object):
defer.returnValue(res)
- handle = logcontext.preserve_fn(handle_check_result)
+ handle = preserve_fn(handle_check_result)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
- valid_pdus = yield logcontext.make_deferred_yieldable(
+ valid_pdus = yield make_deferred_yieldable(
defer.gatherResults(deferreds2, consumeErrors=True)
).addErrback(unwrapFirstError)
@@ -115,7 +121,7 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, room_version, pdu):
- return logcontext.make_deferred_yieldable(
+ return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
@@ -133,14 +139,14 @@ class FederationBase(object):
* returns a redacted version of the event (if the signature
matched but the hash did not)
* throws a SynapseError if the signature check failed.
- The deferreds run their callbacks in the sentinel logcontext.
+ The deferreds run their callbacks in the sentinel
"""
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
- ctx = logcontext.LoggingContext.current_context()
+ ctx = LoggingContext.current_context()
def callback(_, pdu):
- with logcontext.PreserveLoggingContext(ctx):
+ with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
# let's try to distinguish between failures because the event was
# redacted (which are somewhat expected) vs actual ball-tampering
@@ -178,7 +184,7 @@ class FederationBase(object):
def errback(failure, pdu):
failure.trap(SynapseError)
- with logcontext.PreserveLoggingContext(ctx):
+ with PreserveLoggingContext(ctx):
logger.warn(
"Signature check failed for %s: %s",
pdu.event_id,
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3883eb52..3cb4b944 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -39,10 +39,10 @@ from synapse.api.room_versions import (
)
from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.util import logcontext, unwrapFirstError
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.utils import log_function
+from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
-from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -207,7 +207,7 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield logcontext.make_deferred_yieldable(
+ pdus[:] = yield make_deferred_yieldable(
defer.gatherResults(
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True
).addErrback(unwrapFirstError)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2e0cebb6..ed2b6d5e 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -42,6 +42,8 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
+from synapse.logging.context import nested_logging_context
+from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
@@ -50,8 +52,6 @@ from synapse.types import get_domain_from_id
from synapse.util import glob_to_regex
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.logcontext import nested_logging_context
-from synapse.util.logutils import log_function
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
@@ -369,7 +369,7 @@ class FederationServer(FederationBase):
logger.warn("Room version %s not in %s", room_version, supported_versions)
raise IncompatibleRoomVersionError(room_version=room_version)
- pdu = yield self.handler.on_make_join_request(room_id, user_id)
+ pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue(
{"event": pdu.get_pdu_json(time_now), "room_version": room_version}
@@ -423,7 +423,7 @@ class FederationServer(FederationBase):
def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
- pdu = yield self.handler.on_make_leave_request(room_id, user_id)
+ pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id)
room_version = yield self.store.get_room_version(room_id)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 7535f792..44edcabe 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -21,9 +21,7 @@ These actions are mostly only used by the :py:mod:`.replication` module.
import logging
-from twisted.internet import defer
-
-from synapse.util.logutils import log_function
+from synapse.logging.utils import log_function
logger = logging.getLogger(__name__)
@@ -63,33 +61,3 @@ class TransactionActions(object):
return self.store.set_received_txn_response(
transaction.transaction_id, origin, code, response
)
-
- @defer.inlineCallbacks
- @log_function
- def prepare_to_send(self, transaction):
- """ Persists the `Transaction` we are about to send and works out the
- correct value for the `prev_ids` key.
-
- Returns:
- Deferred
- """
- transaction.prev_ids = yield self.store.prep_send_transaction(
- transaction.transaction_id,
- transaction.destination,
- transaction.origin_server_ts,
- )
-
- @log_function
- def delivered(self, transaction, response_code, response_dict):
- """ Marks the given `Transaction` as having been successfully
- delivered to the remote homeserver, and what the response was.
-
- Returns:
- Deferred
- """
- return self.store.delivered_txn(
- transaction.transaction_id,
- transaction.destination,
- response_code,
- response_dict,
- )
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 766c5a37..d46f4aae 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -26,6 +26,11 @@ from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
from synapse.federation.units import Edu
from synapse.handlers.presence import get_interested_remotes
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ preserve_fn,
+ run_in_background,
+)
from synapse.metrics import (
LaterGauge,
event_processing_loop_counter,
@@ -33,7 +38,6 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import logcontext
from synapse.util.metrics import measure_func
logger = logging.getLogger(__name__)
@@ -210,10 +214,10 @@ class FederationSender(object):
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
- yield logcontext.make_deferred_yieldable(
+ yield make_deferred_yieldable(
defer.gatherResults(
[
- logcontext.run_in_background(handle_room_events, evs)
+ run_in_background(handle_room_events, evs)
for evs in itervalues(events_by_room)
],
consumeErrors=True,
@@ -360,7 +364,7 @@ class FederationSender(object):
for queue in queues:
queue.flush_read_receipts_for_room(room_id)
- @logcontext.preserve_fn # the caller should not yield on this
+ @preserve_fn # the caller should not yield on this
@defer.inlineCallbacks
def send_presence(self, states):
"""Send the new presence states to the appropriate destinations.
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index c987bb9a..0460a8c4 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -63,8 +63,6 @@ class TransactionManager(object):
len(edus),
)
- logger.debug("TX [%s] Persisting transaction...", destination)
-
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
@@ -76,9 +74,6 @@ class TransactionManager(object):
self._next_txn_id += 1
- yield self._transaction_actions.prepare_to_send(transaction)
-
- logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
destination,
@@ -118,10 +113,6 @@ class TransactionManager(object):
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
- yield self._transaction_actions.delivered(transaction, code, response)
-
- logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id)
-
if code == 200:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index aecd1423..1aae9ec9 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
-from synapse.util.logutils import log_function
+from synapse.logging.utils import log_function
logger = logging.getLogger(__name__)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 955f0f43..ea4e1b6d 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,9 +19,8 @@ import functools
import logging
import re
-from twisted.internet import defer
-
import synapse
+import synapse.logging.opentracing as opentracing
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import (
@@ -36,8 +36,8 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string_from_args,
)
+from synapse.logging.context import run_in_background
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-from synapse.util.logcontext import run_in_background
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
@@ -102,8 +102,7 @@ class Authenticator(object):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets
- @defer.inlineCallbacks
- def authenticate_request(self, request, content):
+ async def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = {
"method": request.method.decode("ascii"),
@@ -141,7 +140,7 @@ class Authenticator(object):
401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
- yield self.keyring.verify_json_for_server(
+ await self.keyring.verify_json_for_server(
origin, json_request, now, "Incoming request"
)
@@ -150,17 +149,16 @@ class Authenticator(object):
# If we get a valid signed request from the other side, its probably
# alive
- retry_timings = yield self.store.get_destination_retry_timings(origin)
+ retry_timings = await self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
run_in_background(self._reset_retry_timings, origin)
- defer.returnValue(origin)
+ return origin
- @defer.inlineCallbacks
- def _reset_retry_timings(self, origin):
+ async def _reset_retry_timings(self, origin):
try:
logger.info("Marking origin %r as up", origin)
- yield self.store.set_destination_retry_timings(origin, 0, 0)
+ await self.store.set_destination_retry_timings(origin, 0, 0)
except Exception:
logger.exception("Error resetting retry timings on %s", origin)
@@ -214,7 +212,8 @@ class BaseFederationServlet(object):
match against the request path (excluding the /federation/v1 prefix).
The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
- the appropriate HTTP method. These methods have the signature:
+ the appropriate HTTP method. These methods must be *asynchronous* and have the
+ signature:
on_<METHOD>(self, origin, content, query, **kwargs)
@@ -234,7 +233,7 @@ class BaseFederationServlet(object):
components as specified in the path match regexp.
Returns:
- Deferred[(int, object)|None]: either (response code, response object) to
+ Optional[Tuple[int, object]]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
Raises:
@@ -257,10 +256,9 @@ class BaseFederationServlet(object):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
- @defer.inlineCallbacks
@functools.wraps(func)
- def new_func(request, *args, **kwargs):
- """ A callback which can be passed to HttpServer.RegisterPaths
+ async def new_func(request, *args, **kwargs):
+ """A callback which can be passed to HttpServer.RegisterPaths
Args:
request (twisted.web.http.Request):
@@ -269,8 +267,8 @@ class BaseFederationServlet(object):
components as specified in the path match regexp.
Returns:
- Deferred[(int, object)|None]: (response code, response object) as returned
- by the callback method. None if the request has already been handled.
+ Tuple[int, object]|None: (response code, response object) as returned by
+ the callback method. None if the request has already been handled.
"""
content = None
if request.method in [b"PUT", b"POST"]:
@@ -278,7 +276,7 @@ class BaseFederationServlet(object):
content = parse_json_object_from_request(request)
try:
- origin = yield authenticator.authenticate_request(request, content)
+ origin = await authenticator.authenticate_request(request, content)
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
@@ -288,16 +286,31 @@ class BaseFederationServlet(object):
logger.warn("authenticate_request failed: %s", e)
raise
- if origin:
- with ratelimiter.ratelimit(origin) as d:
- yield d
- response = yield func(
+ # Start an opentracing span
+ with opentracing.start_active_span_from_context(
+ request.requestHeaders,
+ "incoming-federation-request",
+ tags={
+ "request_id": request.get_request_id(),
+ opentracing.tags.SPAN_KIND: opentracing.tags.SPAN_KIND_RPC_SERVER,
+ opentracing.tags.HTTP_METHOD: request.get_method(),
+ opentracing.tags.HTTP_URL: request.get_redacted_uri(),
+ opentracing.tags.PEER_HOST_IPV6: request.getClientIP(),
+ "authenticated_entity": origin,
+ },
+ ):
+ if origin:
+ with ratelimiter.ratelimit(origin) as d:
+ await d
+ response = await func(
+ origin, content, request.args, *args, **kwargs
+ )
+ else:
+ response = await func(
origin, content, request.args, *args, **kwargs
)
- else:
- response = yield func(origin, content, request.args, *args, **kwargs)
- defer.returnValue(response)
+ return response
# Extra logic that functools.wraps() doesn't finish
new_func.__self__ = func.__self__
@@ -312,7 +325,9 @@ class BaseFederationServlet(object):
if code is None:
continue
- server.register_paths(method, (pattern,), self._wrap(code))
+ server.register_paths(
+ method, (pattern,), self._wrap(code), self.__class__.__name__
+ )
class FederationSendServlet(BaseFederationServlet):
@@ -325,8 +340,7 @@ class FederationSendServlet(BaseFederationServlet):
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, transaction_id):
+ async def on_PUT(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
@@ -335,7 +349,7 @@ class FederationSendServlet(BaseFederationServlet):
request. This is *not* None.
Returns:
- Deferred: Results in a tuple of `(code, response)`, where
+ Tuple of `(code, response)`, where
`response` is a python dict to be converted into JSON that is
used as the response body.
"""
@@ -364,34 +378,33 @@ class FederationSendServlet(BaseFederationServlet):
except Exception as e:
logger.exception(e)
- defer.returnValue((400, {"error": "Invalid transaction"}))
- return
+ return 400, {"error": "Invalid transaction"}
try:
- code, response = yield self.handler.on_incoming_transaction(
+ code, response = await self.handler.on_incoming_transaction(
origin, transaction_data
)
except Exception:
logger.exception("on_incoming_transaction failed")
raise
- defer.returnValue((code, response))
+ return code, response
class FederationEventServlet(BaseFederationServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
- def on_GET(self, origin, content, query, event_id):
- return self.handler.on_pdu_request(origin, event_id)
+ async def on_GET(self, origin, content, query, event_id):
+ return await self.handler.on_pdu_request(origin, event_id)
class FederationStateServlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/?"
# This is when someone asks for all data for a given context.
- def on_GET(self, origin, content, query, context):
- return self.handler.on_context_state_request(
+ async def on_GET(self, origin, content, query, context):
+ return await self.handler.on_context_state_request(
origin,
context,
parse_string_from_args(query, "event_id", None, required=True),
@@ -401,8 +414,8 @@ class FederationStateServlet(BaseFederationServlet):
class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
- def on_GET(self, origin, content, query, room_id):
- return self.handler.on_state_ids_request(
+ async def on_GET(self, origin, content, query, room_id):
+ return await self.handler.on_state_ids_request(
origin,
room_id,
parse_string_from_args(query, "event_id", None, required=True),
@@ -412,22 +425,22 @@ class FederationStateIdsServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/?"
- def on_GET(self, origin, content, query, context):
+ async def on_GET(self, origin, content, query, context):
versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
if not limit:
- return defer.succeed((400, {"error": "Did not include limit param"}))
+ return 400, {"error": "Did not include limit param"}
- return self.handler.on_backfill_request(origin, context, versions, limit)
+ return await self.handler.on_backfill_request(origin, context, versions, limit)
class FederationQueryServlet(BaseFederationServlet):
PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query
- def on_GET(self, origin, content, query, query_type):
- return self.handler.on_query_request(
+ async def on_GET(self, origin, content, query, query_type):
+ return await self.handler.on_query_request(
query_type,
{k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()},
)
@@ -436,8 +449,7 @@ class FederationQueryServlet(BaseFederationServlet):
class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
- @defer.inlineCallbacks
- def on_GET(self, origin, _content, query, context, user_id):
+ async def on_GET(self, origin, _content, query, context, user_id):
"""
Args:
origin (unicode): The authenticated server_name of the calling server
@@ -450,8 +462,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
components as specified in the path match regexp.
Returns:
- Deferred[(int, object)|None]: either (response code, response object) to
- return a JSON response, or None if the request has already been handled.
+ Tuple[int, object]: (response code, response object)
"""
versions = query.get(b"ver")
if versions is not None:
@@ -459,64 +470,60 @@ class FederationMakeJoinServlet(BaseFederationServlet):
else:
supported_versions = ["1"]
- content = yield self.handler.on_make_join_request(
+ content = await self.handler.on_make_join_request(
origin, context, user_id, supported_versions=supported_versions
)
- defer.returnValue((200, content))
+ return 200, content
class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_leave_request(origin, context, user_id)
- defer.returnValue((200, content))
+ async def on_GET(self, origin, content, query, context, user_id):
+ content = await self.handler.on_make_leave_request(origin, context, user_id)
+ return 200, content
class FederationSendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, room_id, event_id):
- content = yield self.handler.on_send_leave_request(origin, content, room_id)
- defer.returnValue((200, content))
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, content
class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
- def on_GET(self, origin, content, query, context, event_id):
- return self.handler.on_event_auth(origin, context, event_id)
+ async def on_GET(self, origin, content, query, context, event_id):
+ return await self.handler.on_event_auth(origin, context, event_id)
class FederationSendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, context, event_id):
+ async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
- content = yield self.handler.on_send_join_request(origin, content, context)
- defer.returnValue((200, content))
+ content = await self.handler.on_send_join_request(origin, content, context)
+ return 200, content
class FederationV1InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, context, event_id):
+ async def on_PUT(self, origin, content, query, context, event_id):
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
# invites
- content = yield self.handler.on_invite_request(
+ content = await self.handler.on_invite_request(
origin, content, room_version=RoomVersions.V1.identifier
)
# V1 federation API is defined to return a content of `[200, {...}]`
# due to a historical bug.
- defer.returnValue((200, (200, content)))
+ return 200, (200, content)
class FederationV2InviteServlet(BaseFederationServlet):
@@ -524,8 +531,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
PREFIX = FEDERATION_V2_PREFIX
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, context, event_id):
+ async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
@@ -538,69 +544,65 @@ class FederationV2InviteServlet(BaseFederationServlet):
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
- content = yield self.handler.on_invite_request(
+ content = await self.handler.on_invite_request(
origin, event, room_version=room_version
)
- defer.returnValue((200, content))
+ return 200, content
class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, room_id):
- content = yield self.handler.on_exchange_third_party_invite_request(
+ async def on_PUT(self, origin, content, query, room_id):
+ content = await self.handler.on_exchange_third_party_invite_request(
origin, room_id, content
)
- defer.returnValue((200, content))
+ return 200, content
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
- def on_POST(self, origin, content, query):
- return self.handler.on_query_client_keys(origin, content)
+ async def on_POST(self, origin, content, query):
+ return await self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
- def on_GET(self, origin, content, query, user_id):
- return self.handler.on_query_user_devices(origin, user_id)
+ async def on_GET(self, origin, content, query, user_id):
+ return await self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query):
- response = yield self.handler.on_claim_client_keys(origin, content)
- defer.returnValue((200, response))
+ async def on_POST(self, origin, content, query):
+ response = await self.handler.on_claim_client_keys(origin, content)
+ return 200, response
class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, context, event_id):
- new_content = yield self.handler.on_query_auth_request(
+ async def on_POST(self, origin, content, query, context, event_id):
+ new_content = await self.handler.on_query_auth_request(
origin, content, context, event_id
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGetMissingEventsServlet(BaseFederationServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, room_id):
+ async def on_POST(self, origin, content, query, room_id):
limit = int(content.get("limit", 10))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
- content = yield self.handler.on_get_missing_events(
+ content = await self.handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
@@ -608,7 +610,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
limit=limit,
)
- defer.returnValue((200, content))
+ return 200, content
class On3pidBindServlet(BaseFederationServlet):
@@ -616,8 +618,7 @@ class On3pidBindServlet(BaseFederationServlet):
REQUIRE_AUTH = False
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query):
+ async def on_POST(self, origin, content, query):
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -629,7 +630,7 @@ class On3pidBindServlet(BaseFederationServlet):
)
logger.info(message)
raise SynapseError(400, message)
- yield self.handler.exchange_third_party_invite(
+ await self.handler.exchange_third_party_invite(
invite["sender"],
invite["mxid"],
invite["room_id"],
@@ -639,7 +640,7 @@ class On3pidBindServlet(BaseFederationServlet):
last_exception = e
if last_exception:
raise last_exception
- defer.returnValue((200, {}))
+ return 200, {}
class OpenIdUserInfo(BaseFederationServlet):
@@ -663,29 +664,26 @@ class OpenIdUserInfo(BaseFederationServlet):
REQUIRE_AUTH = False
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query):
+ async def on_GET(self, origin, content, query):
token = query.get(b"access_token", [None])[0]
if token is None:
- defer.returnValue(
- (401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"})
+ return (
+ 401,
+ {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"},
)
- return
- user_id = yield self.handler.on_openid_userinfo(token.decode("ascii"))
+ user_id = await self.handler.on_openid_userinfo(token.decode("ascii"))
if user_id is None:
- defer.returnValue(
- (
- 401,
- {
- "errcode": "M_UNKNOWN_TOKEN",
- "error": "Access Token unknown or expired",
- },
- )
+ return (
+ 401,
+ {
+ "errcode": "M_UNKNOWN_TOKEN",
+ "error": "Access Token unknown or expired",
+ },
)
- defer.returnValue((200, {"sub": user_id}))
+ return 200, {"sub": user_id}
class PublicRoomList(BaseFederationServlet):
@@ -727,8 +725,7 @@ class PublicRoomList(BaseFederationServlet):
)
self.allow_access = allow_access
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query):
+ async def on_GET(self, origin, content, query):
if not self.allow_access:
raise FederationDeniedError(origin)
@@ -748,10 +745,10 @@ class PublicRoomList(BaseFederationServlet):
else:
network_tuple = ThirdPartyInstanceID(None, None)
- data = yield self.handler.get_local_public_room_list(
+ data = await self.handler.get_local_public_room_list(
limit, since_token, network_tuple=network_tuple, from_federation=True
)
- defer.returnValue((200, data))
+ return 200, data
class FederationVersionServlet(BaseFederationServlet):
@@ -759,12 +756,10 @@ class FederationVersionServlet(BaseFederationServlet):
REQUIRE_AUTH = False
- def on_GET(self, origin, content, query):
- return defer.succeed(
- (
- 200,
- {"server": {"name": "Synapse", "version": get_version_string(synapse)}},
- )
+ async def on_GET(self, origin, content, query):
+ return (
+ 200,
+ {"server": {"name": "Synapse", "version": get_version_string(synapse)}},
)
@@ -774,41 +769,38 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/profile"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_group_profile(group_id, requester_user_id)
+ new_content = await self.handler.get_group_profile(group_id, requester_user_id)
- defer.returnValue((200, new_content))
+ return 200, new_content
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id):
+ async def on_POST(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.update_group_profile(
+ new_content = await self.handler.update_group_profile(
group_id, requester_user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsSummaryServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_group_summary(group_id, requester_user_id)
+ new_content = await self.handler.get_group_summary(group_id, requester_user_id)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsRoomsServlet(BaseFederationServlet):
@@ -817,15 +809,14 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id)
+ new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
@@ -834,29 +825,27 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, room_id):
+ async def on_POST(self, origin, content, query, group_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.add_room_to_group(
+ new_content = await self.handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, room_id):
+ async def on_DELETE(self, origin, content, query, group_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.remove_room_from_group(
+ new_content = await self.handler.remove_room_from_group(
group_id, requester_user_id, room_id
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
@@ -868,17 +857,16 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
"/config/(?P<config_key>[^/]*)"
)
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, room_id, config_key):
+ async def on_POST(self, origin, content, query, group_id, room_id, config_key):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- result = yield self.groups_handler.update_room_in_group(
+ result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content
)
- defer.returnValue((200, result))
+ return 200, result
class FederationGroupsUsersServlet(BaseFederationServlet):
@@ -887,15 +875,14 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_users_in_group(group_id, requester_user_id)
+ new_content = await self.handler.get_users_in_group(group_id, requester_user_id)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
@@ -904,17 +891,16 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_invited_users_in_group(
+ new_content = await self.handler.get_invited_users_in_group(
group_id, requester_user_id
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsInviteServlet(BaseFederationServlet):
@@ -923,17 +909,16 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.invite_to_group(
+ new_content = await self.handler.invite_to_group(
group_id, user_id, requester_user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
@@ -942,14 +927,13 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
- new_content = yield self.handler.accept_invite(group_id, user_id, content)
+ new_content = await self.handler.accept_invite(group_id, user_id, content)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsJoinServlet(BaseFederationServlet):
@@ -958,14 +942,13 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
- new_content = yield self.handler.join_group(group_id, user_id, content)
+ new_content = await self.handler.join_group(group_id, user_id, content)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
@@ -974,17 +957,16 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.remove_user_from_group(
+ new_content = await self.handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
@@ -993,14 +975,13 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
- new_content = yield self.handler.on_invite(group_id, user_id, content)
+ new_content = await self.handler.on_invite(group_id, user_id, content)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
@@ -1009,16 +990,15 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
- new_content = yield self.handler.user_removed_from_group(
+ new_content = await self.handler.user_removed_from_group(
group_id, user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
@@ -1027,15 +1007,14 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
# We don't need to check auth here as we check the attestation signatures
- new_content = yield self.handler.on_renew_attestation(
+ new_content = await self.handler.on_renew_attestation(
group_id, user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
@@ -1052,8 +1031,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
"/rooms/(?P<room_id>[^/]*)"
)
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, category_id, room_id):
+ async def on_POST(self, origin, content, query, group_id, category_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1061,7 +1039,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
- resp = yield self.handler.update_group_summary_room(
+ resp = await self.handler.update_group_summary_room(
group_id,
requester_user_id,
room_id=room_id,
@@ -1069,10 +1047,9 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
content=content,
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+ async def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1080,11 +1057,11 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
- resp = yield self.handler.delete_group_summary_room(
+ resp = await self.handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsCategoriesServlet(BaseFederationServlet):
@@ -1093,15 +1070,14 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_categories(group_id, requester_user_id)
+ resp = await self.handler.get_group_categories(group_id, requester_user_id)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsCategoryServlet(BaseFederationServlet):
@@ -1110,20 +1086,18 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id, category_id):
+ async def on_GET(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_category(
+ resp = await self.handler.get_group_category(
group_id, requester_user_id, category_id
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, category_id):
+ async def on_POST(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1131,14 +1105,13 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
- resp = yield self.handler.upsert_group_category(
+ resp = await self.handler.upsert_group_category(
group_id, requester_user_id, category_id, content
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, category_id):
+ async def on_DELETE(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1146,11 +1119,11 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
- resp = yield self.handler.delete_group_category(
+ resp = await self.handler.delete_group_category(
group_id, requester_user_id, category_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsRolesServlet(BaseFederationServlet):
@@ -1159,15 +1132,14 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_roles(group_id, requester_user_id)
+ resp = await self.handler.get_group_roles(group_id, requester_user_id)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsRoleServlet(BaseFederationServlet):
@@ -1176,18 +1148,16 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id, role_id):
+ async def on_GET(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id)
+ resp = await self.handler.get_group_role(group_id, requester_user_id, role_id)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, role_id):
+ async def on_POST(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1195,14 +1165,13 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
- resp = yield self.handler.update_group_role(
+ resp = await self.handler.update_group_role(
group_id, requester_user_id, role_id, content
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, role_id):
+ async def on_DELETE(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1210,11 +1179,11 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
- resp = yield self.handler.delete_group_role(
+ resp = await self.handler.delete_group_role(
group_id, requester_user_id, role_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
@@ -1231,8 +1200,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
"/users/(?P<user_id>[^/]*)"
)
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, role_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, role_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1240,7 +1208,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
- resp = yield self.handler.update_group_summary_user(
+ resp = await self.handler.update_group_summary_user(
group_id,
requester_user_id,
user_id=user_id,
@@ -1248,10 +1216,9 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
content=content,
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+ async def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1259,11 +1226,11 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
- resp = yield self.handler.delete_group_summary_user(
+ resp = await self.handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
@@ -1272,13 +1239,12 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
PATH = "/get_groups_publicised"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query):
- resp = yield self.handler.bulk_get_publicised_groups(
+ async def on_POST(self, origin, content, query):
+ resp = await self.handler.bulk_get_publicised_groups(
content["user_ids"], proxy=False
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
@@ -1287,17 +1253,16 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, group_id):
+ async def on_PUT(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.set_group_join_policy(
+ new_content = await self.handler.set_group_join_policy(
group_id, requester_user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class RoomComplexityServlet(BaseFederationServlet):
@@ -1309,18 +1274,17 @@ class RoomComplexityServlet(BaseFederationServlet):
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, room_id):
+ async def on_GET(self, origin, content, query, room_id):
store = self.handler.hs.get_datastore()
- is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id)
+ is_public = await store.is_room_world_readable_or_publicly_joinable(room_id)
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
- complexity = yield store.get_room_complexity(room_id)
- defer.returnValue((200, complexity))
+ complexity = await store.get_room_complexity(room_id)
+ return 200, complexity
FEDERATION_SERVLET_CLASSES = (
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index e7375757..f4977111 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -43,9 +43,9 @@ from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
-from synapse.util.logcontext import run_in_background
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index edb48054..1f1708ba 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -22,10 +22,10 @@ from email.mime.text import MIMEText
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
-from synapse.util.logcontext import make_deferred_yieldable
try:
from synapse.push.mailer import load_jinja2_templates
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 941ebfa1..e8a651e2 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -17,6 +17,10 @@ import logging
from twisted.internet import defer
+from synapse.api.constants import Membership
+from synapse.types import RoomStreamToken
+from synapse.visibility import filter_events_for_client
+
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -89,3 +93,182 @@ class AdminHandler(BaseHandler):
ret = yield self.store.search_users(term)
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def export_user_data(self, user_id, writer):
+ """Write all data we have on the user to the given writer.
+
+ Args:
+ user_id (str)
+ writer (ExfiltrationWriter)
+
+ Returns:
+ defer.Deferred: Resolves when all data for a user has been written.
+ The returned value is that returned by `writer.finished()`.
+ """
+ # Get all rooms the user is in or has been in
+ rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id,
+ membership_list=(
+ Membership.JOIN,
+ Membership.LEAVE,
+ Membership.BAN,
+ Membership.INVITE,
+ ),
+ )
+
+ # We only try and fetch events for rooms the user has been in. If
+ # they've been e.g. invited to a room without joining then we handle
+ # those seperately.
+ rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id)
+
+ for index, room in enumerate(rooms):
+ room_id = room.room_id
+
+ logger.info(
+ "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
+ )
+
+ forgotten = yield self.store.did_forget(user_id, room_id)
+ if forgotten:
+ logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
+ continue
+
+ if room_id not in rooms_user_has_been_in:
+ # If we haven't been in the rooms then the filtering code below
+ # won't return anything, so we need to handle these cases
+ # explicitly.
+
+ if room.membership == Membership.INVITE:
+ event_id = room.event_id
+ invite = yield self.store.get_event(event_id, allow_none=True)
+ if invite:
+ invited_state = invite.unsigned["invite_room_state"]
+ writer.write_invite(room_id, invite, invited_state)
+
+ continue
+
+ # We only want to bother fetching events up to the last time they
+ # were joined. We estimate that point by looking at the
+ # stream_ordering of the last membership if it wasn't a join.
+ if room.membership == Membership.JOIN:
+ stream_ordering = yield self.store.get_room_max_stream_ordering()
+ else:
+ stream_ordering = room.stream_ordering
+
+ from_key = str(RoomStreamToken(0, 0))
+ to_key = str(RoomStreamToken(None, stream_ordering))
+
+ written_events = set() # Events that we've processed in this room
+
+ # We need to track gaps in the events stream so that we can then
+ # write out the state at those events. We do this by keeping track
+ # of events whose prev events we haven't seen.
+
+ # Map from event ID to prev events that haven't been processed,
+ # dict[str, set[str]].
+ event_to_unseen_prevs = {}
+
+ # The reverse mapping to above, i.e. map from unseen event to events
+ # that have the unseen event in their prev_events, i.e. the unseen
+ # events "children". dict[str, set[str]]
+ unseen_to_child_events = {}
+
+ # We fetch events in the room the user could see by fetching *all*
+ # events that we have and then filtering, this isn't the most
+ # efficient method perhaps but it does guarantee we get everything.
+ while True:
+ events, _ = yield self.store.paginate_room_events(
+ room_id, from_key, to_key, limit=100, direction="f"
+ )
+ if not events:
+ break
+
+ from_key = events[-1].internal_metadata.after
+
+ events = yield filter_events_for_client(self.store, user_id, events)
+
+ writer.write_events(room_id, events)
+
+ # Update the extremity tracking dicts
+ for event in events:
+ # Check if we have any prev events that haven't been
+ # processed yet, and add those to the appropriate dicts.
+ unseen_events = set(event.prev_event_ids()) - written_events
+ if unseen_events:
+ event_to_unseen_prevs[event.event_id] = unseen_events
+ for unseen in unseen_events:
+ unseen_to_child_events.setdefault(unseen, set()).add(
+ event.event_id
+ )
+
+ # Now check if this event is an unseen prev event, if so
+ # then we remove this event from the appropriate dicts.
+ for child_id in unseen_to_child_events.pop(event.event_id, []):
+ event_to_unseen_prevs[child_id].discard(event.event_id)
+
+ written_events.add(event.event_id)
+
+ logger.info(
+ "Written %d events in room %s", len(written_events), room_id
+ )
+
+ # Extremities are the events who have at least one unseen prev event.
+ extremities = (
+ event_id
+ for event_id, unseen_prevs in event_to_unseen_prevs.items()
+ if unseen_prevs
+ )
+ for event_id in extremities:
+ if not event_to_unseen_prevs[event_id]:
+ continue
+ state = yield self.store.get_state_for_event(event_id)
+ writer.write_state(room_id, event_id, state)
+
+ defer.returnValue(writer.finished())
+
+
+class ExfiltrationWriter(object):
+ """Interface used to specify how to write exported data.
+ """
+
+ def write_events(self, room_id, events):
+ """Write a batch of events for a room.
+
+ Args:
+ room_id (str)
+ events (list[FrozenEvent])
+ """
+ pass
+
+ def write_state(self, room_id, event_id, state):
+ """Write the state at the given event in the room.
+
+ This only gets called for backward extremities rather than for each
+ event.
+
+ Args:
+ room_id (str)
+ event_id (str)
+ state (dict[tuple[str, str], FrozenEvent])
+ """
+ pass
+
+ def write_invite(self, room_id, event, state):
+ """Write an invite for the room, with associated invite state.
+
+ Args:
+ room_id (str)
+ event (FrozenEvent)
+ state (dict[tuple[str, str], dict]): A subset of the state at the
+ invite, with a subset of the event keys (type, state_key
+ content and sender)
+ """
+
+ def finished(self):
+ """Called when all data has succesfully been exported and written.
+
+ This functions return value is passed to the caller of
+ `export_user_data`.
+ """
+ pass
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5cc89d43..8f089f0e 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -23,13 +23,13 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import (
event_processing_loop_counter,
event_processing_loop_room_count,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import log_failure
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c8c1ed32..d4d65749 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+import time
import unicodedata
import attr
@@ -34,11 +35,12 @@ from synapse.api.errors import (
LoginError,
StoreError,
SynapseError,
+ UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
+from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
from synapse.types import UserID
-from synapse.util import logcontext
from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@@ -558,7 +560,7 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id=None):
+ def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
"""
Creates a new access token for the user with the given user ID.
@@ -572,15 +574,27 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
+ valid_until_ms (int|None): when the token is valid until. None for
+ no expiry.
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
"""
- logger.info("Logging in user %s on device %s", user_id, device_id)
- access_token = yield self.issue_access_token(user_id, device_id)
+ fmt_expiry = ""
+ if valid_until_ms is not None:
+ fmt_expiry = time.strftime(
+ " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
+ )
+ logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
yield self.auth.check_auth_blocking(user_id)
+ access_token = self.macaroon_gen.generate_access_token(user_id)
+ yield self.store.add_access_token_to_user(
+ user_id, access_token, device_id, valid_until_ms
+ )
+
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
# really don't want is active access_tokens without a record of the
@@ -610,6 +624,7 @@ class AuthHandler(BaseHandler):
Raises:
LimitExceededError if the ratelimiter's login requests count for this
user is too high too proceed.
+ UserDeactivatedError if a user is found but is deactivated.
"""
self.ratelimit_login_per_account(user_id)
res = yield self._find_user_id_and_pwd_hash(user_id)
@@ -825,6 +840,13 @@ class AuthHandler(BaseHandler):
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
+
+ # If the password hash is None, the account has likely been deactivated
+ if not password_hash:
+ deactivated = yield self.store.get_user_deactivated_status(user_id)
+ if deactivated:
+ raise UserDeactivatedError("This account has been deactivated")
+
result = yield self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
@@ -832,12 +854,6 @@ class AuthHandler(BaseHandler):
defer.returnValue(user_id)
@defer.inlineCallbacks
- def issue_access_token(self, user_id, device_id=None):
- access_token = self.macaroon_gen.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(user_id, access_token, device_id)
- defer.returnValue(access_token)
-
- @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
user_id = None
@@ -987,7 +1003,7 @@ class AuthHandler(BaseHandler):
bcrypt.gensalt(self.bcrypt_rounds),
).decode("ascii")
- return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
+ return defer_to_thread(self.hs.get_reactor(), _do_hash)
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -1013,7 +1029,7 @@ class AuthHandler(BaseHandler):
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii")
- return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
+ return defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
return defer.succeed(False)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 807900fe..fdfe8611 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -22,9 +22,9 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
-from synapse.api.errors import CodeMessageException, FederationDeniedError, SynapseError
+from synapse.api.errors import CodeMessageException, SynapseError
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import UserID, get_domain_from_id
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -350,9 +350,6 @@ def _exception_to_failure(e):
if isinstance(e, NotRetryingDestination):
return {"status": 503, "message": "Not ready for retry"}
- if isinstance(e, FederationDeniedError):
- return {"status": 403, "message": "Federation Denied"}
-
# include ConnectionRefused and other errors
#
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 5836d3c6..6a38328a 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -21,8 +21,8 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
+from synapse.logging.utils import log_function
from synapse.types import UserID
-from synapse.util.logutils import log_function
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 02d397c4..30b69af8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,13 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
from synapse.events.validator import EventValidator
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ preserve_fn,
+ run_in_background,
+)
+from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
@@ -52,10 +59,9 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.types import UserID, get_domain_from_id
-from synapse.util import logcontext, unwrapFirstError
+from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
-from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@@ -338,7 +344,7 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(room_id)
- with logcontext.nested_logging_context(p):
+ with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
@@ -532,7 +538,7 @@ class FederationHandler(BaseHandler):
event_id,
ev.event_id,
)
- with logcontext.nested_logging_context(ev.event_id):
+ with nested_logging_context(ev.event_id):
try:
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
@@ -725,10 +731,10 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch,
)
- results = yield logcontext.make_deferred_yieldable(
+ results = yield make_deferred_yieldable(
defer.gatherResults(
[
- logcontext.run_in_background(
+ run_in_background(
self.federation_client.get_pdu,
[dest],
event_id,
@@ -994,10 +1000,8 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
- resolve = logcontext.preserve_fn(
- self.state_handler.resolve_state_groups_for_events
- )
- states = yield logcontext.make_deferred_yieldable(
+ resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
+ states = yield make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1171,7 +1175,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
- logcontext.run_in_background(self._handle_queued_pdus, room_queue)
+ run_in_background(self._handle_queued_pdus, room_queue)
defer.returnValue(True)
@@ -1191,7 +1195,7 @@ class FederationHandler(BaseHandler):
p.event_id,
p.room_id,
)
- with logcontext.nested_logging_context(p.event_id):
+ with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warn(
@@ -1200,11 +1204,28 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def on_make_join_request(self, room_id, user_id):
+ def on_make_join_request(self, origin, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
+
+ Args:
+ origin (str): The (verified) server name of the requesting server.
+ room_id (str): Room to create join event in
+ user_id (str): The user to create the join for
+
+ Returns:
+ Deferred[FrozenEvent]
"""
+
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Got /make_join request for user %r from different origin %s, ignoring",
+ user_id,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
event_content = {"membership": Membership.JOIN}
room_version = yield self.store.get_room_version(room_id)
@@ -1407,11 +1428,27 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def on_make_leave_request(self, room_id, user_id):
+ def on_make_leave_request(self, origin, room_id, user_id):
""" We've received a /make_leave/ request, so we create a partial
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
+
+ Args:
+ origin (str): The (verified) server name of the requesting server.
+ room_id (str): Room to create leave event in
+ user_id (str): The user to create the leave for
+
+ Returns:
+ Deferred[FrozenEvent]
"""
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Got /make_leave request for user %r from different origin %s, ignoring",
+ user_id,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
room_version = yield self.store.get_room_version(room_id)
builder = self.event_builder_factory.new(
room_version,
@@ -1610,7 +1647,7 @@ class FederationHandler(BaseHandler):
success = True
finally:
if not success:
- logcontext.run_in_background(
+ run_in_background(
self.store.remove_push_actions_from_staging, event.event_id
)
@@ -1629,7 +1666,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def prep(ev_info):
event = ev_info["event"]
- with logcontext.nested_logging_context(suffix=event.event_id):
+ with nested_logging_context(suffix=event.event_id):
res = yield self._prep_event(
origin,
event,
@@ -1639,12 +1676,9 @@ class FederationHandler(BaseHandler):
)
defer.returnValue(res)
- contexts = yield logcontext.make_deferred_yieldable(
+ contexts = yield make_deferred_yieldable(
defer.gatherResults(
- [
- logcontext.run_in_background(prep, ev_info)
- for ev_info in event_infos
- ],
+ [run_in_background(prep, ev_info) for ev_info in event_infos],
consumeErrors=True,
)
)
@@ -2106,10 +2140,10 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(event.room_id)
- different_events = yield logcontext.make_deferred_yieldable(
+ different_events = yield make_deferred_yieldable(
defer.gatherResults(
[
- logcontext.run_in_background(
+ run_in_background(
self.store.get_event, d, allow_none=True, allow_rejected=False
)
for d in different_auth
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index c82b1933..546d6169 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -118,7 +118,7 @@ class IdentityHandler(BaseHandler):
raise SynapseError(400, "No client_secret in creds")
try:
- data = yield self.http_client.post_urlencoded_get_json(
+ data = yield self.http_client.post_json_get_json(
"https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"),
{"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid},
)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index a1fe9d11..54c966c8 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -21,12 +21,12 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 683da6bf..6d7a987f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -23,6 +23,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
+from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.api.errors import (
AuthError,
@@ -34,13 +35,13 @@ from synapse.api.errors import (
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events.validator import EventValidator
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -784,6 +785,20 @@ class EventCreationHandler(object):
event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
+ original_event = yield self.store.get_event(
+ event.redacts,
+ check_redacted=False,
+ get_prev_content=False,
+ allow_rejected=False,
+ allow_none=True,
+ check_room_id=event.room_id,
+ )
+
+ # we can make some additional checks now if we have the original event.
+ if original_event:
+ if original_event.type == EventTypes.Create:
+ raise AuthError(403, "Redacting create events is not permitted")
+
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
@@ -791,18 +806,18 @@ class EventCreationHandler(object):
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version(event.room_id)
- if self.auth.check_redaction(room_version, event, auth_events=auth_events):
- original_event = yield self.store.get_event(
- event.redacts,
- check_redacted=False,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=False,
- )
+
+ if event_auth.check_redaction(room_version, event, auth_events=auth_events):
+ # this user doesn't have 'redact' rights, so we need to do some more
+ # checks on the original event. Let's start by checking the original
+ # event exists.
+ if not original_event:
+ raise NotFoundError("Could not find event %s" % (event.redacts,))
+
if event.user_id != original_event.user_id:
raise AuthError(403, "You don't have permission to redact events")
- # We've already checked.
+ # all the checks are done.
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 76ee97dd..20bcfed3 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -20,10 +20,10 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
-from synapse.util.logcontext import run_in_background
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index c80dc2eb..6f3537e4 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,14 +34,14 @@ from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
+from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from synapse.util.logcontext import run_in_background
-from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index d8462b75..a2388a70 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -303,6 +303,10 @@ class BaseProfileHandler(BaseHandler):
if not self.hs.config.require_auth_for_profile_requests or not requester:
return
+ # Always allow the user to query their own profile.
+ if target_user.to_string() == requester.to_string():
+ return
+
try:
requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = yield self.store.get_rooms_for_user(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a85dd8cd..e58bf7e3 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
from synapse.handlers._base import BaseHandler
-from synapse.types import ReadReceipt
+from synapse.types import ReadReceipt, get_domain_from_id
logger = logging.getLogger(__name__)
@@ -40,18 +40,27 @@ class ReceiptsHandler(BaseHandler):
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
- receipts = [
- ReadReceipt(
- room_id=room_id,
- receipt_type=receipt_type,
- user_id=user_id,
- event_ids=user_values["event_ids"],
- data=user_values.get("data", {}),
- )
- for room_id, room_values in content.items()
- for receipt_type, users in room_values.items()
- for user_id, user_values in users.items()
- ]
+ receipts = []
+ for room_id, room_values in content.items():
+ for receipt_type, users in room_values.items():
+ for user_id, user_values in users.items():
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Received receipt for user %r from server %s, ignoring",
+ user_id,
+ origin,
+ )
+ continue
+
+ receipts.append(
+ ReadReceipt(
+ room_id=room_id,
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_ids=user_values["event_ids"],
+ data=user_values.get("data", {}),
+ )
+ )
yield self._handle_new_receipts(receipts)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e487b90c..bb7cfd71 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -84,6 +84,8 @@ class RegistrationHandler(BaseHandler):
self.device_handler = hs.get_device_handler()
self.pusher_pool = hs.get_pusherpool()
+ self.session_lifetime = hs.config.session_lifetime
+
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
if types.contains_invalid_mxid_characters(localpart):
@@ -138,11 +140,10 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def register(
+ def register_user(
self,
localpart=None,
password=None,
- generate_token=True,
guest_access_token=None,
make_guest=False,
admin=False,
@@ -160,11 +161,6 @@ class RegistrationHandler(BaseHandler):
password (unicode) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
- generate_token (bool): Whether a new access token should be
- generated. Having this be True should be considered deprecated,
- since it offers no means of associating a device_id with the
- access_token. Instead you should call auth_handler.issue_access_token
- after registration.
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname
@@ -172,7 +168,7 @@ class RegistrationHandler(BaseHandler):
address (str|None): the IP address used to perform the registration.
bind_emails (List[str]): list of emails to bind to this account.
Returns:
- A tuple of (user_id, access_token).
+ Deferred[str]: user_id
Raises:
RegistrationError if there was a problem registering.
"""
@@ -206,12 +202,8 @@ class RegistrationHandler(BaseHandler):
elif default_display_name is None:
default_display_name = localpart
- token = None
- if generate_token:
- token = self.macaroon_gen.generate_access_token(user_id)
yield self.register_with_store(
user_id=user_id,
- token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
@@ -230,21 +222,17 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
attempts = 0
- token = None
user = None
while not user:
localpart = yield self._generate_user_id(attempts > 0)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
- if generate_token:
- token = self.macaroon_gen.generate_access_token(user_id)
if default_display_name is None:
default_display_name = localpart
try:
yield self.register_with_store(
user_id=user_id,
- token=token,
password_hash=password_hash,
make_guest=make_guest,
create_profile_with_displayname=default_display_name,
@@ -254,10 +242,15 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
user = None
user_id = None
- token = None
attempts += 1
+
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
+ else:
+ logger.info(
+ "Skipping auto-join for %s because consent is required at registration",
+ user_id,
+ )
# Bind any specified emails to this account
current_time = self.hs.get_clock().time_msec()
@@ -272,7 +265,7 @@ class RegistrationHandler(BaseHandler):
# Bind email to new account
yield self._register_email_threepid(user_id, threepid_dict, None, False)
- defer.returnValue((user_id, token))
+ defer.returnValue(user_id)
@defer.inlineCallbacks
def _auto_join_rooms(self, user_id):
@@ -298,6 +291,7 @@ class RegistrationHandler(BaseHandler):
count = yield self.store.count_all_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
+ logger.info("Auto-joining %s to %s", user_id, r)
try:
if should_auto_create_rooms:
room_alias = RoomAlias.from_string(r)
@@ -506,87 +500,6 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
- """Creates a new user if the user does not exist,
- else revokes all previous access tokens and generates a new one.
-
- Args:
- localpart : The local part of the user ID to register. If None,
- one will be randomly generated.
- Returns:
- A tuple of (user_id, access_token).
- Raises:
- RegistrationError if there was a problem registering.
-
- NB this is only used in tests. TODO: move it to the test package!
- """
- if localpart is None:
- raise SynapseError(400, "Request must include user id")
- yield self.auth.check_auth_blocking()
- need_register = True
-
- try:
- yield self.check_username(localpart)
- except SynapseError as e:
- if e.errcode == Codes.USER_IN_USE:
- need_register = False
- else:
- raise
-
- user = UserID(localpart, self.hs.hostname)
- user_id = user.to_string()
- token = self.macaroon_gen.generate_access_token(user_id)
-
- if need_register:
- yield self.register_with_store(
- user_id=user_id,
- token=token,
- password_hash=password_hash,
- create_profile_with_displayname=user.localpart,
- )
- else:
- yield self._auth_handler.delete_access_tokens_for_user(user_id)
- yield self.store.add_access_token_to_user(user_id=user_id, token=token)
-
- if displayname is not None:
- logger.info("setting user display name: %s -> %s", user_id, displayname)
- yield self.profile_handler.set_displayname(
- user, requester, displayname, by_admin=True
- )
-
- defer.returnValue((user_id, token))
-
- @defer.inlineCallbacks
- def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
- """Get a guest access token for a 3PID, creating a guest account if
- one doesn't already exist.
-
- Args:
- medium (str)
- address (str)
- inviter_user_id (str): The user ID who is trying to invite the
- 3PID
-
- Returns:
- Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
- 3PID guest account.
- """
- access_token = yield self.store.get_3pid_guest_access_token(medium, address)
- if access_token:
- user_info = yield self.auth.get_user_by_access_token(access_token)
-
- defer.returnValue((user_info["user"].to_string(), access_token))
-
- user_id, access_token = yield self.register(
- generate_token=True, make_guest=True
- )
- access_token = yield self.store.save_or_get_3pid_guest_access_token(
- medium, address, access_token, inviter_user_id
- )
-
- defer.returnValue((user_id, access_token))
-
- @defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier):
room_id = None
room_member_handler = self.hs.get_room_member_handler()
@@ -615,7 +528,6 @@ class RegistrationHandler(BaseHandler):
def register_with_store(
self,
user_id,
- token=None,
password_hash=None,
was_guest=False,
make_guest=False,
@@ -629,9 +541,6 @@ class RegistrationHandler(BaseHandler):
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user. If this
- is not None, the given access token is associated with the user
- id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -667,7 +576,6 @@ class RegistrationHandler(BaseHandler):
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,
- token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
@@ -678,9 +586,8 @@ class RegistrationHandler(BaseHandler):
address=address,
)
else:
- return self.store.register(
+ return self.store.register_user(
user_id=user_id,
- token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
@@ -694,6 +601,8 @@ class RegistrationHandler(BaseHandler):
def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
"""Register a device for a user and generate an access token.
+ The access token will be limited by the homeserver's session_lifetime config.
+
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
@@ -714,20 +623,29 @@ class RegistrationHandler(BaseHandler):
is_guest=is_guest,
)
defer.returnValue((r["device_id"], r["access_token"]))
- else:
- device_id = yield self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+
+ valid_until_ms = None
+ if self.session_lifetime is not None:
if is_guest:
- access_token = self.macaroon_gen.generate_access_token(
- user_id, ["guest = true"]
- )
- else:
- access_token = yield self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id
+ raise Exception(
+ "session_lifetime is not currently implemented for guest access"
)
+ valid_until_ms = self.clock.time_msec() + self.session_lifetime
+
+ device_id = yield self.device_handler.check_device_registered(
+ user_id, device_id, initial_display_name
+ )
+ if is_guest:
+ assert valid_until_ms is None
+ access_token = self.macaroon_gen.generate_access_token(
+ user_id, ["guest = true"]
+ )
+ else:
+ access_token = yield self._auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ )
- defer.returnValue((device_id, access_token))
+ defer.returnValue((device_id, access_token))
@defer.inlineCallbacks
def post_registration_actions(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 66b05b47..e0196ef8 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -29,7 +29,7 @@ from twisted.internet import defer
import synapse.server
import synapse.types
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError
from synapse.types import RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -119,24 +119,6 @@ class RoomMemberHandler(object):
raise NotImplementedError()
@abc.abstractmethod
- def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
- """Get a guest access token for a 3PID, creating a guest account if
- one doesn't already exist.
-
- Args:
- requester (Requester)
- medium (str)
- address (str)
- inviter_user_id (str): The user ID who is trying to invite the
- 3PID
-
- Returns:
- Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
- 3PID guest account.
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
def _user_joined_room(self, target, room_id):
"""Notifies distributor on master process that the user has joined the
room.
@@ -890,24 +872,23 @@ class RoomMemberHandler(object):
"sender_avatar_url": inviter_avatar_url,
}
- if self.config.invite_3pid_guest:
- guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest(
- requester=requester,
- medium=medium,
- address=address,
- inviter_user_id=inviter_user_id,
+ try:
+ data = yield self.simple_http_client.post_json_get_json(
+ is_url, invite_config
)
-
- invite_config.update(
- {
- "guest_access_token": guest_access_token,
- "guest_user_id": guest_user_id,
- }
+ except HttpResponseException as e:
+ # Some identity servers may only support application/x-www-form-urlencoded
+ # types. This is especially true with old instances of Sydent, see
+ # https://github.com/matrix-org/sydent/pull/170
+ logger.info(
+ "Failed to POST %s with JSON, falling back to urlencoded form: %s",
+ is_url,
+ e,
+ )
+ data = yield self.simple_http_client.post_urlencoded_get_json(
+ is_url, invite_config
)
- data = yield self.simple_http_client.post_urlencoded_get_json(
- is_url, invite_config
- )
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
@@ -1010,12 +991,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
yield self.store.locally_reject_invite(target.to_string(), room_id)
defer.returnValue({})
- def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
- """Implements RoomMemberHandler.get_or_register_3pid_guest
- """
- rg = self.registration_handler
- return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
-
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index da501f38..fc873a3b 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -20,7 +20,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
- ReplicationRegister3PIDGuestRestServlet as Repl3PID,
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
@@ -33,7 +32,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def __init__(self, hs):
super(RoomMemberWorkerHandler, self).__init__(hs)
- self._get_register_3pid_client = Repl3PID.make_client(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
@@ -80,13 +78,3 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left"
)
-
- def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
- """Implements RoomMemberHandler.get_or_register_3pid_guest
- """
- return self._get_register_3pid_client(
- requester=requester,
- medium=medium,
- address=address,
- inviter_user_id=inviter_user_id,
- )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a3f55055..cd1ac0a2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -25,6 +25,7 @@ from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.logging.context import LoggingContext
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
@@ -33,7 +34,6 @@ from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index f8062c86..c3e0c8fc 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -19,9 +19,9 @@ from collections import namedtuple
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
+from synapse.logging.context import run_in_background
from synapse.types import UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 9bc7035c..45d50109 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,9 +45,9 @@ from synapse.http import (
cancelled_to_request_timed_out_error,
redact_uri,
)
+from synapse.logging.context import make_deferred_yieldable
from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
-from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 414cde07..054c321a 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -30,9 +30,9 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
from synapse.util.caches.ttlcache import TTLCache
-from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
# period to cache .well-known results for by default
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 1f22f78a..ecc88f9b 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -25,7 +25,7 @@ from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
-from synapse.util.logcontext import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable
logger = logging.getLogger(__name__)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 5ef8bb60..e6033454 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -36,6 +36,7 @@ from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
+import synapse.logging.opentracing as opentracing
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
@@ -48,8 +49,8 @@ from synapse.api.errors import (
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.logging.context import make_deferred_yieldable
from synapse.util.async_helpers import timeout_deferred
-from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -339,9 +340,25 @@ class MatrixFederationHttpClient(object):
else:
query_bytes = b""
- headers_dict = {b"User-Agent": [self.version_string_bytes]}
+ # Retreive current span
+ scope = opentracing.start_active_span(
+ "outgoing-federation-request",
+ tags={
+ opentracing.tags.SPAN_KIND: opentracing.tags.SPAN_KIND_RPC_CLIENT,
+ opentracing.tags.PEER_ADDRESS: request.destination,
+ opentracing.tags.HTTP_METHOD: request.method,
+ opentracing.tags.HTTP_URL: request.path,
+ },
+ finish_on_close=True,
+ )
+
+ # Inject the span into the headers
+ headers_dict = {}
+ opentracing.inject_active_span_byte_dict(headers_dict, request.destination)
- with limiter:
+ headers_dict[b"User-Agent"] = [self.version_string_bytes]
+
+ with limiter, scope:
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
if long_retries:
@@ -419,6 +436,10 @@ class MatrixFederationHttpClient(object):
response.phrase.decode("ascii", errors="replace"),
)
+ opentracing.set_tag(
+ opentracing.tags.HTTP_STATUS_CODE, response.code
+ )
+
if 200 <= response.code < 300:
pass
else:
@@ -499,8 +520,7 @@ class MatrixFederationHttpClient(object):
_flatten_response_never_received(e),
)
raise
-
- defer.returnValue(response)
+ defer.returnValue(response)
def build_auth_headers(
self, destination, method, url_bytes, content=None, destination_is=None
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 62045a91..46af27c8 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -19,8 +19,8 @@ import threading
from prometheus_client.core import Counter, Histogram
+from synapse.logging.context import LoggingContext
from synapse.metrics import LaterGauge
-from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index d993161a..e6f351ba 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -39,8 +39,8 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
+from synapse.logging.context import preserve_fn
from synapse.util.caches import intern_dict
-from synapse.util.logcontext import preserve_fn
logger = logging.getLogger(__name__)
@@ -245,7 +245,9 @@ class JsonResource(HttpServer, resource.Resource):
isLeaf = True
- _PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
+ _PathEntry = collections.namedtuple(
+ "_PathEntry", ["pattern", "callback", "servlet_classname"]
+ )
def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self)
@@ -255,12 +257,28 @@ class JsonResource(HttpServer, resource.Resource):
self.path_regexs = {}
self.hs = hs
- def register_paths(self, method, path_patterns, callback):
+ def register_paths(self, method, path_patterns, callback, servlet_classname):
+ """
+ Registers a request handler against a regular expression. Later request URLs are
+ checked against these regular expressions in order to identify an appropriate
+ handler for that request.
+
+ Args:
+ method (str): GET, POST etc
+
+ path_patterns (Iterable[str]): A list of regular expressions to which
+ the request URLs are compared.
+
+ callback (function): The handler for the request. Usually a Servlet
+
+ servlet_classname (str): The name of the handler to be used in prometheus
+ and opentracing logs.
+ """
method = method.encode("utf-8") # method is bytes on py3
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
- self._PathEntry(path_pattern, callback)
+ self._PathEntry(path_pattern, callback, servlet_classname)
)
def render(self, request):
@@ -275,13 +293,9 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and
path.
"""
- callback, group_dict = self._get_handler_for_request(request)
+ callback, servlet_classname, group_dict = self._get_handler_for_request(request)
- servlet_instance = getattr(callback, "__self__", None)
- if servlet_instance is not None:
- servlet_classname = servlet_instance.__class__.__name__
- else:
- servlet_classname = "%r" % callback
+ # Make sure we have a name for this handler in prometheus.
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
@@ -311,7 +325,8 @@ class JsonResource(HttpServer, resource.Resource):
request (twisted.web.http.Request):
Returns:
- Tuple[Callable, dict[unicode, unicode]]: callback method, and the
+ Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
+ label to use for that method in prometheus metrics, and the
dict mapping keys to path components as specified in the
handler's path match regexp.
@@ -320,7 +335,7 @@ class JsonResource(HttpServer, resource.Resource):
None, or a tuple of (http code, response body).
"""
if request.method == b"OPTIONS":
- return _options_handler, {}
+ return _options_handler, "options_request_handler", {}
# Loop through all the registered callbacks to check if the method
# and path regex match
@@ -328,10 +343,10 @@ class JsonResource(HttpServer, resource.Resource):
m = path_entry.pattern.match(request.path.decode("ascii"))
if m:
# We found a match!
- return path_entry.callback, m.groupdict()
+ return path_entry.callback, path_entry.servlet_classname, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- return _unrecognised_request_handler, {}
+ return _unrecognised_request_handler, "unrecognised_request_handler", {}
def _send_response(
self, request, code, response_json_object, response_code_message=None
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index cd8415ac..f0ca7d9a 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -20,6 +20,7 @@ import logging
from canonicaljson import json
from synapse.api.errors import Codes, SynapseError
+from synapse.logging.opentracing import trace_servlet
logger = logging.getLogger(__name__)
@@ -289,8 +290,14 @@ class RestServlet(object):
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method,)):
+ servlet_classname = self.__class__.__name__
method_handler = getattr(self, "on_%s" % (method,))
- http_server.register_paths(method, patterns, method_handler)
+ http_server.register_paths(
+ method,
+ patterns,
+ trace_servlet(servlet_classname, method_handler),
+ servlet_classname,
+ )
else:
raise NotImplementedError("RestServlet must register something.")
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 93f679ea..df5274c1 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,7 +19,7 @@ from twisted.web.server import Request, Site
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
diff --git a/synapse/logging/__init__.py b/synapse/logging/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/synapse/logging/__init__.py
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
new file mode 100644
index 00000000..b456c31f
--- /dev/null
+++ b/synapse/logging/context.py
@@ -0,0 +1,697 @@
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Thread-local-alike tracking of log contexts within synapse
+
+This module provides objects and utilities for tracking contexts through
+synapse code, so that log lines can include a request identifier, and so that
+CPU and database activity can be accounted for against the request that caused
+them.
+
+See doc/log_contexts.rst for details on how this works.
+"""
+
+import logging
+import threading
+import types
+
+from twisted.internet import defer, threads
+
+logger = logging.getLogger(__name__)
+
+try:
+ import resource
+
+ # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined
+ # to be 1 on linux so we hard code it.
+ RUSAGE_THREAD = 1
+
+ # If the system doesn't support RUSAGE_THREAD then this should throw an
+ # exception.
+ resource.getrusage(RUSAGE_THREAD)
+
+ def get_thread_resource_usage():
+ return resource.getrusage(RUSAGE_THREAD)
+
+
+except Exception:
+ # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
+ # won't track resource usage by returning None.
+ def get_thread_resource_usage():
+ return None
+
+
+# get an id for the current thread.
+#
+# threading.get_ident doesn't actually return an OS-level tid, and annoyingly,
+# on Linux it actually returns the same value either side of a fork() call. However
+# we only fork in one place, so it's not worth the hoop-jumping to get a real tid.
+#
+get_thread_id = threading.get_ident
+
+
+class ContextResourceUsage(object):
+ """Object for tracking the resources used by a log context
+
+ Attributes:
+ ru_utime (float): user CPU time (in seconds)
+ ru_stime (float): system CPU time (in seconds)
+ db_txn_count (int): number of database transactions done
+ db_sched_duration_sec (float): amount of time spent waiting for a
+ database connection
+ db_txn_duration_sec (float): amount of time spent doing database
+ transactions (excluding scheduling time)
+ evt_db_fetch_count (int): number of events requested from the database
+ """
+
+ __slots__ = [
+ "ru_stime",
+ "ru_utime",
+ "db_txn_count",
+ "db_txn_duration_sec",
+ "db_sched_duration_sec",
+ "evt_db_fetch_count",
+ ]
+
+ def __init__(self, copy_from=None):
+ """Create a new ContextResourceUsage
+
+ Args:
+ copy_from (ContextResourceUsage|None): if not None, an object to
+ copy stats from
+ """
+ if copy_from is None:
+ self.reset()
+ else:
+ self.ru_utime = copy_from.ru_utime
+ self.ru_stime = copy_from.ru_stime
+ self.db_txn_count = copy_from.db_txn_count
+
+ self.db_txn_duration_sec = copy_from.db_txn_duration_sec
+ self.db_sched_duration_sec = copy_from.db_sched_duration_sec
+ self.evt_db_fetch_count = copy_from.evt_db_fetch_count
+
+ def copy(self):
+ return ContextResourceUsage(copy_from=self)
+
+ def reset(self):
+ self.ru_stime = 0.0
+ self.ru_utime = 0.0
+ self.db_txn_count = 0
+
+ self.db_txn_duration_sec = 0
+ self.db_sched_duration_sec = 0
+ self.evt_db_fetch_count = 0
+
+ def __repr__(self):
+ return (
+ "<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
+ "db_txn_count='%r', db_txn_duration_sec='%r', "
+ "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>"
+ ) % (
+ self.ru_stime,
+ self.ru_utime,
+ self.db_txn_count,
+ self.db_txn_duration_sec,
+ self.db_sched_duration_sec,
+ self.evt_db_fetch_count,
+ )
+
+ def __iadd__(self, other):
+ """Add another ContextResourceUsage's stats to this one's.
+
+ Args:
+ other (ContextResourceUsage): the other resource usage object
+ """
+ self.ru_utime += other.ru_utime
+ self.ru_stime += other.ru_stime
+ self.db_txn_count += other.db_txn_count
+ self.db_txn_duration_sec += other.db_txn_duration_sec
+ self.db_sched_duration_sec += other.db_sched_duration_sec
+ self.evt_db_fetch_count += other.evt_db_fetch_count
+ return self
+
+ def __isub__(self, other):
+ self.ru_utime -= other.ru_utime
+ self.ru_stime -= other.ru_stime
+ self.db_txn_count -= other.db_txn_count
+ self.db_txn_duration_sec -= other.db_txn_duration_sec
+ self.db_sched_duration_sec -= other.db_sched_duration_sec
+ self.evt_db_fetch_count -= other.evt_db_fetch_count
+ return self
+
+ def __add__(self, other):
+ res = ContextResourceUsage(copy_from=self)
+ res += other
+ return res
+
+ def __sub__(self, other):
+ res = ContextResourceUsage(copy_from=self)
+ res -= other
+ return res
+
+
+class LoggingContext(object):
+ """Additional context for log formatting. Contexts are scoped within a
+ "with" block.
+
+ If a parent is given when creating a new context, then:
+ - logging fields are copied from the parent to the new context on entry
+ - when the new context exits, the cpu usage stats are copied from the
+ child to the parent
+
+ Args:
+ name (str): Name for the context for debugging.
+ parent_context (LoggingContext|None): The parent of the new context
+ """
+
+ __slots__ = [
+ "previous_context",
+ "name",
+ "parent_context",
+ "_resource_usage",
+ "usage_start",
+ "main_thread",
+ "alive",
+ "request",
+ "tag",
+ "scope",
+ ]
+
+ thread_local = threading.local()
+
+ class Sentinel(object):
+ """Sentinel to represent the root context"""
+
+ __slots__ = []
+
+ def __str__(self):
+ return "sentinel"
+
+ def copy_to(self, record):
+ pass
+
+ def start(self):
+ pass
+
+ def stop(self):
+ pass
+
+ def add_database_transaction(self, duration_sec):
+ pass
+
+ def add_database_scheduled(self, sched_sec):
+ pass
+
+ def record_event_fetch(self, event_count):
+ pass
+
+ def __nonzero__(self):
+ return False
+
+ __bool__ = __nonzero__ # python3
+
+ sentinel = Sentinel()
+
+ def __init__(self, name=None, parent_context=None, request=None):
+ self.previous_context = LoggingContext.current_context()
+ self.name = name
+
+ # track the resources used by this context so far
+ self._resource_usage = ContextResourceUsage()
+
+ # If alive has the thread resource usage when the logcontext last
+ # became active.
+ self.usage_start = None
+
+ self.main_thread = get_thread_id()
+ self.request = None
+ self.tag = ""
+ self.alive = True
+ self.scope = None
+
+ self.parent_context = parent_context
+
+ if self.parent_context is not None:
+ self.parent_context.copy_to(self)
+
+ if request is not None:
+ # the request param overrides the request from the parent context
+ self.request = request
+
+ def __str__(self):
+ if self.request:
+ return str(self.request)
+ return "%s@%x" % (self.name, id(self))
+
+ @classmethod
+ def current_context(cls):
+ """Get the current logging context from thread local storage
+
+ Returns:
+ LoggingContext: the current logging context
+ """
+ return getattr(cls.thread_local, "current_context", cls.sentinel)
+
+ @classmethod
+ def set_current_context(cls, context):
+ """Set the current logging context in thread local storage
+ Args:
+ context(LoggingContext): The context to activate.
+ Returns:
+ The context that was previously active
+ """
+ current = cls.current_context()
+
+ if current is not context:
+ current.stop()
+ cls.thread_local.current_context = context
+ context.start()
+ return current
+
+ def __enter__(self):
+ """Enters this logging context into thread local storage"""
+ old_context = self.set_current_context(self)
+ if self.previous_context != old_context:
+ logger.warn(
+ "Expected previous context %r, found %r",
+ self.previous_context,
+ old_context,
+ )
+ self.alive = True
+
+ return self
+
+ def __exit__(self, type, value, traceback):
+ """Restore the logging context in thread local storage to the state it
+ was before this context was entered.
+ Returns:
+ None to avoid suppressing any exceptions that were thrown.
+ """
+ current = self.set_current_context(self.previous_context)
+ if current is not self:
+ if current is self.sentinel:
+ logger.warning("Expected logging context %s was lost", self)
+ else:
+ logger.warning(
+ "Expected logging context %s but found %s", self, current
+ )
+ self.previous_context = None
+ self.alive = False
+
+ # if we have a parent, pass our CPU usage stats on
+ if self.parent_context is not None and hasattr(
+ self.parent_context, "_resource_usage"
+ ):
+ self.parent_context._resource_usage += self._resource_usage
+
+ # reset them in case we get entered again
+ self._resource_usage.reset()
+
+ def copy_to(self, record):
+ """Copy logging fields from this context to a log record or
+ another LoggingContext
+ """
+
+ # we track the current request
+ record.request = self.request
+
+ # we also track the current scope:
+ record.scope = self.scope
+
+ def start(self):
+ if get_thread_id() != self.main_thread:
+ logger.warning("Started logcontext %s on different thread", self)
+ return
+
+ # If we haven't already started record the thread resource usage so
+ # far
+ if not self.usage_start:
+ self.usage_start = get_thread_resource_usage()
+
+ def stop(self):
+ if get_thread_id() != self.main_thread:
+ logger.warning("Stopped logcontext %s on different thread", self)
+ return
+
+ # When we stop, let's record the cpu used since we started
+ if not self.usage_start:
+ logger.warning("Called stop on logcontext %s without calling start", self)
+ return
+
+ utime_delta, stime_delta = self._get_cputime()
+ self._resource_usage.ru_utime += utime_delta
+ self._resource_usage.ru_stime += stime_delta
+
+ self.usage_start = None
+
+ def get_resource_usage(self):
+ """Get resources used by this logcontext so far.
+
+ Returns:
+ ContextResourceUsage: a *copy* of the object tracking resource
+ usage so far
+ """
+ # we always return a copy, for consistency
+ res = self._resource_usage.copy()
+
+ # If we are on the correct thread and we're currently running then we
+ # can include resource usage so far.
+ is_main_thread = get_thread_id() == self.main_thread
+ if self.alive and self.usage_start and is_main_thread:
+ utime_delta, stime_delta = self._get_cputime()
+ res.ru_utime += utime_delta
+ res.ru_stime += stime_delta
+
+ return res
+
+ def _get_cputime(self):
+ """Get the cpu usage time so far
+
+ Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
+ """
+ current = get_thread_resource_usage()
+
+ utime_delta = current.ru_utime - self.usage_start.ru_utime
+ stime_delta = current.ru_stime - self.usage_start.ru_stime
+
+ # sanity check
+ if utime_delta < 0:
+ logger.error(
+ "utime went backwards! %f < %f",
+ current.ru_utime,
+ self.usage_start.ru_utime,
+ )
+ utime_delta = 0
+
+ if stime_delta < 0:
+ logger.error(
+ "stime went backwards! %f < %f",
+ current.ru_stime,
+ self.usage_start.ru_stime,
+ )
+ stime_delta = 0
+
+ return utime_delta, stime_delta
+
+ def add_database_transaction(self, duration_sec):
+ if duration_sec < 0:
+ raise ValueError("DB txn time can only be non-negative")
+ self._resource_usage.db_txn_count += 1
+ self._resource_usage.db_txn_duration_sec += duration_sec
+
+ def add_database_scheduled(self, sched_sec):
+ """Record a use of the database pool
+
+ Args:
+ sched_sec (float): number of seconds it took us to get a
+ connection
+ """
+ if sched_sec < 0:
+ raise ValueError("DB scheduling time can only be non-negative")
+ self._resource_usage.db_sched_duration_sec += sched_sec
+
+ def record_event_fetch(self, event_count):
+ """Record a number of events being fetched from the db
+
+ Args:
+ event_count (int): number of events being fetched
+ """
+ self._resource_usage.evt_db_fetch_count += event_count
+
+
+class LoggingContextFilter(logging.Filter):
+ """Logging filter that adds values from the current logging context to each
+ record.
+ Args:
+ **defaults: Default values to avoid formatters complaining about
+ missing fields
+ """
+
+ def __init__(self, **defaults):
+ self.defaults = defaults
+
+ def filter(self, record):
+ """Add each fields from the logging contexts to the record.
+ Returns:
+ True to include the record in the log output.
+ """
+ context = LoggingContext.current_context()
+ for key, value in self.defaults.items():
+ setattr(record, key, value)
+
+ # context should never be None, but if it somehow ends up being, then
+ # we end up in a death spiral of infinite loops, so let's check, for
+ # robustness' sake.
+ if context is not None:
+ context.copy_to(record)
+
+ return True
+
+
+class PreserveLoggingContext(object):
+ """Captures the current logging context and restores it when the scope is
+ exited. Used to restore the context after a function using
+ @defer.inlineCallbacks is resumed by a callback from the reactor."""
+
+ __slots__ = ["current_context", "new_context", "has_parent"]
+
+ def __init__(self, new_context=None):
+ if new_context is None:
+ new_context = LoggingContext.sentinel
+ self.new_context = new_context
+
+ def __enter__(self):
+ """Captures the current logging context"""
+ self.current_context = LoggingContext.set_current_context(self.new_context)
+
+ if self.current_context:
+ self.has_parent = self.current_context.previous_context is not None
+ if not self.current_context.alive:
+ logger.debug("Entering dead context: %s", self.current_context)
+
+ def __exit__(self, type, value, traceback):
+ """Restores the current logging context"""
+ context = LoggingContext.set_current_context(self.current_context)
+
+ if context != self.new_context:
+ if context is LoggingContext.sentinel:
+ logger.warning("Expected logging context %s was lost", self.new_context)
+ else:
+ logger.warning(
+ "Expected logging context %s but found %s",
+ self.new_context,
+ context,
+ )
+
+ if self.current_context is not LoggingContext.sentinel:
+ if not self.current_context.alive:
+ logger.debug("Restoring dead context: %s", self.current_context)
+
+
+def nested_logging_context(suffix, parent_context=None):
+ """Creates a new logging context as a child of another.
+
+ The nested logging context will have a 'request' made up of the parent context's
+ request, plus the given suffix.
+
+ CPU/db usage stats will be added to the parent context's on exit.
+
+ Normal usage looks like:
+
+ with nested_logging_context(suffix):
+ # ... do stuff
+
+ Args:
+ suffix (str): suffix to add to the parent context's 'request'.
+ parent_context (LoggingContext|None): parent context. Will use the current context
+ if None.
+
+ Returns:
+ LoggingContext: new logging context.
+ """
+ if parent_context is None:
+ parent_context = LoggingContext.current_context()
+ return LoggingContext(
+ parent_context=parent_context, request=parent_context.request + "-" + suffix
+ )
+
+
+def preserve_fn(f):
+ """Function decorator which wraps the function with run_in_background"""
+
+ def g(*args, **kwargs):
+ return run_in_background(f, *args, **kwargs)
+
+ return g
+
+
+def run_in_background(f, *args, **kwargs):
+ """Calls a function, ensuring that the current context is restored after
+ return from the function, and that the sentinel context is set once the
+ deferred returned by the function completes.
+
+ Useful for wrapping functions that return a deferred or coroutine, which you don't
+ yield or await on (for instance because you want to pass it to
+ deferred.gatherResults()).
+
+ Note that if you completely discard the result, you should make sure that
+ `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+ CRITICAL error about an unhandled error will be logged without much
+ indication about where it came from.
+ """
+ current = LoggingContext.current_context()
+ try:
+ res = f(*args, **kwargs)
+ except: # noqa: E722
+ # the assumption here is that the caller doesn't want to be disturbed
+ # by synchronous exceptions, so let's turn them into Failures.
+ return defer.fail()
+
+ if isinstance(res, types.CoroutineType):
+ res = defer.ensureDeferred(res)
+
+ if not isinstance(res, defer.Deferred):
+ return res
+
+ if res.called and not res.paused:
+ # The function should have maintained the logcontext, so we can
+ # optimise out the messing about
+ return res
+
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(_set_context_cb, ctx)
+ return res
+
+
+def make_deferred_yieldable(deferred):
+ """Given a deferred, make it follow the Synapse logcontext rules:
+
+ If the deferred has completed (or is not actually a Deferred), essentially
+ does nothing (just returns another completed deferred with the
+ result/failure).
+
+ If the deferred has not yet completed, resets the logcontext before
+ returning a deferred. Then, when the deferred completes, restores the
+ current logcontext before running callbacks/errbacks.
+
+ (This is more-or-less the opposite operation to run_in_background.)
+ """
+ if not isinstance(deferred, defer.Deferred):
+ return deferred
+
+ if deferred.called and not deferred.paused:
+ # it looks like this deferred is ready to run any callbacks we give it
+ # immediately. We may as well optimise out the logcontext faffery.
+ return deferred
+
+ # ok, we can't be sure that a yield won't block, so let's reset the
+ # logcontext, and add a callback to the deferred to restore it.
+ prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ deferred.addBoth(_set_context_cb, prev_context)
+ return deferred
+
+
+def _set_context_cb(result, context):
+ """A callback function which just sets the logging context"""
+ LoggingContext.set_current_context(context)
+ return result
+
+
+def defer_to_thread(reactor, f, *args, **kwargs):
+ """
+ Calls the function `f` using a thread from the reactor's default threadpool and
+ returns the result as a Deferred.
+
+ Creates a new logcontext for `f`, which is created as a child of the current
+ logcontext (so its CPU usage metrics will get attributed to the current
+ logcontext). `f` should preserve the logcontext it is given.
+
+ The result deferred follows the Synapse logcontext rules: you should `yield`
+ on it.
+
+ Args:
+ reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
+ the Deferred will be invoked, and whose threadpool we should use for the
+ function.
+
+ Normally this will be hs.get_reactor().
+
+ f (callable): The function to call.
+
+ args: positional arguments to pass to f.
+
+ kwargs: keyword arguments to pass to f.
+
+ Returns:
+ Deferred: A Deferred which fires a callback with the result of `f`, or an
+ errback if `f` throws an exception.
+ """
+ return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
+
+
+def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+ """
+ A wrapper for twisted.internet.threads.deferToThreadpool, which handles
+ logcontexts correctly.
+
+ Calls the function `f` using a thread from the given threadpool and returns
+ the result as a Deferred.
+
+ Creates a new logcontext for `f`, which is created as a child of the current
+ logcontext (so its CPU usage metrics will get attributed to the current
+ logcontext). `f` should preserve the logcontext it is given.
+
+ The result deferred follows the Synapse logcontext rules: you should `yield`
+ on it.
+
+ Args:
+ reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
+ the Deferred will be invoked. Normally this will be hs.get_reactor().
+
+ threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for
+ running `f`. Normally this will be hs.get_reactor().getThreadPool().
+
+ f (callable): The function to call.
+
+ args: positional arguments to pass to f.
+
+ kwargs: keyword arguments to pass to f.
+
+ Returns:
+ Deferred: A Deferred which fires a callback with the result of `f`, or an
+ errback if `f` throws an exception.
+ """
+ logcontext = LoggingContext.current_context()
+
+ def g():
+ with LoggingContext(parent_context=logcontext):
+ return f(*args, **kwargs)
+
+ return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
new file mode 100644
index 00000000..fbf570c7
--- /dev/null
+++ b/synapse/logging/formatter.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import traceback
+
+from six import StringIO
+
+
+class LogFormatter(logging.Formatter):
+ """Log formatter which gives more detail for exceptions
+
+ This is the same as the standard log formatter, except that when logging
+ exceptions [typically via log.foo("msg", exc_info=1)], it prints the
+ sequence that led up to the point at which the exception was caught.
+ (Normally only stack frames between the point the exception was raised and
+ where it was caught are logged).
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(LogFormatter, self).__init__(*args, **kwargs)
+
+ def formatException(self, ei):
+ sio = StringIO()
+ (typ, val, tb) = ei
+
+ # log the stack above the exception capture point if possible, but
+ # check that we actually have an f_back attribute to work around
+ # https://twistedmatrix.com/trac/ticket/9305
+
+ if tb and hasattr(tb.tb_frame, "f_back"):
+ sio.write("Capture point (most recent call last):\n")
+ traceback.print_stack(tb.tb_frame.f_back, None, sio)
+
+ traceback.print_exception(typ, val, tb, None, sio)
+ s = sio.getvalue()
+ sio.close()
+ if s[-1:] == "\n":
+ s = s[:-1]
+ return s
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
new file mode 100644
index 00000000..3da33d78
--- /dev/null
+++ b/synapse/logging/opentracing.py
@@ -0,0 +1,483 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.import opentracing
+
+
+# NOTE
+# This is a small wrapper around opentracing because opentracing is not currently
+# packaged downstream (specifically debian). Since opentracing instrumentation is
+# fairly invasive it was awkward to make it optional. As a result we opted to encapsulate
+# all opentracing state in these methods which effectively noop if opentracing is
+# not present. We should strongly consider encouraging the downstream distributers
+# to package opentracing and making opentracing a full dependency. In order to facilitate
+# this move the methods have work very similarly to opentracing's and it should only
+# be a matter of few regexes to move over to opentracing's access patterns proper.
+
+"""
+============================
+Using OpenTracing in Synapse
+============================
+
+Python-specific tracing concepts are at https://opentracing.io/guides/python/.
+Note that Synapse wraps OpenTracing in a small module (this one) in order to make the
+OpenTracing dependency optional. That means that the access patterns are
+different to those demonstrated in the OpenTracing guides. However, it is
+still useful to know, especially if OpenTracing is included as a full dependency
+in the future or if you are modifying this module.
+
+
+OpenTracing is encapsulated so that
+no span objects from OpenTracing are exposed in Synapse's code. This allows
+OpenTracing to be easily disabled in Synapse and thereby have OpenTracing as
+an optional dependency. This does however limit the number of modifiable spans
+at any point in the code to one. From here out references to `opentracing`
+in the code snippets refer to the Synapses module.
+
+Tracing
+-------
+
+In Synapse it is not possible to start a non-active span. Spans can be started
+using the ``start_active_span`` method. This returns a scope (see
+OpenTracing docs) which is a context manager that needs to be entered and
+exited. This is usually done by using ``with``.
+
+.. code-block:: python
+
+ from synapse.logging.opentracing import start_active_span
+
+ with start_active_span("operation name"):
+ # Do something we want to tracer
+
+Forgetting to enter or exit a scope will result in some mysterious and grievous log
+context errors.
+
+At anytime where there is an active span ``opentracing.set_tag`` can be used to
+set a tag on the current active span.
+
+Tracing functions
+-----------------
+
+Functions can be easily traced using decorators. There is a decorator for
+'normal' function and for functions which are actually deferreds. The name of
+the function becomes the operation name for the span.
+
+.. code-block:: python
+
+ from synapse.logging.opentracing import trace, trace_deferred
+
+ # Start a span using 'normal_function' as the operation name
+ @trace
+ def normal_function(*args, **kwargs):
+ # Does all kinds of cool and expected things
+ return something_usual_and_useful
+
+ # Start a span using 'deferred_function' as the operation name
+ @trace_deferred
+ @defer.inlineCallbacks
+ def deferred_function(*args, **kwargs):
+ # We start
+ yield we_wait
+ # we finish
+ defer.returnValue(something_usual_and_useful)
+
+Operation names can be explicitly set for functions by using
+``trace_using_operation_name`` and
+``trace_deferred_using_operation_name``
+
+.. code-block:: python
+
+ from synapse.logging.opentracing import (
+ trace_using_operation_name,
+ trace_deferred_using_operation_name
+ )
+
+ @trace_using_operation_name("A *much* better operation name")
+ def normal_function(*args, **kwargs):
+ # Does all kinds of cool and expected things
+ return something_usual_and_useful
+
+ @trace_deferred_using_operation_name("Another exciting operation name!")
+ @defer.inlineCallbacks
+ def deferred_function(*args, **kwargs):
+ # We start
+ yield we_wait
+ # we finish
+ defer.returnValue(something_usual_and_useful)
+
+Contexts and carriers
+---------------------
+
+There are a selection of wrappers for injecting and extracting contexts from
+carriers provided. Unfortunately OpenTracing's three context injection
+techniques are not adequate for our inject of OpenTracing span-contexts into
+Twisted's http headers, EDU contents and our database tables. Also note that
+the binary encoding format mandated by OpenTracing is not actually implemented
+by jaeger_client v4.0.0 - it will silently noop.
+Please refer to the end of ``logging/opentracing.py`` for the available
+injection and extraction methods.
+
+Homeserver whitelisting
+-----------------------
+
+Most of the whitelist checks are encapsulated in the modules's injection
+and extraction method but be aware that using custom carriers or crossing
+unchartered waters will require the enforcement of the whitelist.
+``logging/opentracing.py`` has a ``whitelisted_homeserver`` method which takes
+in a destination and compares it to the whitelist.
+
+=======
+Gotchas
+=======
+
+- Checking whitelists on span propagation
+- Inserting pii
+- Forgetting to enter or exit a scope
+- Span source: make sure that the span you expect to be active across a
+ function call really will be that one. Does the current function have more
+ than one caller? Will all of those calling functions have be in a context
+ with an active span?
+"""
+
+import contextlib
+import logging
+import re
+from functools import wraps
+
+from twisted.internet import defer
+
+from synapse.config import ConfigError
+
+try:
+ import opentracing
+except ImportError:
+ opentracing = None
+try:
+ from jaeger_client import Config as JaegerConfig
+ from synapse.logging.scopecontextmanager import LogContextScopeManager
+except ImportError:
+ JaegerConfig = None
+ LogContextScopeManager = None
+
+
+logger = logging.getLogger(__name__)
+
+
+class _DumTagNames(object):
+ """wrapper of opentracings tags. We need to have them if we
+ want to reference them without opentracing around. Clearly they
+ should never actually show up in a trace. `set_tags` overwrites
+ these with the correct ones."""
+
+ INVALID_TAG = "invalid-tag"
+ COMPONENT = INVALID_TAG
+ DATABASE_INSTANCE = INVALID_TAG
+ DATABASE_STATEMENT = INVALID_TAG
+ DATABASE_TYPE = INVALID_TAG
+ DATABASE_USER = INVALID_TAG
+ ERROR = INVALID_TAG
+ HTTP_METHOD = INVALID_TAG
+ HTTP_STATUS_CODE = INVALID_TAG
+ HTTP_URL = INVALID_TAG
+ MESSAGE_BUS_DESTINATION = INVALID_TAG
+ PEER_ADDRESS = INVALID_TAG
+ PEER_HOSTNAME = INVALID_TAG
+ PEER_HOST_IPV4 = INVALID_TAG
+ PEER_HOST_IPV6 = INVALID_TAG
+ PEER_PORT = INVALID_TAG
+ PEER_SERVICE = INVALID_TAG
+ SAMPLING_PRIORITY = INVALID_TAG
+ SERVICE = INVALID_TAG
+ SPAN_KIND = INVALID_TAG
+ SPAN_KIND_CONSUMER = INVALID_TAG
+ SPAN_KIND_PRODUCER = INVALID_TAG
+ SPAN_KIND_RPC_CLIENT = INVALID_TAG
+ SPAN_KIND_RPC_SERVER = INVALID_TAG
+
+
+def only_if_tracing(func):
+ """Executes the function only if we're tracing. Otherwise return.
+ Assumes the function wrapped may return None"""
+
+ @wraps(func)
+ def _only_if_tracing_inner(*args, **kwargs):
+ if opentracing:
+ return func(*args, **kwargs)
+ else:
+ return
+
+ return _only_if_tracing_inner
+
+
+# A regex which matches the server_names to expose traces for.
+# None means 'block everything'.
+_homeserver_whitelist = None
+
+tags = _DumTagNames
+
+
+def init_tracer(config):
+ """Set the whitelists and initialise the JaegerClient tracer
+
+ Args:
+ config (HomeserverConfig): The config used by the homeserver
+ """
+ global opentracing
+ if not config.opentracer_enabled:
+ # We don't have a tracer
+ opentracing = None
+ return
+
+ if not opentracing or not JaegerConfig:
+ raise ConfigError(
+ "The server has been configured to use opentracing but opentracing is not "
+ "installed."
+ )
+
+ # Include the worker name
+ name = config.worker_name if config.worker_name else "master"
+
+ set_homeserver_whitelist(config.opentracer_whitelist)
+ jaeger_config = JaegerConfig(
+ config={"sampler": {"type": "const", "param": 1}, "logging": True},
+ service_name="{} {}".format(config.server_name, name),
+ scope_manager=LogContextScopeManager(config),
+ )
+ jaeger_config.initialize_tracer()
+
+ # Set up tags to be opentracing's tags
+ global tags
+ tags = opentracing.tags
+
+
+@contextlib.contextmanager
+def _noop_context_manager(*args, **kwargs):
+ """Does absolutely nothing really well. Can be entered and exited arbitrarily.
+ Good substitute for an opentracing scope."""
+ yield
+
+
+# Could use kwargs but I want these to be explicit
+def start_active_span(
+ operation_name,
+ child_of=None,
+ references=None,
+ tags=None,
+ start_time=None,
+ ignore_active_span=False,
+ finish_on_close=True,
+):
+ """Starts an active opentracing span. Note, the scope doesn't become active
+ until it has been entered, however, the span starts from the time this
+ message is called.
+ Args:
+ See opentracing.tracer
+ Returns:
+ scope (Scope) or noop_context_manager
+ """
+ if opentracing is None:
+ return _noop_context_manager()
+ else:
+ # We need to enter the scope here for the logcontext to become active
+ return opentracing.tracer.start_active_span(
+ operation_name,
+ child_of=child_of,
+ references=references,
+ tags=tags,
+ start_time=start_time,
+ ignore_active_span=ignore_active_span,
+ finish_on_close=finish_on_close,
+ )
+
+
+@only_if_tracing
+def close_active_span():
+ """Closes the active span. This will close it's logcontext if the context
+ was made for the span"""
+ opentracing.tracer.scope_manager.active.__exit__(None, None, None)
+
+
+@only_if_tracing
+def set_tag(key, value):
+ """Set's a tag on the active span"""
+ opentracing.tracer.active_span.set_tag(key, value)
+
+
+@only_if_tracing
+def log_kv(key_values, timestamp=None):
+ """Log to the active span"""
+ opentracing.tracer.active_span.log_kv(key_values, timestamp)
+
+
+# Note: we don't have a get baggage items because we're trying to hide all
+# scope and span state from synapse. I think this method may also be useless
+# as a result
+@only_if_tracing
+def set_baggage_item(key, value):
+ """Attach baggage to the active span"""
+ opentracing.tracer.active_span.set_baggage_item(key, value)
+
+
+@only_if_tracing
+def set_operation_name(operation_name):
+ """Sets the operation name of the active span"""
+ opentracing.tracer.active_span.set_operation_name(operation_name)
+
+
+@only_if_tracing
+def set_homeserver_whitelist(homeserver_whitelist):
+ """Sets the whitelist
+
+ Args:
+ homeserver_whitelist (iterable of strings): regex of whitelisted homeservers
+ """
+ global _homeserver_whitelist
+ if homeserver_whitelist:
+ # Makes a single regex which accepts all passed in regexes in the list
+ _homeserver_whitelist = re.compile(
+ "({})".format(")|(".join(homeserver_whitelist))
+ )
+
+
+@only_if_tracing
+def whitelisted_homeserver(destination):
+ """Checks if a destination matches the whitelist
+ Args:
+ destination (String)"""
+ if _homeserver_whitelist:
+ return _homeserver_whitelist.match(destination)
+ return False
+
+
+def start_active_span_from_context(
+ headers,
+ operation_name,
+ references=None,
+ tags=None,
+ start_time=None,
+ ignore_active_span=False,
+ finish_on_close=True,
+):
+ """
+ Extracts a span context from Twisted Headers.
+ args:
+ headers (twisted.web.http_headers.Headers)
+ returns:
+ span_context (opentracing.span.SpanContext)
+ """
+ # Twisted encodes the values as lists whereas opentracing doesn't.
+ # So, we take the first item in the list.
+ # Also, twisted uses byte arrays while opentracing expects strings.
+ if opentracing is None:
+ return _noop_context_manager()
+
+ header_dict = {k.decode(): v[0].decode() for k, v in headers.getAllRawHeaders()}
+ context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
+
+ return opentracing.tracer.start_active_span(
+ operation_name,
+ child_of=context,
+ references=references,
+ tags=tags,
+ start_time=start_time,
+ ignore_active_span=ignore_active_span,
+ finish_on_close=finish_on_close,
+ )
+
+
+@only_if_tracing
+def inject_active_span_twisted_headers(headers, destination):
+ """
+ Injects a span context into twisted headers inplace
+
+ Args:
+ headers (twisted.web.http_headers.Headers)
+ span (opentracing.Span)
+
+ Returns:
+ Inplace modification of headers
+
+ Note:
+ The headers set by the tracer are custom to the tracer implementation which
+ should be unique enough that they don't interfere with any headers set by
+ synapse or twisted. If we're still using jaeger these headers would be those
+ here:
+ https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
+ """
+
+ if not whitelisted_homeserver(destination):
+ return
+
+ span = opentracing.tracer.active_span
+ carrier = {}
+ opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+
+ for key, value in carrier.items():
+ headers.addRawHeaders(key, value)
+
+
+@only_if_tracing
+def inject_active_span_byte_dict(headers, destination):
+ """
+ Injects a span context into a dict where the headers are encoded as byte
+ strings
+
+ Args:
+ headers (dict)
+ span (opentracing.Span)
+
+ Returns:
+ Inplace modification of headers
+
+ Note:
+ The headers set by the tracer are custom to the tracer implementation which
+ should be unique enough that they don't interfere with any headers set by
+ synapse or twisted. If we're still using jaeger these headers would be those
+ here:
+ https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
+ """
+ if not whitelisted_homeserver(destination):
+ return
+
+ span = opentracing.tracer.active_span
+
+ carrier = {}
+ opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+
+ for key, value in carrier.items():
+ headers[key.encode()] = [value.encode()]
+
+
+def trace_servlet(servlet_name, func):
+ """Decorator which traces a serlet. It starts a span with some servlet specific
+ tags such as the servlet_name and request information"""
+
+ @wraps(func)
+ @defer.inlineCallbacks
+ def _trace_servlet_inner(request, *args, **kwargs):
+ with start_active_span_from_context(
+ request.requestHeaders,
+ "incoming-client-request",
+ tags={
+ "request_id": request.get_request_id(),
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ tags.HTTP_METHOD: request.get_method(),
+ tags.HTTP_URL: request.get_redacted_uri(),
+ tags.PEER_HOST_IPV6: request.getClientIP(),
+ "servlet_name": servlet_name,
+ },
+ ):
+ result = yield defer.maybeDeferred(func, request, *args, **kwargs)
+ defer.returnValue(result)
+
+ return _trace_servlet_inner
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
new file mode 100644
index 00000000..8c661302
--- /dev/null
+++ b/synapse/logging/scopecontextmanager.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.import logging
+
+import logging
+
+from opentracing import Scope, ScopeManager
+
+import twisted
+
+from synapse.logging.context import LoggingContext, nested_logging_context
+
+logger = logging.getLogger(__name__)
+
+
+class LogContextScopeManager(ScopeManager):
+ """
+ The LogContextScopeManager tracks the active scope in opentracing
+ by using the log contexts which are native to synapse. This is so
+ that the basic opentracing api can be used across twisted defereds.
+ (I would love to break logcontexts and this into an OS package. but
+ let's wait for twisted's contexts to be released.)
+ """
+
+ def __init__(self, config):
+ pass
+
+ @property
+ def active(self):
+ """
+ Returns the currently active Scope which can be used to access the
+ currently active Scope.span.
+ If there is a non-null Scope, its wrapped Span
+ becomes an implicit parent of any newly-created Span at
+ Tracer.start_active_span() time.
+
+ Return:
+ (Scope) : the Scope that is active, or None if not
+ available.
+ """
+ ctx = LoggingContext.current_context()
+ if ctx is LoggingContext.sentinel:
+ return None
+ else:
+ return ctx.scope
+
+ def activate(self, span, finish_on_close):
+ """
+ Makes a Span active.
+ Args
+ span (Span): the span that should become active.
+ finish_on_close (Boolean): whether Span should be automatically
+ finished when Scope.close() is called.
+
+ Returns:
+ Scope to control the end of the active period for
+ *span*. It is a programming error to neglect to call
+ Scope.close() on the returned instance.
+ """
+
+ enter_logcontext = False
+ ctx = LoggingContext.current_context()
+
+ if ctx is LoggingContext.sentinel:
+ # We don't want this scope to affect.
+ logger.error("Tried to activate scope outside of loggingcontext")
+ return Scope(None, span)
+ elif ctx.scope is not None:
+ # We want the logging scope to look exactly the same so we give it
+ # a blank suffix
+ ctx = nested_logging_context("")
+ enter_logcontext = True
+
+ scope = _LogContextScope(self, span, ctx, enter_logcontext, finish_on_close)
+ ctx.scope = scope
+ return scope
+
+
+class _LogContextScope(Scope):
+ """
+ A custom opentracing scope. The only significant difference is that it will
+ close the log context it's related to if the logcontext was created specifically
+ for this scope.
+ """
+
+ def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
+ """
+ Args:
+ manager (LogContextScopeManager):
+ the manager that is responsible for this scope.
+ span (Span):
+ the opentracing span which this scope represents the local
+ lifetime for.
+ logcontext (LogContext):
+ the logcontext to which this scope is attached.
+ enter_logcontext (Boolean):
+ if True the logcontext will be entered and exited when the scope
+ is entered and exited respectively
+ finish_on_close (Boolean):
+ if True finish the span when the scope is closed
+ """
+ super(_LogContextScope, self).__init__(manager, span)
+ self.logcontext = logcontext
+ self._finish_on_close = finish_on_close
+ self._enter_logcontext = enter_logcontext
+
+ def __enter__(self):
+ if self._enter_logcontext:
+ self.logcontext.__enter__()
+
+ def __exit__(self, type, value, traceback):
+ if type == twisted.internet.defer._DefGen_Return:
+ super(_LogContextScope, self).__exit__(None, None, None)
+ else:
+ super(_LogContextScope, self).__exit__(type, value, traceback)
+ if self._enter_logcontext:
+ self.logcontext.__exit__(type, value, traceback)
+ else: # the logcontext existed before the creation of the scope
+ self.logcontext.scope = None
+
+ def close(self):
+ if self.manager.active is not self:
+ logger.error("Tried to close a none active scope!")
+ return
+
+ if self._finish_on_close:
+ self.span.finish()
diff --git a/synapse/util/logutils.py b/synapse/logging/utils.py
index 7df0fa60..7df0fa60 100644
--- a/synapse/util/logutils.py
+++ b/synapse/logging/utils.py
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index eaf0aaa8..488280b4 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -29,8 +29,16 @@ from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricF
from twisted.internet import reactor
+from synapse.metrics._exposition import (
+ MetricsResource,
+ generate_latest,
+ start_http_server,
+)
+
logger = logging.getLogger(__name__)
+METRICS_PREFIX = "/_synapse/metrics"
+
running_on_pypy = platform.python_implementation() == "PyPy"
all_metrics = []
all_collectors = []
@@ -470,3 +478,12 @@ try:
gc.disable()
except AttributeError:
pass
+
+__all__ = [
+ "MetricsResource",
+ "generate_latest",
+ "start_http_server",
+ "LaterGauge",
+ "InFlightGauge",
+ "BucketCollector",
+]
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
new file mode 100644
index 00000000..1933ecd3
--- /dev/null
+++ b/synapse/metrics/_exposition.py
@@ -0,0 +1,258 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015-2019 Prometheus Python Client Developers
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This code is based off `prometheus_client/exposition.py` from version 0.7.1.
+
+Due to the renaming of metrics in prometheus_client 0.4.0, this customised
+vendoring of the code will emit both the old versions that Synapse dashboards
+expect, and the newer "best practice" version of the up-to-date official client.
+"""
+
+import math
+import threading
+from collections import namedtuple
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from socketserver import ThreadingMixIn
+from urllib.parse import parse_qs, urlparse
+
+from prometheus_client import REGISTRY
+
+from twisted.web.resource import Resource
+
+try:
+ from prometheus_client.samples import Sample
+except ImportError:
+ Sample = namedtuple("Sample", ["name", "labels", "value", "timestamp", "exemplar"])
+
+
+CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
+
+
+INF = float("inf")
+MINUS_INF = float("-inf")
+
+
+def floatToGoString(d):
+ d = float(d)
+ if d == INF:
+ return "+Inf"
+ elif d == MINUS_INF:
+ return "-Inf"
+ elif math.isnan(d):
+ return "NaN"
+ else:
+ s = repr(d)
+ dot = s.find(".")
+ # Go switches to exponents sooner than Python.
+ # We only need to care about positive values for le/quantile.
+ if d > 0 and dot > 6:
+ mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.")
+ return "{0}e+0{1}".format(mantissa, dot - 1)
+ return s
+
+
+def sample_line(line, name):
+ if line.labels:
+ labelstr = "{{{0}}}".format(
+ ",".join(
+ [
+ '{0}="{1}"'.format(
+ k,
+ v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""),
+ )
+ for k, v in sorted(line.labels.items())
+ ]
+ )
+ )
+ else:
+ labelstr = ""
+ timestamp = ""
+ if line.timestamp is not None:
+ # Convert to milliseconds.
+ timestamp = " {0:d}".format(int(float(line.timestamp) * 1000))
+ return "{0}{1} {2}{3}\n".format(
+ name, labelstr, floatToGoString(line.value), timestamp
+ )
+
+
+def nameify_sample(sample):
+ """
+ If we get a prometheus_client<0.4.0 sample as a tuple, transform it into a
+ namedtuple which has the names we expect.
+ """
+ if not isinstance(sample, Sample):
+ sample = Sample(*sample, None, None)
+
+ return sample
+
+
+def generate_latest(registry, emit_help=False):
+ output = []
+
+ for metric in registry.collect():
+
+ if metric.name.startswith("__unused"):
+ continue
+
+ if not metric.samples:
+ # No samples, don't bother.
+ continue
+
+ mname = metric.name
+ mnewname = metric.name
+ mtype = metric.type
+
+ # OpenMetrics -> Prometheus
+ if mtype == "counter":
+ mnewname = mnewname + "_total"
+ elif mtype == "info":
+ mtype = "gauge"
+ mnewname = mnewname + "_info"
+ elif mtype == "stateset":
+ mtype = "gauge"
+ elif mtype == "gaugehistogram":
+ mtype = "histogram"
+ elif mtype == "unknown":
+ mtype = "untyped"
+
+ # Output in the old format for compatibility.
+ if emit_help:
+ output.append(
+ "# HELP {0} {1}\n".format(
+ mname,
+ metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
+ )
+ )
+ output.append("# TYPE {0} {1}\n".format(mname, mtype))
+ for sample in map(nameify_sample, metric.samples):
+ # Get rid of the OpenMetrics specific samples
+ for suffix in ["_created", "_gsum", "_gcount"]:
+ if sample.name.endswith(suffix):
+ break
+ else:
+ newname = sample.name.replace(mnewname, mname)
+ if ":" in newname and newname.endswith("_total"):
+ newname = newname[: -len("_total")]
+ output.append(sample_line(sample, newname))
+
+ # Get rid of the weird colon things while we're at it
+ if mtype == "counter":
+ mnewname = mnewname.replace(":total", "")
+ mnewname = mnewname.replace(":", "_")
+
+ if mname == mnewname:
+ continue
+
+ # Also output in the new format, if it's different.
+ if emit_help:
+ output.append(
+ "# HELP {0} {1}\n".format(
+ mnewname,
+ metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
+ )
+ )
+ output.append("# TYPE {0} {1}\n".format(mnewname, mtype))
+ for sample in map(nameify_sample, metric.samples):
+ # Get rid of the OpenMetrics specific samples
+ for suffix in ["_created", "_gsum", "_gcount"]:
+ if sample.name.endswith(suffix):
+ break
+ else:
+ output.append(
+ sample_line(
+ sample, sample.name.replace(":total", "").replace(":", "_")
+ )
+ )
+
+ return "".join(output).encode("utf-8")
+
+
+class MetricsHandler(BaseHTTPRequestHandler):
+ """HTTP handler that gives metrics from ``REGISTRY``."""
+
+ registry = REGISTRY
+
+ def do_GET(self):
+ registry = self.registry
+ params = parse_qs(urlparse(self.path).query)
+
+ if "help" in params:
+ emit_help = True
+ else:
+ emit_help = False
+
+ try:
+ output = generate_latest(registry, emit_help=emit_help)
+ except Exception:
+ self.send_error(500, "error generating metric output")
+ raise
+ self.send_response(200)
+ self.send_header("Content-Type", CONTENT_TYPE_LATEST)
+ self.end_headers()
+ self.wfile.write(output)
+
+ def log_message(self, format, *args):
+ """Log nothing."""
+
+ @classmethod
+ def factory(cls, registry):
+ """Returns a dynamic MetricsHandler class tied
+ to the passed registry.
+ """
+ # This implementation relies on MetricsHandler.registry
+ # (defined above and defaulted to REGISTRY).
+
+ # As we have unicode_literals, we need to create a str()
+ # object for type().
+ cls_name = str(cls.__name__)
+ MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry})
+ return MyMetricsHandler
+
+
+class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
+ """Thread per request HTTP server."""
+
+ # Make worker threads "fire and forget". Beginning with Python 3.7 this
+ # prevents a memory leak because ``ThreadingMixIn`` starts to gather all
+ # non-daemon threads in a list in order to join on them at server close.
+ # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the
+ # same as Python 3.7's ``ThreadingHTTPServer``.
+ daemon_threads = True
+
+
+def start_http_server(port, addr="", registry=REGISTRY):
+ """Starts an HTTP server for prometheus metrics as a daemon thread"""
+ CustomMetricsHandler = MetricsHandler.factory(registry)
+ httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
+ t = threading.Thread(target=httpd.serve_forever)
+ t.daemon = True
+ t.start()
+
+
+class MetricsResource(Resource):
+ """
+ Twisted ``Resource`` that serves prometheus metrics.
+ """
+
+ isLeaf = True
+
+ def __init__(self, registry=REGISTRY):
+ self.registry = registry
+
+ def render_GET(self, request):
+ request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
+ return generate_latest(self.registry)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 167e2c06..edd6b42d 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -22,7 +22,7 @@ from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
from twisted.internet import defer
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
diff --git a/synapse/metrics/resource.py b/synapse/metrics/resource.py
deleted file mode 100644
index 97893590..00000000
--- a/synapse/metrics/resource.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from prometheus_client.twisted import MetricsResource
-
-METRICS_PREFIX = "/_synapse/metrics"
-
-__all__ = ["MetricsResource", "METRICS_PREFIX"]
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index bf43ca09..7bb020cb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -12,10 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.types import UserID
+logger = logging.getLogger(__name__)
+
class ModuleApi(object):
"""A proxy object that gets passed to password auth providers so they
@@ -76,8 +80,31 @@ class ModuleApi(object):
@defer.inlineCallbacks
def register(self, localpart, displayname=None, emails=[]):
- """Registers a new user with given localpart and optional
- displayname, emails.
+ """Registers a new user with given localpart and optional displayname, emails.
+
+ Also returns an access token for the new user.
+
+ Deprecated: avoid this, as it generates a new device with no way to
+ return that device to the user. Prefer separate calls to register_user and
+ register_device.
+
+ Args:
+ localpart (str): The localpart of the new user.
+ displayname (str|None): The displayname of the new user.
+ emails (List[str]): Emails to bind to the new user.
+
+ Returns:
+ Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token)
+ """
+ logger.warning(
+ "Using deprecated ModuleApi.register which creates a dummy user device."
+ )
+ user_id = yield self.register_user(localpart, displayname, emails)
+ _, access_token = yield self.register_device(user_id)
+ defer.returnValue((user_id, access_token))
+
+ def register_user(self, localpart, displayname=None, emails=[]):
+ """Registers a new user with given localpart and optional displayname, emails.
Args:
localpart (str): The localpart of the new user.
@@ -85,15 +112,30 @@ class ModuleApi(object):
emails (List[str]): Emails to bind to the new user.
Returns:
- Deferred: a 2-tuple of (user_id, access_token)
+ Deferred[str]: user_id
"""
- # Register the user
- reg = self.hs.get_registration_handler()
- user_id, access_token = yield reg.register(
+ return self.hs.get_registration_handler().register_user(
localpart=localpart, default_display_name=displayname, bind_emails=emails
)
- defer.returnValue((user_id, access_token))
+ def register_device(self, user_id, device_id=None, initial_display_name=None):
+ """Register a device for a user and generate an access token.
+
+ Args:
+ user_id (str): full canonical @user:id
+ device_id (str|None): The device ID to check, or None to generate
+ a new one.
+ initial_display_name (str|None): An optional display name for the
+ device.
+
+ Returns:
+ defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+ """
+ return self.hs.get_registration_handler().register_device(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ )
@defer.inlineCallbacks
def invalidate_access_token(self, access_token):
diff --git a/synapse/notifier.py b/synapse/notifier.py
index d398078e..918ef648 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -23,12 +23,12 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
+from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import StreamToken
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
-from synapse.util.logcontext import PreserveLoggingContext
-from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 96d087de..134bf805 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -1,5 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -248,6 +249,18 @@ BASE_APPEND_OVERRIDE_RULES = [
],
"actions": ["notify", {"set_tweak": "highlight", "value": True}],
},
+ {
+ "rule_id": "global/override/.m.rule.reaction",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.reaction",
+ "_id": "_reaction",
+ }
+ ],
+ "actions": ["dont_notify"],
+ },
]
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 809199fe..521c6e2c 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -29,6 +29,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
+from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
descriptor_from_member_events,
@@ -36,7 +37,6 @@ from synapse.push.presentable_names import (
)
from synapse.types import UserID
from synapse.util.async_helpers import concurrently_execute
-from synapse.util.logcontext import make_deferred_yieldable
from synapse.visibility import filter_events_for_client
logger = logging.getLogger(__name__)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 6324c00e..c6465c03 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -65,9 +65,7 @@ REQUIREMENTS = [
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
"six>=1.10",
- # prometheus_client 0.4.0 changed the format of counter metrics
- # (cf https://github.com/matrix-org/synapse/issues/4001)
- "prometheus_client>=0.0.18,<0.4.0",
+ "prometheus_client>=0.0.18,<0.8.0",
# we use attr.s(slots), which arrived in 16.0.0
# Twisted 18.7.0 requires attrs>=17.4.0
"attrs>=17.4.0",
@@ -95,6 +93,7 @@ CONDITIONAL_REQUIREMENTS = {
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0", "parameterized"],
"sentry": ["sentry-sdk>=0.7.2"],
+ "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
"jwt": ["pyjwt>=1.6.4"],
}
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index fe482e27..43c89e36 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -205,7 +205,7 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
- http_server.register_paths(method, [pattern], handler)
+ http_server.register_paths(method, [pattern], handler, self.__class__.__name__)
def _cached_handler(self, request, txn_id, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 0a76a376..2d9cbbae 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -156,70 +156,6 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
defer.returnValue((200, ret))
-class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
- """Gets/creates a guest account for given 3PID.
-
- Request format:
-
- POST /_synapse/replication/get_or_register_3pid_guest/
-
- {
- "requester": ...,
- "medium": ...,
- "address": ...,
- "inviter_user_id": ...
- }
- """
-
- NAME = "get_or_register_3pid_guest"
- PATH_ARGS = ()
-
- def __init__(self, hs):
- super(ReplicationRegister3PIDGuestRestServlet, self).__init__(hs)
-
- self.registeration_handler = hs.get_registration_handler()
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
-
- @staticmethod
- def _serialize_payload(requester, medium, address, inviter_user_id):
- """
- Args:
- requester(Requester)
- medium (str)
- address (str)
- inviter_user_id (str): The user ID who is trying to invite the
- 3PID
- """
- return {
- "requester": requester.serialize(),
- "medium": medium,
- "address": address,
- "inviter_user_id": inviter_user_id,
- }
-
- @defer.inlineCallbacks
- def _handle_request(self, request):
- content = parse_json_object_from_request(request)
-
- medium = content["medium"]
- address = content["address"]
- inviter_user_id = content["inviter_user_id"]
-
- requester = Requester.deserialize(self.store, content["requester"])
-
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
-
- logger.info("get_or_register_3pid_guest: %r", content)
-
- ret = yield self.registeration_handler.get_or_register_3pid_guest(
- medium, address, inviter_user_id
- )
-
- defer.returnValue((200, ret))
-
-
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
@@ -272,5 +208,4 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
- ReplicationRegister3PIDGuestRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index f81a0f1b..2bf21738 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -38,7 +38,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
@staticmethod
def _serialize_payload(
user_id,
- token,
password_hash,
was_guest,
make_guest,
@@ -51,9 +50,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"""
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user. If this
- is not None, the given access token is associated with the user
- id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -68,7 +64,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
address (str|None): the IP address used to perform the regitration.
"""
return {
- "token": token,
"password_hash": password_hash,
"was_guest": was_guest,
"make_guest": make_guest,
@@ -85,7 +80,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
yield self.registration_handler.register_with_store(
user_id=user_id,
- token=content["token"],
password_hash=content["password_hash"],
was_guest=content["was_guest"],
make_guest=content["make_guest"],
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 97efb835..5ffdf267 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -62,9 +62,9 @@ from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from .commands import (
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 9843a902..6888ae55 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -219,11 +219,10 @@ class UserRegisterServlet(RestServlet):
register = RegisterRestServlet(self.hs)
- (user_id, _) = yield register.registration_handler.register(
+ user_id = yield register.registration_handler.register_user(
localpart=body["username"].lower(),
password=body["password"],
admin=bool(admin),
- generate_token=False,
user_type=user_type,
)
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index ee66838a..d9c71261 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -59,9 +59,14 @@ class SendServerNoticeServlet(RestServlet):
def register(self, json_resource):
PATTERN = "^/_synapse/admin/v1/send_server_notice"
- json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST)
json_resource.register_paths(
- "PUT", (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),), self.on_PUT
+ "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
+ )
+ json_resource.register_paths(
+ "PUT",
+ (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
+ self.on_PUT,
+ self.__class__.__name__,
)
@defer.inlineCallbacks
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 36404b79..6da71dc4 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -17,8 +17,8 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index dd0d38ea..57542c2b 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -18,7 +18,13 @@ import logging
from twisted.internet import defer
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientCredentialsError,
+ NotFoundError,
+ SynapseError,
+)
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias
@@ -97,7 +103,7 @@ class ClientDirectoryServer(RestServlet):
room_alias.to_string(),
)
defer.returnValue((200, {}))
- except AuthError:
+ except InvalidClientCredentialsError:
# fallback to default user behaviour if they aren't an AS
pass
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f9611782..0d05945f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -283,19 +283,7 @@ class LoginRestServlet(RestServlet):
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
- device_id = login_submission.get("device_id")
- initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name
- )
-
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "device_id": device_id,
- }
-
+ result = yield self._register_device_with_callback(user_id, login_submission)
defer.returnValue(result)
@defer.inlineCallbacks
@@ -323,35 +311,16 @@ class LoginRestServlet(RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID(user, self.hs.hostname).to_string()
- device_id = login_submission.get("device_id")
- initial_display_name = login_submission.get("initial_device_display_name")
-
- auth_handler = self.auth_handler
- registered_user_id = yield auth_handler.check_user_exists(user_id)
- if registered_user_id:
- device_id, access_token = yield self.registration_handler.register_device(
- registered_user_id, device_id, initial_display_name
- )
- result = {
- "user_id": registered_user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
- else:
- user_id, access_token = (
- yield self.registration_handler.register(localpart=user)
- )
- device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name
+ registered_user_id = yield self.auth_handler.check_user_exists(user_id)
+ if not registered_user_id:
+ registered_user_id = yield self.registration_handler.register_user(
+ localpart=user
)
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
-
+ result = yield self._register_device_with_callback(
+ registered_user_id, login_submission
+ )
defer.returnValue(result)
@@ -534,12 +503,8 @@ class SSOAuthHandler(object):
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = yield self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
- registered_user_id, _ = (
- yield self._registration_handler.register(
- localpart=localpart,
- generate_token=False,
- default_display_name=user_display_name,
- )
+ registered_user_id = yield self._registration_handler.register_user(
+ localpart=localpart, default_display_name=user_display_name
)
login_token = self._macaroon_gen.generate_short_term_login_token(
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index cca7e45d..6276e97f 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -24,7 +24,12 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientCredentialsError,
+ SynapseError,
+)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
@@ -62,11 +67,17 @@ class RoomCreateRestServlet(TransactionRestServlet):
register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths(
- "OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS
+ "OPTIONS",
+ client_patterns("/rooms(?:/.*)?$", v1=True),
+ self.on_OPTIONS,
+ self.__class__.__name__,
)
# define CORS for /createRoom[/txnid]
http_server.register_paths(
- "OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS
+ "OPTIONS",
+ client_patterns("/createRoom(?:/.*)?$", v1=True),
+ self.on_OPTIONS,
+ self.__class__.__name__,
)
def on_PUT(self, request, txn_id):
@@ -111,16 +122,28 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
http_server.register_paths(
- "GET", client_patterns(state_key, v1=True), self.on_GET
+ "GET",
+ client_patterns(state_key, v1=True),
+ self.on_GET,
+ self.__class__.__name__,
)
http_server.register_paths(
- "PUT", client_patterns(state_key, v1=True), self.on_PUT
+ "PUT",
+ client_patterns(state_key, v1=True),
+ self.on_PUT,
+ self.__class__.__name__,
)
http_server.register_paths(
- "GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key
+ "GET",
+ client_patterns(no_state_key, v1=True),
+ self.on_GET_no_state_key,
+ self.__class__.__name__,
)
http_server.register_paths(
- "PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key
+ "PUT",
+ client_patterns(no_state_key, v1=True),
+ self.on_PUT_no_state_key,
+ self.__class__.__name__,
)
def on_GET_no_state_key(self, request, room_id, event_type):
@@ -307,7 +330,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
try:
yield self.auth.get_user_by_req(request, allow_guest=True)
- except AuthError as e:
+ except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
# federations.
@@ -840,18 +863,23 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
with_get: True to also register respective GET paths for the PUTs.
"""
http_server.register_paths(
- "POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST
+ "POST",
+ client_patterns(regex_string + "$", v1=True),
+ servlet.on_POST,
+ servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT,
+ servlet.__class__.__name__,
)
if with_get:
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET,
+ servlet.__class__.__name__,
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5c120e4d..f327999e 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -464,11 +464,10 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE,
)
- (registered_user_id, _) = yield self.registration_handler.register(
+ registered_user_id = yield self.registration_handler.register_user(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
- generate_token=False,
threepid=threepid,
address=client_addr,
)
@@ -542,8 +541,8 @@ class RegisterRestServlet(RestServlet):
if not compare_digest(want_mac, got_mac):
raise SynapseError(403, "HMAC incorrect")
- (user_id, _) = yield self.registration_handler.register(
- localpart=username, password=password, generate_token=False
+ user_id = yield self.registration_handler.register_user(
+ localpart=username, password=password
)
result = yield self._create_registration_details(user_id, body)
@@ -577,8 +576,8 @@ class RegisterRestServlet(RestServlet):
def _do_guest_registration(self, params, address=None):
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
- user_id, _ = yield self.registration_handler.register(
- generate_token=False, make_guest=True, address=address
+ user_id = yield self.registration_handler.register_user(
+ make_guest=True, address=address
)
# we don't allow guests to specify their own device_id, because
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 8e362782..9e9a6390 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -34,6 +34,7 @@ from synapse.http.servlet import (
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import (
AggregationPaginationToken,
+ PaginationChunk,
RelationPaginationToken,
)
@@ -71,11 +72,13 @@ class RelationSendServlet(RestServlet):
"POST",
client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST,
+ self.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT,
+ self.__class__.__name__,
)
def on_PUT(self, request, *args, **kwargs):
@@ -145,38 +148,55 @@ class RelationPaginationServlet(RestServlet):
room_id, requester.user.to_string()
)
- # This checks that a) the event exists and b) the user is allowed to
- # view it.
- yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ # This gets the original event and checks that a) the event exists and
+ # b) the user is allowed to view it.
+ event = yield self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
- if from_token:
- from_token = RelationPaginationToken.from_string(from_token)
-
- if to_token:
- to_token = RelationPaginationToken.from_string(to_token)
-
- result = yield self.store.get_relations_for_event(
- event_id=parent_id,
- relation_type=relation_type,
- event_type=event_type,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
+ if event.internal_metadata.is_redacted():
+ # If the event is redacted, return an empty list of relations
+ pagination_chunk = PaginationChunk(chunk=[])
+ else:
+ # Return the relations
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ pagination_chunk = yield self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
events = yield self.store.get_events_as_list(
- [c["event_id"] for c in result.chunk]
+ [c["event_id"] for c in pagination_chunk.chunk]
)
now = self.clock.time_msec()
- events = yield self._event_serializer.serialize_events(events, now)
+ # We set bundle_aggregations to False when retrieving the original
+ # event because we want the content before relations were applied to
+ # it.
+ original_event = yield self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=False
+ )
+ # Similarly, we don't allow relations to be applied to relations, so we
+ # return the original relations without any aggregations on top of them
+ # here.
+ events = yield self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=False
+ )
- return_value = result.to_dict()
+ return_value = pagination_chunk.to_dict()
return_value["chunk"] = events
+ return_value["original_event"] = original_event
defer.returnValue((200, return_value))
@@ -222,7 +242,7 @@ class RelationAggregationPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
- yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = yield self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -231,21 +251,26 @@ class RelationAggregationPaginationServlet(RestServlet):
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
- if from_token:
- from_token = AggregationPaginationToken.from_string(from_token)
-
- if to_token:
- to_token = AggregationPaginationToken.from_string(to_token)
-
- res = yield self.store.get_aggregation_groups_for_event(
- event_id=parent_id,
- event_type=event_type,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
-
- defer.returnValue((200, res.to_dict()))
+ if event.internal_metadata.is_redacted():
+ # If the event is redacted, return an empty list of relations
+ pagination_chunk = PaginationChunk(chunk=[])
+ else:
+ # Return the relations
+ if from_token:
+ from_token = AggregationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = AggregationPaginationToken.from_string(to_token)
+
+ pagination_chunk = yield self.store.get_aggregation_groups_for_event(
+ event_id=parent_id,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ defer.returnValue((200, pagination_chunk.to_dict()))
class RelationAggregationGroupPaginationServlet(RestServlet):
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 3318638d..5fefee4d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -25,7 +25,7 @@ from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
-from synapse.util import logcontext
+from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii
logger = logging.getLogger(__name__)
@@ -75,9 +75,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- yield logcontext.make_deferred_yieldable(
- FileSender().beginFileTransfer(f, request)
- )
+ yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index df3d985a..65afffbb 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -33,8 +33,8 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import logcontext
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
@@ -463,7 +463,7 @@ class MediaRepository(object):
)
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield logcontext.defer_to_thread(
+ t_byte_source = yield defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -511,7 +511,7 @@ class MediaRepository(object):
)
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield logcontext.defer_to_thread(
+ t_byte_source = yield defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -596,7 +596,7 @@ class MediaRepository(object):
return
if thumbnailer.transpose_method is not None:
- m_width, m_height = yield logcontext.defer_to_thread(
+ m_width, m_height = yield defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
@@ -616,11 +616,11 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
- t_byte_source = yield logcontext.defer_to_thread(
+ t_byte_source = yield defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
- t_byte_source = yield logcontext.defer_to_thread(
+ t_byte_source = yield defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index eff86836..25e5ac28 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -24,9 +24,8 @@ import six
from twisted.internet import defer
from twisted.protocols.basic import FileSender
-from synapse.util import logcontext
+from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
-from synapse.util.logcontext import make_deferred_yieldable
from ._base import Responder
@@ -65,7 +64,7 @@ class MediaStorage(object):
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- yield logcontext.defer_to_thread(
+ yield defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
yield finish_cb()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 053346fb..5871737b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -42,11 +42,11 @@ from synapse.http.server import (
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from ._base import FileInfo
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 359b45eb..37687ea7 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -20,8 +20,7 @@ import shutil
from twisted.internet import defer
from synapse.config._base import Config
-from synapse.util import logcontext
-from synapse.util.logcontext import run_in_background
+from synapse.logging.context import defer_to_thread, run_in_background
from .media_storage import FileResponder
@@ -68,7 +67,7 @@ class StorageProviderWrapper(StorageProvider):
backend (StorageProvider)
store_local (bool): Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully
- uploaded, or todo the upload in the backgroud.
+ uploaded, or todo the upload in the background.
store_remote (bool): Whether remote media should be uploaded
"""
@@ -125,7 +124,7 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return logcontext.defer_to_thread(
+ return defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1b454a56..9f708fa2 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -28,11 +28,11 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events.snapshot import EventContext
+from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 29589853..2f940dba 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -30,12 +30,12 @@ from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
from synapse.util.caches.descriptors import Cache
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.stringutils import exception_to_unicode
# import a function which will return a monotonic time, in seconds
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 86f84857..b486ca50 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -33,6 +33,8 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.utils import log_function
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
@@ -45,8 +47,6 @@ from synapse.util import batch_iter
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
-from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 6d680d40..06379281 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -29,14 +29,15 @@ from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import (
+from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import get_domain_from_id
+from synapse.util import batch_iter
from synapse.util.metrics import Measure
from ._base import SQLBaseStore
@@ -218,9 +219,116 @@ class EventsWorkerStore(SQLBaseStore):
if not event_ids:
defer.returnValue([])
- event_id_list = event_ids
- event_ids = set(event_ids)
+ # there may be duplicates so we cast the list to a set
+ event_entry_map = yield self._get_events_from_cache_or_db(
+ set(event_ids), allow_rejected=allow_rejected
+ )
+ events = []
+ for event_id in event_ids:
+ entry = event_entry_map.get(event_id, None)
+ if not entry:
+ continue
+
+ if not allow_rejected:
+ assert not entry.event.rejected_reason, (
+ "rejected event returned from _get_events_from_cache_or_db despite "
+ "allow_rejected=False"
+ )
+
+ # We may not have had the original event when we received a redaction, so
+ # we have to recheck auth now.
+
+ if not allow_rejected and entry.event.type == EventTypes.Redaction:
+ redacted_event_id = entry.event.redacts
+ event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ original_event_entry = event_map.get(redacted_event_id)
+ if not original_event_entry:
+ # we don't have the redacted event (or it was rejected).
+ #
+ # We assume that the redaction isn't authorized for now; if the
+ # redacted event later turns up, the redaction will be re-checked,
+ # and if it is found valid, the original will get redacted before it
+ # is served to the client.
+ logger.debug(
+ "Withholding redaction event %s since we don't (yet) have the "
+ "original %s",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ original_event = original_event_entry.event
+ if original_event.type == EventTypes.Create:
+ # we never serve redactions of Creates to clients.
+ logger.info(
+ "Withholding redaction %s of create event %s",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ if original_event.room_id != entry.event.room_id:
+ logger.info(
+ "Withholding redaction %s of event %s from a different room",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ if entry.event.internal_metadata.need_to_check_redaction():
+ original_domain = get_domain_from_id(original_event.sender)
+ redaction_domain = get_domain_from_id(entry.event.sender)
+ if original_domain != redaction_domain:
+ # the senders don't match, so this is forbidden
+ logger.info(
+ "Withholding redaction %s whose sender domain %s doesn't "
+ "match that of redacted event %s %s",
+ event_id,
+ redaction_domain,
+ redacted_event_id,
+ original_domain,
+ )
+ continue
+
+ # Update the cache to save doing the checks again.
+ entry.event.internal_metadata.recheck_redaction = False
+
+ if check_redacted and entry.redacted_event:
+ event = entry.redacted_event
+ else:
+ event = entry.event
+
+ events.append(event)
+
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
+
+ defer.returnValue(events)
+
+ @defer.inlineCallbacks
+ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ """Fetch a bunch of events from the cache or the database.
+
+ If events are pulled from the database, they will be cached for future lookups.
+
+ Args:
+ event_ids (Iterable[str]): The event_ids of the events to fetch
+ allow_rejected (bool): Whether to include rejected events
+
+ Returns:
+ Deferred[Dict[str, _EventCacheEntry]]:
+ map from event id to result
+ """
event_entry_map = self._get_events_from_cache(
event_ids, allow_rejected=allow_rejected
)
@@ -243,81 +351,7 @@ class EventsWorkerStore(SQLBaseStore):
event_entry_map.update(missing_events)
- events = []
- for event_id in event_id_list:
- entry = event_entry_map.get(event_id, None)
- if not entry:
- continue
-
- # Starting in room version v3, some redactions need to be rechecked if we
- # didn't have the redacted event at the time, so we recheck on read
- # instead.
- if not allow_rejected and entry.event.type == EventTypes.Redaction:
- if entry.event.internal_metadata.need_to_check_redaction():
- # XXX: we need to avoid calling get_event here.
- #
- # The problem is that we end up at this point when an event
- # which has been redacted is pulled out of the database by
- # _enqueue_events, because _enqueue_events needs to check
- # the redaction before it can cache the redacted event. So
- # obviously, calling get_event to get the redacted event out
- # of the database gives us an infinite loop.
- #
- # For now (quick hack to fix during 0.99 release cycle), we
- # just go and fetch the relevant row from the db, but it
- # would be nice to think about how we can cache this rather
- # than hit the db every time we access a redaction event.
- #
- # One thought on how to do this:
- # 1. split get_events_as_list up so that it is divided into
- # (a) get the rawish event from the db/cache, (b) do the
- # redaction/rejection filtering
- # 2. have _get_event_from_row just call the first half of
- # that
-
- orig_sender = yield self._simple_select_one_onecol(
- table="events",
- keyvalues={"event_id": entry.event.redacts},
- retcol="sender",
- allow_none=True,
- )
-
- expected_domain = get_domain_from_id(entry.event.sender)
- if (
- orig_sender
- and get_domain_from_id(orig_sender) == expected_domain
- ):
- # This redaction event is allowed. Mark as not needing a
- # recheck.
- entry.event.internal_metadata.recheck_redaction = False
- else:
- # We don't have the event that is being redacted, so we
- # assume that the event isn't authorized for now. (If we
- # later receive the event, then we will always redact
- # it anyway, since we have this redaction)
- continue
-
- if allow_rejected or not entry.event.rejected_reason:
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
-
- events.append(event)
-
- if get_prev_content:
- if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
- event.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- event.unsigned = dict(event.unsigned)
- event.unsigned["prev_content"] = prev.content
- event.unsigned["prev_sender"] = prev.sender
-
- defer.returnValue(events)
+ return event_entry_map
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
@@ -326,8 +360,8 @@ class EventsWorkerStore(SQLBaseStore):
"""Fetch events from the caches
Args:
- events (list(str)): list of event_ids to fetch
- allow_rejected (bool): Whether to teturn events that were rejected
+ events (Iterable[str]): list of event_ids to fetch
+ allow_rejected (bool): Whether to return events that were rejected
update_metrics (bool): Whether to update the cache hit ratio metrics
Returns:
@@ -384,19 +418,16 @@ class EventsWorkerStore(SQLBaseStore):
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.
-
"""
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = list(zip(*event_list))[0]
event_ids = [item for sublist in event_id_lists for item in sublist]
- rows = self._new_transaction(
+ row_dict = self._new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
)
- row_dict = {r["event_id"]: r for r in rows}
-
# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
@@ -454,7 +485,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected:
- rows[:] = [r for r in rows if not r["rejects"]]
+ rows[:] = [r for r in rows if r["rejected_reason"] is None]
res = yield make_deferred_yieldable(
defer.gatherResults(
@@ -463,8 +494,8 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_from_row,
row["internal_metadata"],
row["json"],
- row["redacts"],
- rejected_reason=row["rejects"],
+ row["redactions"],
+ rejected_reason=row["rejected_reason"],
format_version=row["format_version"],
)
for row in rows
@@ -475,49 +506,98 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event.event_id: e for e in res if e})
- def _fetch_event_rows(self, txn, events):
- rows = []
- N = 200
- for i in range(1 + len(events) // N):
- evs = events[i * N : (i + 1) * N]
- if not evs:
- break
+ def _fetch_event_rows(self, txn, event_ids):
+ """Fetch event rows from the database
+
+ Events which are not found are omitted from the result.
+
+ The returned per-event dicts contain the following keys:
+
+ * event_id (str)
+
+ * json (str): json-encoded event structure
+
+ * internal_metadata (str): json-encoded internal metadata dict
+
+ * format_version (int|None): The format of the event. Hopefully one
+ of EventFormatVersions. 'None' means the event predates
+ EventFormatVersions (so the event is format V1).
+
+ * rejected_reason (str|None): if the event was rejected, the reason
+ why.
+ * redactions (List[str]): a list of event-ids which (claim to) redact
+ this event.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection):
+ event_ids (Iterable[str]): event IDs to fetch
+
+ Returns:
+ Dict[str, Dict]: a map from event id to event info.
+ """
+ event_dict = {}
+ for evs in batch_iter(event_ids, 200):
sql = (
"SELECT "
- " e.event_id as event_id, "
+ " e.event_id, "
" e.internal_metadata,"
" e.json,"
" e.format_version, "
- " r.redacts as redacts,"
- " rej.event_id as rejects "
+ " rej.reason "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
- " LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs)
- rows.extend(self.cursor_to_dict(txn))
- return rows
+ for row in txn:
+ event_id = row[0]
+ event_dict[event_id] = {
+ "event_id": event_id,
+ "internal_metadata": row[1],
+ "json": row[2],
+ "format_version": row[3],
+ "rejected_reason": row[4],
+ "redactions": [],
+ }
+
+ # check for redactions
+ redactions_sql = (
+ "SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)"
+ ) % (",".join(["?"] * len(evs)),)
+
+ txn.execute(redactions_sql, evs)
+
+ for (redacter, redacted) in txn:
+ d = event_dict.get(redacted)
+ if d:
+ d["redactions"].append(redacter)
+
+ return event_dict
@defer.inlineCallbacks
def _get_event_from_row(
- self, internal_metadata, js, redacted, format_version, rejected_reason=None
+ self, internal_metadata, js, redactions, format_version, rejected_reason=None
):
+ """Parse an event row which has been read from the database
+
+ Args:
+ internal_metadata (str): json-encoded internal_metadata column
+ js (str): json-encoded event body from event_json
+ redactions (list[str]): a list of the events which claim to have redacted
+ this event, from the redactions table
+ format_version: (str): the 'format_version' column
+ rejected_reason (str|None): the reason this event was rejected, if any
+
+ Returns:
+ _EventCacheEntry
+ """
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
- if rejected_reason:
- rejected_reason = yield self._simple_select_one_onecol(
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- desc="_get_event_from_row_rejected_reason",
- )
-
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
@@ -529,41 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason=rejected_reason,
)
- redacted_event = None
- if redacted:
- redacted_event = prune_event(original_ev)
-
- redaction_id = yield self._simple_select_one_onecol(
- table="redactions",
- keyvalues={"redacts": redacted_event.event_id},
- retcol="event_id",
- desc="_get_event_from_row_redactions",
- )
-
- redacted_event.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
-
- because = yield self.get_event(
- redaction_id, check_redacted=False, allow_none=True
- )
-
- if because:
- # It's fine to do add the event directly, since get_pdu_json
- # will serialise this field correctly
- redacted_event.unsigned["redacted_because"] = because
-
- # Starting in room version v3, some redactions need to be
- # rechecked if we didn't have the redacted event at the
- # time, so we recheck on read instead.
- if because.internal_metadata.need_to_check_redaction():
- expected_domain = get_domain_from_id(original_ev.sender)
- if get_domain_from_id(because.sender) == expected_domain:
- # This redaction event is allowed. Mark as not needing a
- # recheck.
- because.internal_metadata.recheck_redaction = False
- else:
- # Senders don't match, so the event isn't actually redacted
- redacted_event = None
+ redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)
cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
@@ -574,6 +620,83 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue(cache_entry)
@defer.inlineCallbacks
+ def _maybe_redact_event_row(self, original_ev, redactions):
+ """Given an event object and a list of possible redacting event ids,
+ determine whether to honour any of those redactions and if so return a redacted
+ event.
+
+ Args:
+ original_ev (EventBase):
+ redactions (iterable[str]): list of event ids of potential redaction events
+
+ Returns:
+ Deferred[EventBase|None]: if the event should be redacted, a pruned
+ event object. Otherwise, None.
+ """
+ if original_ev.type == "m.room.create":
+ # we choose to ignore redactions of m.room.create events.
+ return None
+
+ if original_ev.type == "m.room.redaction":
+ # ... and redaction events
+ return None
+
+ redaction_map = yield self._get_events_from_cache_or_db(redactions)
+
+ for redaction_id in redactions:
+ redaction_entry = redaction_map.get(redaction_id)
+ if not redaction_entry:
+ # we don't have the redaction event, or the redaction event was not
+ # authorized.
+ logger.debug(
+ "%s was redacted by %s but redaction not found/authed",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ redaction_event = redaction_entry.event
+ if redaction_event.room_id != original_ev.room_id:
+ logger.debug(
+ "%s was redacted by %s but redaction was in a different room!",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ # Starting in room version v3, some redactions need to be
+ # rechecked if we didn't have the redacted event at the
+ # time, so we recheck on read instead.
+ if redaction_event.internal_metadata.need_to_check_redaction():
+ expected_domain = get_domain_from_id(original_ev.sender)
+ if get_domain_from_id(redaction_event.sender) == expected_domain:
+ # This redaction event is allowed. Mark as not needing a recheck.
+ redaction_event.internal_metadata.recheck_redaction = False
+ else:
+ # Senders don't match, so the event isn't actually redacted
+ logger.debug(
+ "%s was redacted by %s but the senders don't match",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id)
+
+ # we found a good redaction event. Redact!
+ redacted_event = prune_event(original_ev)
+ redacted_event.unsigned["redacted_by"] = redaction_id
+
+ # It's fine to add the event directly, since get_pdu_json
+ # will serialise this field correctly
+ redacted_event.unsigned["redacted_because"] = redaction_event
+
+ return redacted_event
+
+ # no valid redaction found for this event
+ return None
+
+ @defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 13a3d520..8b2c2a97 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -90,7 +90,8 @@ class RegistrationWorkerStore(SQLBaseStore):
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`.
+ including the keys `name`, `is_guest`, `device_id`, `token_id`,
+ `valid_until_ms`.
"""
return self.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@@ -284,7 +285,7 @@ class RegistrationWorkerStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
- " access_tokens.device_id"
+ " access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
@@ -433,19 +434,6 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def get_3pid_guest_access_token(self, medium, address):
- ret = yield self._simple_select_one(
- "threepid_guest_access_tokens",
- {"medium": medium, "address": address},
- ["guest_access_token"],
- True,
- "get_3pid_guest_access_token",
- )
- if ret:
- defer.returnValue(ret["guest_access_token"])
- defer.returnValue(None)
-
- @defer.inlineCallbacks
def get_user_id_by_threepid(self, medium, address, require_verified=False):
"""Returns user id from threepid
@@ -616,7 +604,7 @@ class RegistrationStore(
)
self.register_background_update_handler(
- "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag
+ "users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
# Create a background job for culling expired 3PID validity tokens
@@ -631,14 +619,14 @@ class RegistrationStore(
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks
- def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+ def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
last_user = progress.get("user_id", "")
- def _backgroud_update_set_deactivated_flag_txn(txn):
+ def _background_update_set_deactivated_flag_txn(txn):
txn.execute(
"""
SELECT
@@ -683,7 +671,7 @@ class RegistrationStore(
return False
end = yield self.runInteraction(
- "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn
+ "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
@@ -692,14 +680,16 @@ class RegistrationStore(
defer.returnValue(batch_size)
@defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token, device_id=None):
+ def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
- token
+ token
+ valid_until_ms (int|None): when the token is valid until. None for
+ no expiry.
Raises:
StoreError if there was a problem adding this.
"""
@@ -707,14 +697,19 @@ class RegistrationStore(
yield self._simple_insert(
"access_tokens",
- {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
+ {
+ "id": next_id,
+ "user_id": user_id,
+ "token": token,
+ "device_id": device_id,
+ "valid_until_ms": valid_until_ms,
+ },
desc="add_access_token_to_user",
)
- def register(
+ def register_user(
self,
user_id,
- token=None,
password_hash=None,
was_guest=False,
make_guest=False,
@@ -727,9 +722,6 @@ class RegistrationStore(
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user. If this
- is not None, the given access token is associated with the user
- id.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -746,10 +738,9 @@ class RegistrationStore(
StoreError if the user_id could not be registered.
"""
return self.runInteraction(
- "register",
- self._register,
+ "register_user",
+ self._register_user,
user_id,
- token,
password_hash,
was_guest,
make_guest,
@@ -759,11 +750,10 @@ class RegistrationStore(
user_type,
)
- def _register(
+ def _register_user(
self,
txn,
user_id,
- token,
password_hash,
was_guest,
make_guest,
@@ -776,8 +766,6 @@ class RegistrationStore(
now = int(self.clock.time())
- next_id = self._access_tokens_id_gen.get_next()
-
try:
if was_guest:
# Ensure that the guest user actually exists
@@ -825,14 +813,6 @@ class RegistrationStore(
if self._account_validity.enabled:
self.set_expiration_date_for_user_txn(txn, user_id)
- if token:
- # it's possible for this to get a conflict, but only for a single user
- # since tokens are namespaced based on their user ID
- txn.execute(
- "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
- (next_id, user_id, token),
- )
-
if create_profile_with_displayname:
# set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames
@@ -979,40 +959,6 @@ class RegistrationStore(
defer.returnValue(res if res else False)
- @defer.inlineCallbacks
- def save_or_get_3pid_guest_access_token(
- self, medium, address, access_token, inviter_user_id
- ):
- """
- Gets the 3pid's guest access token if exists, else saves access_token.
-
- Args:
- medium (str): Medium of the 3pid. Must be "email".
- address (str): 3pid address.
- access_token (str): The access token to persist if none is
- already persisted.
- inviter_user_id (str): User ID of the inviter.
-
- Returns:
- deferred str: Whichever access token is persisted at the end
- of this function call.
- """
-
- def insert(txn):
- txn.execute(
- "INSERT INTO threepid_guest_access_tokens "
- "(medium, address, guest_access_token, first_inviter) "
- "VALUES (?, ?, ?, ?)",
- (medium, address, access_token, inviter_user_id),
- )
-
- try:
- yield self.runInteraction("save_3pid_guest_access_token", insert)
- defer.returnValue(access_token)
- except self.database_engine.module.IntegrityError:
- ret = yield self.get_3pid_guest_access_token(medium, address)
- defer.returnValue(ret)
-
def add_user_pending_deactivation(self, user_id):
"""
Adds a user to the table of users who need to be parted from all the rooms they're
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 1b01934c..9954bc09 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -60,7 +60,7 @@ class PaginationChunk(object):
class RelationPaginationToken(object):
"""Pagination token for relation pagination API.
- As the results are order by topological ordering, we can use the
+ As the results are in topological order, we can use the
`topological_ordering` and `stream_ordering` fields of the events at the
boundaries of the chunk as pagination tokens.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8004aeb9..32cfd010 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -575,6 +575,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
count = yield self.runInteraction("did_forget_membership", f)
defer.returnValue(count == 0)
+ @defer.inlineCallbacks
+ def get_rooms_user_has_been_in(self, user_id):
+ """Get all rooms that the user has ever been in.
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[set[str]]: Set of room IDs.
+ """
+
+ room_ids = yield self._simple_select_onecol(
+ table="room_memberships",
+ keyvalues={"membership": Membership.JOIN, "user_id": user_id},
+ retcol="room_id",
+ desc="get_rooms_user_has_been_in",
+ )
+
+ return set(room_ids)
+
class RoomMemberStore(RoomMemberWorkerStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/storage/schema/delta/55/access_token_expiry.sql b/synapse/storage/schema/delta/55/access_token_expiry.sql
new file mode 100644
index 00000000..4590604b
--- /dev/null
+++ b/synapse/storage/schema/delta/55/access_token_expiry.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- when this access token can be used until, in ms since the epoch. NULL means the token
+-- never expires.
+ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT;
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d9482a38..a0465484 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -41,12 +41,12 @@ from six.moves import range
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -833,7 +833,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
as a list of _EventDictReturn and a token that points to the end
- of the result set.
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between
+ `from_token` and `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -905,15 +907,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return. Zero or less
- means no limit.
+ limit (int): The maximum number of events to return.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
Returns:
- tuple[list[dict], str]: Returns the results as a list of dicts and
- a token that points to the end of the result set. The dicts have
- the keys "event_id", "topological_ordering" and "stream_orderign".
+ tuple[list[FrozenEvent], str]: Returns the results as a list of
+ events and a token that points to the end of the result set. If no
+ events are returned then the end of the stream has been reached
+ (i.e. there are no events between `from_key` and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index b1188f6b..fd186191 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -133,34 +133,6 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response",
)
- def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
- """Persists an outgoing transaction and calculates the values for the
- previous transaction id list.
-
- This should be called before sending the transaction so that it has the
- correct value for the `prev_ids` key.
-
- Args:
- transaction_id (str)
- destination (str)
- origin_server_ts (int)
-
- Returns:
- list: A list of previous transaction ids.
- """
- return defer.succeed([])
-
- def delivered_txn(self, transaction_id, destination, code, response_dict):
- """Persists the response for an outgoing transaction.
-
- Args:
- transaction_id (str)
- destination (str)
- code (int)
- response_json (str)
- """
- pass
-
@defer.inlineCallbacks
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 954e32fb..f506b2a6 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -21,7 +21,7 @@ import attr
from twisted.internet import defer, task
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.logging import context
logger = logging.getLogger(__name__)
@@ -46,7 +46,7 @@ class Clock(object):
@defer.inlineCallbacks
def sleep(self, seconds):
d = defer.Deferred()
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
defer.returnValue(res)
@@ -91,10 +91,10 @@ class Clock(object):
"""
def wrapped_callback(*args, **kwargs):
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
callback(*args, **kwargs)
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False):
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7757b870..58a6b876 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -23,13 +23,12 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
-from synapse.util import Clock, logcontext, unwrapFirstError
-
-from .logcontext import (
+from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
+from synapse.util import Clock, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -153,7 +152,7 @@ def concurrently_execute(func, args, limit):
except StopIteration:
pass
- return logcontext.make_deferred_yieldable(
+ return make_deferred_yieldable(
defer.gatherResults(
[run_in_background(_concurrently_execute_inner) for _ in range(limit)],
consumeErrors=True,
@@ -174,7 +173,7 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
"""
- return logcontext.make_deferred_yieldable(
+ return make_deferred_yieldable(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter],
consumeErrors=True,
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index d2f25063..675db2f4 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -24,7 +24,8 @@ from six import itervalues, string_types
from twisted.internet import defer
-from synapse.util import logcontext, unwrapFirstError
+from synapse.logging.context import make_deferred_yieldable, preserve_fn
+from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
@@ -388,7 +389,7 @@ class CacheDescriptor(_CacheDescriptorBase):
except KeyError:
ret = defer.maybeDeferred(
- logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs
+ preserve_fn(self.function_to_call), obj, *args, **kwargs
)
def onErr(f):
@@ -408,7 +409,7 @@ class CacheDescriptor(_CacheDescriptorBase):
observer = result_d.observe()
if isinstance(observer, defer.Deferred):
- return logcontext.make_deferred_yieldable(observer)
+ return make_deferred_yieldable(observer)
else:
return observer
@@ -563,7 +564,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
cached_defers.append(
defer.maybeDeferred(
- logcontext.preserve_fn(self.function_to_call), **args_to_call
+ preserve_fn(self.function_to_call), **args_to_call
).addCallbacks(complete_all, errback)
)
@@ -571,7 +572,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
lambda _: results, unwrapFirstError
)
- return logcontext.make_deferred_yieldable(d)
+ return make_deferred_yieldable(d)
else:
return results
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index cbe54d45..d6908e16 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -16,9 +16,9 @@ import logging
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -78,7 +78,7 @@ class ResponseCache(object):
*deferred* should run its callbacks in the sentinel logcontext (ie,
you should wrap normal synapse deferreds with
- logcontext.run_in_background).
+ synapse.logging.context.run_in_background).
Can return either a new Deferred (which also doesn't follow the synapse
logcontext rules), or, if *deferred* was already complete, the actual
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 5a79db82..45af8d3e 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -17,8 +17,8 @@ import logging
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 629ed441..8b17d1c8 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -17,7 +17,7 @@ from six.moves import queue
from twisted.internet import threads
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable, run_in_background
class BackgroundFileConsumer(object):
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 9e1b5378..40e5c10a 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,673 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-""" Thread-local-alike tracking of log contexts within synapse
-
-This module provides objects and utilities for tracking contexts through
-synapse code, so that log lines can include a request identifier, and so that
-CPU and database activity can be accounted for against the request that caused
-them.
-
-See doc/log_contexts.rst for details on how this works.
+"""
+Backwards compatibility re-exports of ``synapse.logging.context`` functionality.
"""
-import logging
-import threading
-import types
-
-from twisted.internet import defer, threads
-
-logger = logging.getLogger(__name__)
-
-try:
- import resource
-
- # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined
- # to be 1 on linux so we hard code it.
- RUSAGE_THREAD = 1
-
- # If the system doesn't support RUSAGE_THREAD then this should throw an
- # exception.
- resource.getrusage(RUSAGE_THREAD)
-
- def get_thread_resource_usage():
- return resource.getrusage(RUSAGE_THREAD)
-
-
-except Exception:
- # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
- # won't track resource usage by returning None.
- def get_thread_resource_usage():
- return None
-
-
-class ContextResourceUsage(object):
- """Object for tracking the resources used by a log context
-
- Attributes:
- ru_utime (float): user CPU time (in seconds)
- ru_stime (float): system CPU time (in seconds)
- db_txn_count (int): number of database transactions done
- db_sched_duration_sec (float): amount of time spent waiting for a
- database connection
- db_txn_duration_sec (float): amount of time spent doing database
- transactions (excluding scheduling time)
- evt_db_fetch_count (int): number of events requested from the database
- """
-
- __slots__ = [
- "ru_stime",
- "ru_utime",
- "db_txn_count",
- "db_txn_duration_sec",
- "db_sched_duration_sec",
- "evt_db_fetch_count",
- ]
-
- def __init__(self, copy_from=None):
- """Create a new ContextResourceUsage
-
- Args:
- copy_from (ContextResourceUsage|None): if not None, an object to
- copy stats from
- """
- if copy_from is None:
- self.reset()
- else:
- self.ru_utime = copy_from.ru_utime
- self.ru_stime = copy_from.ru_stime
- self.db_txn_count = copy_from.db_txn_count
-
- self.db_txn_duration_sec = copy_from.db_txn_duration_sec
- self.db_sched_duration_sec = copy_from.db_sched_duration_sec
- self.evt_db_fetch_count = copy_from.evt_db_fetch_count
-
- def copy(self):
- return ContextResourceUsage(copy_from=self)
-
- def reset(self):
- self.ru_stime = 0.0
- self.ru_utime = 0.0
- self.db_txn_count = 0
-
- self.db_txn_duration_sec = 0
- self.db_sched_duration_sec = 0
- self.evt_db_fetch_count = 0
-
- def __repr__(self):
- return (
- "<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
- "db_txn_count='%r', db_txn_duration_sec='%r', "
- "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>"
- ) % (
- self.ru_stime,
- self.ru_utime,
- self.db_txn_count,
- self.db_txn_duration_sec,
- self.db_sched_duration_sec,
- self.evt_db_fetch_count,
- )
-
- def __iadd__(self, other):
- """Add another ContextResourceUsage's stats to this one's.
-
- Args:
- other (ContextResourceUsage): the other resource usage object
- """
- self.ru_utime += other.ru_utime
- self.ru_stime += other.ru_stime
- self.db_txn_count += other.db_txn_count
- self.db_txn_duration_sec += other.db_txn_duration_sec
- self.db_sched_duration_sec += other.db_sched_duration_sec
- self.evt_db_fetch_count += other.evt_db_fetch_count
- return self
-
- def __isub__(self, other):
- self.ru_utime -= other.ru_utime
- self.ru_stime -= other.ru_stime
- self.db_txn_count -= other.db_txn_count
- self.db_txn_duration_sec -= other.db_txn_duration_sec
- self.db_sched_duration_sec -= other.db_sched_duration_sec
- self.evt_db_fetch_count -= other.evt_db_fetch_count
- return self
-
- def __add__(self, other):
- res = ContextResourceUsage(copy_from=self)
- res += other
- return res
-
- def __sub__(self, other):
- res = ContextResourceUsage(copy_from=self)
- res -= other
- return res
-
-
-class LoggingContext(object):
- """Additional context for log formatting. Contexts are scoped within a
- "with" block.
-
- If a parent is given when creating a new context, then:
- - logging fields are copied from the parent to the new context on entry
- - when the new context exits, the cpu usage stats are copied from the
- child to the parent
-
- Args:
- name (str): Name for the context for debugging.
- parent_context (LoggingContext|None): The parent of the new context
- """
-
- __slots__ = [
- "previous_context",
- "name",
- "parent_context",
- "_resource_usage",
- "usage_start",
- "main_thread",
- "alive",
- "request",
- "tag",
- ]
-
- thread_local = threading.local()
-
- class Sentinel(object):
- """Sentinel to represent the root context"""
-
- __slots__ = []
-
- def __str__(self):
- return "sentinel"
-
- def copy_to(self, record):
- pass
-
- def start(self):
- pass
-
- def stop(self):
- pass
-
- def add_database_transaction(self, duration_sec):
- pass
-
- def add_database_scheduled(self, sched_sec):
- pass
-
- def record_event_fetch(self, event_count):
- pass
-
- def __nonzero__(self):
- return False
-
- __bool__ = __nonzero__ # python3
-
- sentinel = Sentinel()
-
- def __init__(self, name=None, parent_context=None, request=None):
- self.previous_context = LoggingContext.current_context()
- self.name = name
-
- # track the resources used by this context so far
- self._resource_usage = ContextResourceUsage()
-
- # If alive has the thread resource usage when the logcontext last
- # became active.
- self.usage_start = None
-
- self.main_thread = threading.current_thread()
- self.request = None
- self.tag = ""
- self.alive = True
-
- self.parent_context = parent_context
-
- if self.parent_context is not None:
- self.parent_context.copy_to(self)
-
- if request is not None:
- # the request param overrides the request from the parent context
- self.request = request
-
- def __str__(self):
- if self.request:
- return str(self.request)
- return "%s@%x" % (self.name, id(self))
-
- @classmethod
- def current_context(cls):
- """Get the current logging context from thread local storage
-
- Returns:
- LoggingContext: the current logging context
- """
- return getattr(cls.thread_local, "current_context", cls.sentinel)
-
- @classmethod
- def set_current_context(cls, context):
- """Set the current logging context in thread local storage
- Args:
- context(LoggingContext): The context to activate.
- Returns:
- The context that was previously active
- """
- current = cls.current_context()
-
- if current is not context:
- current.stop()
- cls.thread_local.current_context = context
- context.start()
- return current
-
- def __enter__(self):
- """Enters this logging context into thread local storage"""
- old_context = self.set_current_context(self)
- if self.previous_context != old_context:
- logger.warn(
- "Expected previous context %r, found %r",
- self.previous_context,
- old_context,
- )
- self.alive = True
-
- return self
-
- def __exit__(self, type, value, traceback):
- """Restore the logging context in thread local storage to the state it
- was before this context was entered.
- Returns:
- None to avoid suppressing any exceptions that were thrown.
- """
- current = self.set_current_context(self.previous_context)
- if current is not self:
- if current is self.sentinel:
- logger.warning("Expected logging context %s was lost", self)
- else:
- logger.warning(
- "Expected logging context %s but found %s", self, current
- )
- self.previous_context = None
- self.alive = False
-
- # if we have a parent, pass our CPU usage stats on
- if self.parent_context is not None and hasattr(
- self.parent_context, "_resource_usage"
- ):
- self.parent_context._resource_usage += self._resource_usage
-
- # reset them in case we get entered again
- self._resource_usage.reset()
-
- def copy_to(self, record):
- """Copy logging fields from this context to a log record or
- another LoggingContext
- """
-
- # 'request' is the only field we currently use in the logger, so that's
- # all we need to copy
- record.request = self.request
-
- def start(self):
- if threading.current_thread() is not self.main_thread:
- logger.warning("Started logcontext %s on different thread", self)
- return
-
- # If we haven't already started record the thread resource usage so
- # far
- if not self.usage_start:
- self.usage_start = get_thread_resource_usage()
-
- def stop(self):
- if threading.current_thread() is not self.main_thread:
- logger.warning("Stopped logcontext %s on different thread", self)
- return
-
- # When we stop, let's record the cpu used since we started
- if not self.usage_start:
- logger.warning("Called stop on logcontext %s without calling start", self)
- return
-
- utime_delta, stime_delta = self._get_cputime()
- self._resource_usage.ru_utime += utime_delta
- self._resource_usage.ru_stime += stime_delta
-
- self.usage_start = None
-
- def get_resource_usage(self):
- """Get resources used by this logcontext so far.
-
- Returns:
- ContextResourceUsage: a *copy* of the object tracking resource
- usage so far
- """
- # we always return a copy, for consistency
- res = self._resource_usage.copy()
-
- # If we are on the correct thread and we're currently running then we
- # can include resource usage so far.
- is_main_thread = threading.current_thread() is self.main_thread
- if self.alive and self.usage_start and is_main_thread:
- utime_delta, stime_delta = self._get_cputime()
- res.ru_utime += utime_delta
- res.ru_stime += stime_delta
-
- return res
-
- def _get_cputime(self):
- """Get the cpu usage time so far
-
- Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
- """
- current = get_thread_resource_usage()
-
- utime_delta = current.ru_utime - self.usage_start.ru_utime
- stime_delta = current.ru_stime - self.usage_start.ru_stime
-
- # sanity check
- if utime_delta < 0:
- logger.error(
- "utime went backwards! %f < %f",
- current.ru_utime,
- self.usage_start.ru_utime,
- )
- utime_delta = 0
-
- if stime_delta < 0:
- logger.error(
- "stime went backwards! %f < %f",
- current.ru_stime,
- self.usage_start.ru_stime,
- )
- stime_delta = 0
-
- return utime_delta, stime_delta
-
- def add_database_transaction(self, duration_sec):
- if duration_sec < 0:
- raise ValueError("DB txn time can only be non-negative")
- self._resource_usage.db_txn_count += 1
- self._resource_usage.db_txn_duration_sec += duration_sec
-
- def add_database_scheduled(self, sched_sec):
- """Record a use of the database pool
-
- Args:
- sched_sec (float): number of seconds it took us to get a
- connection
- """
- if sched_sec < 0:
- raise ValueError("DB scheduling time can only be non-negative")
- self._resource_usage.db_sched_duration_sec += sched_sec
-
- def record_event_fetch(self, event_count):
- """Record a number of events being fetched from the db
-
- Args:
- event_count (int): number of events being fetched
- """
- self._resource_usage.evt_db_fetch_count += event_count
-
-
-class LoggingContextFilter(logging.Filter):
- """Logging filter that adds values from the current logging context to each
- record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
- """
-
- def __init__(self, **defaults):
- self.defaults = defaults
-
- def filter(self, record):
- """Add each fields from the logging contexts to the record.
- Returns:
- True to include the record in the log output.
- """
- context = LoggingContext.current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
-
- # context should never be None, but if it somehow ends up being, then
- # we end up in a death spiral of infinite loops, so let's check, for
- # robustness' sake.
- if context is not None:
- context.copy_to(record)
-
- return True
-
-
-class PreserveLoggingContext(object):
- """Captures the current logging context and restores it when the scope is
- exited. Used to restore the context after a function using
- @defer.inlineCallbacks is resumed by a callback from the reactor."""
-
- __slots__ = ["current_context", "new_context", "has_parent"]
-
- def __init__(self, new_context=None):
- if new_context is None:
- new_context = LoggingContext.sentinel
- self.new_context = new_context
-
- def __enter__(self):
- """Captures the current logging context"""
- self.current_context = LoggingContext.set_current_context(self.new_context)
-
- if self.current_context:
- self.has_parent = self.current_context.previous_context is not None
- if not self.current_context.alive:
- logger.debug("Entering dead context: %s", self.current_context)
-
- def __exit__(self, type, value, traceback):
- """Restores the current logging context"""
- context = LoggingContext.set_current_context(self.current_context)
-
- if context != self.new_context:
- if context is LoggingContext.sentinel:
- logger.warning("Expected logging context %s was lost", self.new_context)
- else:
- logger.warning(
- "Expected logging context %s but found %s",
- self.new_context,
- context,
- )
-
- if self.current_context is not LoggingContext.sentinel:
- if not self.current_context.alive:
- logger.debug("Restoring dead context: %s", self.current_context)
-
-
-def nested_logging_context(suffix, parent_context=None):
- """Creates a new logging context as a child of another.
-
- The nested logging context will have a 'request' made up of the parent context's
- request, plus the given suffix.
-
- CPU/db usage stats will be added to the parent context's on exit.
-
- Normal usage looks like:
-
- with nested_logging_context(suffix):
- # ... do stuff
-
- Args:
- suffix (str): suffix to add to the parent context's 'request'.
- parent_context (LoggingContext|None): parent context. Will use the current context
- if None.
-
- Returns:
- LoggingContext: new logging context.
- """
- if parent_context is None:
- parent_context = LoggingContext.current_context()
- return LoggingContext(
- parent_context=parent_context, request=parent_context.request + "-" + suffix
- )
-
-
-def preserve_fn(f):
- """Function decorator which wraps the function with run_in_background"""
-
- def g(*args, **kwargs):
- return run_in_background(f, *args, **kwargs)
-
- return g
-
-
-def run_in_background(f, *args, **kwargs):
- """Calls a function, ensuring that the current context is restored after
- return from the function, and that the sentinel context is set once the
- deferred returned by the function completes.
-
- Useful for wrapping functions that return a deferred or coroutine, which you don't
- yield or await on (for instance because you want to pass it to
- deferred.gatherResults()).
-
- Note that if you completely discard the result, you should make sure that
- `f` doesn't raise any deferred exceptions, otherwise a scary-looking
- CRITICAL error about an unhandled error will be logged without much
- indication about where it came from.
- """
- current = LoggingContext.current_context()
- try:
- res = f(*args, **kwargs)
- except: # noqa: E722
- # the assumption here is that the caller doesn't want to be disturbed
- # by synchronous exceptions, so let's turn them into Failures.
- return defer.fail()
-
- if isinstance(res, types.CoroutineType):
- res = defer.ensureDeferred(res)
-
- if not isinstance(res, defer.Deferred):
- return res
-
- if res.called and not res.paused:
- # The function should have maintained the logcontext, so we can
- # optimise out the messing about
- return res
-
- # The function may have reset the context before returning, so
- # we need to restore it now.
- ctx = LoggingContext.set_current_context(current)
-
- # The original context will be restored when the deferred
- # completes, but there is nothing waiting for it, so it will
- # get leaked into the reactor or some other function which
- # wasn't expecting it. We therefore need to reset the context
- # here.
- #
- # (If this feels asymmetric, consider it this way: we are
- # effectively forking a new thread of execution. We are
- # probably currently within a ``with LoggingContext()`` block,
- # which is supposed to have a single entry and exit point. But
- # by spawning off another deferred, we are effectively
- # adding a new exit point.)
- res.addBoth(_set_context_cb, ctx)
- return res
-
-
-def make_deferred_yieldable(deferred):
- """Given a deferred, make it follow the Synapse logcontext rules:
-
- If the deferred has completed (or is not actually a Deferred), essentially
- does nothing (just returns another completed deferred with the
- result/failure).
-
- If the deferred has not yet completed, resets the logcontext before
- returning a deferred. Then, when the deferred completes, restores the
- current logcontext before running callbacks/errbacks.
-
- (This is more-or-less the opposite operation to run_in_background.)
- """
- if not isinstance(deferred, defer.Deferred):
- return deferred
-
- if deferred.called and not deferred.paused:
- # it looks like this deferred is ready to run any callbacks we give it
- # immediately. We may as well optimise out the logcontext faffery.
- return deferred
-
- # ok, we can't be sure that a yield won't block, so let's reset the
- # logcontext, and add a callback to the deferred to restore it.
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
- deferred.addBoth(_set_context_cb, prev_context)
- return deferred
-
-
-def _set_context_cb(result, context):
- """A callback function which just sets the logging context"""
- LoggingContext.set_current_context(context)
- return result
-
-
-def defer_to_thread(reactor, f, *args, **kwargs):
- """
- Calls the function `f` using a thread from the reactor's default threadpool and
- returns the result as a Deferred.
-
- Creates a new logcontext for `f`, which is created as a child of the current
- logcontext (so its CPU usage metrics will get attributed to the current
- logcontext). `f` should preserve the logcontext it is given.
-
- The result deferred follows the Synapse logcontext rules: you should `yield`
- on it.
-
- Args:
- reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
- the Deferred will be invoked, and whose threadpool we should use for the
- function.
-
- Normally this will be hs.get_reactor().
-
- f (callable): The function to call.
-
- args: positional arguments to pass to f.
-
- kwargs: keyword arguments to pass to f.
-
- Returns:
- Deferred: A Deferred which fires a callback with the result of `f`, or an
- errback if `f` throws an exception.
- """
- return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
-
-
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
- """
- A wrapper for twisted.internet.threads.deferToThreadpool, which handles
- logcontexts correctly.
-
- Calls the function `f` using a thread from the given threadpool and returns
- the result as a Deferred.
-
- Creates a new logcontext for `f`, which is created as a child of the current
- logcontext (so its CPU usage metrics will get attributed to the current
- logcontext). `f` should preserve the logcontext it is given.
-
- The result deferred follows the Synapse logcontext rules: you should `yield`
- on it.
-
- Args:
- reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
- the Deferred will be invoked. Normally this will be hs.get_reactor().
-
- threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for
- running `f`. Normally this will be hs.get_reactor().getThreadPool().
-
- f (callable): The function to call.
-
- args: positional arguments to pass to f.
-
- kwargs: keyword arguments to pass to f.
-
- Returns:
- Deferred: A Deferred which fires a callback with the result of `f`, or an
- errback if `f` throws an exception.
- """
- logcontext = LoggingContext.current_context()
-
- def g():
- with LoggingContext(parent_context=logcontext):
- return f(*args, **kwargs)
-
- return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
+from synapse.logging.context import (
+ LoggingContext,
+ LoggingContextFilter,
+ PreserveLoggingContext,
+ defer_to_thread,
+ make_deferred_yieldable,
+ nested_logging_context,
+ preserve_fn,
+ run_in_background,
+)
+
+__all__ = [
+ "defer_to_thread",
+ "LoggingContext",
+ "LoggingContextFilter",
+ "make_deferred_yieldable",
+ "nested_logging_context",
+ "preserve_fn",
+ "PreserveLoggingContext",
+ "run_in_background",
+]
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index fbf570c7..320e8f81 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -1,5 +1,4 @@
-# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,41 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Backwards compatibility re-exports of ``synapse.logging.formatter`` functionality.
+"""
-import logging
-import traceback
+from synapse.logging.formatter import LogFormatter
-from six import StringIO
-
-
-class LogFormatter(logging.Formatter):
- """Log formatter which gives more detail for exceptions
-
- This is the same as the standard log formatter, except that when logging
- exceptions [typically via log.foo("msg", exc_info=1)], it prints the
- sequence that led up to the point at which the exception was caught.
- (Normally only stack frames between the point the exception was raised and
- where it was caught are logged).
- """
-
- def __init__(self, *args, **kwargs):
- super(LogFormatter, self).__init__(*args, **kwargs)
-
- def formatException(self, ei):
- sio = StringIO()
- (typ, val, tb) = ei
-
- # log the stack above the exception capture point if possible, but
- # check that we actually have an f_back attribute to work around
- # https://twistedmatrix.com/trac/ticket/9305
-
- if tb and hasattr(tb.tb_frame, "f_back"):
- sio.write("Capture point (most recent call last):\n")
- traceback.print_stack(tb.tb_frame.f_back, None, sio)
-
- traceback.print_exception(typ, val, tb, None, sio)
- s = sio.getvalue()
- sio.close()
- if s[-1:] == "\n":
- s = s[:-1]
- return s
+__all__ = ["LogFormatter"]
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 01284d3c..c30b6de1 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -20,8 +20,8 @@ from prometheus_client import Counter
from twisted.internet import defer
+from synapse.logging.context import LoggingContext
from synapse.metrics import InFlightGauge
-from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 06defa81..5ca4521c 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -20,7 +20,7 @@ import logging
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
-from synapse.util.logcontext import (
+from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
@@ -36,9 +36,11 @@ class FederationRateLimiter(object):
clock (Clock)
config (FederationRateLimitConfig)
"""
- self.clock = clock
- self._config = config
- self.ratelimiters = {}
+
+ def new_limiter():
+ return _PerHostRatelimiter(clock=clock, config=config)
+
+ self.ratelimiters = collections.defaultdict(new_limiter)
def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host
@@ -53,11 +55,9 @@ class FederationRateLimiter(object):
host (str): Origin of incoming request.
Returns:
- _PerHostRatelimiter
+ context manager which returns a deferred.
"""
- return self.ratelimiters.setdefault(
- host, _PerHostRatelimiter(clock=self.clock, config=self._config)
- ).ratelimit()
+ return self.ratelimiters[host].ratelimit()
class _PerHostRatelimiter(object):
@@ -122,7 +122,7 @@ class _PerHostRatelimiter(object):
self.request_times.append(time_now)
def queue_request():
- if len(self.current_processing) > self.concurrent_requests:
+ if len(self.current_processing) >= self.concurrent_requests:
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1a774564..d8d0ceae 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -17,7 +17,7 @@ import random
from twisted.internet import defer
-import synapse.util.logcontext
+import synapse.logging.context
from synapse.api.errors import CodeMessageException
logger = logging.getLogger(__name__)
@@ -225,4 +225,4 @@ class RetryDestinationLimiter(object):
logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
- synapse.util.logcontext.run_in_background(store_retry_timings)
+ synapse.logging.context.run_in_background(store_retry_timings)