summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2017-04-11 11:15:22 +0100
committerErik Johnston <erikj@matrix.org>2017-04-11 11:15:22 +0100
commit6396247906b041b8a88e7de17808b2a5ca5e9064 (patch)
treeee78b86f62119a7605ab2fdd1bc8afd3953e7978 /synapse
parent336f3ddcbcae0249cf6bcbb64a3054d1bb609515 (diff)
Imported Upstream version 0.20.0
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/api/errors.py74
-rw-r--r--synapse/api/filtering.py298
-rw-r--r--synapse/app/appservice.py10
-rw-r--r--synapse/app/client_reader.py12
-rw-r--r--synapse/app/federation_reader.py10
-rw-r--r--synapse/app/federation_sender.py10
-rwxr-xr-xsynapse/app/homeserver.py13
-rw-r--r--synapse/app/media_repository.py12
-rw-r--r--synapse/app/pusher.py11
-rw-r--r--synapse/app/synchrotron.py26
-rwxr-xr-xsynapse/app/synctl.py47
-rw-r--r--synapse/config/logger.py75
-rw-r--r--synapse/crypto/keyring.py71
-rw-r--r--synapse/events/snapshot.py26
-rw-r--r--synapse/federation/federation_client.py40
-rw-r--r--synapse/federation/federation_server.py176
-rw-r--r--synapse/federation/send_queue.py7
-rw-r--r--synapse/federation/transaction_queue.py228
-rw-r--r--synapse/federation/transport/client.py7
-rw-r--r--synapse/handlers/auth.py32
-rw-r--r--synapse/handlers/device.py50
-rw-r--r--synapse/handlers/directory.py1
-rw-r--r--synapse/handlers/e2e_keys.py34
-rw-r--r--synapse/handlers/federation.py291
-rw-r--r--synapse/handlers/identity.py37
-rw-r--r--synapse/handlers/initial_sync.py11
-rw-r--r--synapse/handlers/presence.py98
-rw-r--r--synapse/handlers/profile.py14
-rw-r--r--synapse/handlers/receipts.py5
-rw-r--r--synapse/handlers/room_list.py66
-rw-r--r--synapse/handlers/sync.py82
-rw-r--r--synapse/http/matrixfederationclient.py327
-rw-r--r--synapse/http/servlet.py10
-rw-r--r--synapse/notifier.py66
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/push/push_rule_evaluator.py104
-rw-r--r--synapse/push/push_tools.py8
-rw-r--r--synapse/python_dependencies.py3
-rw-r--r--synapse/replication/resource.py4
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py5
-rw-r--r--synapse/replication/slave/storage/events.py49
-rw-r--r--synapse/replication/slave/storage/presence.py1
-rw-r--r--synapse/rest/client/v1/login.py97
-rw-r--r--synapse/rest/client/v1/presence.py3
-rw-r--r--synapse/rest/client/v1/room.py3
-rw-r--r--synapse/rest/client/v2_alpha/account.py114
-rw-r--r--synapse/rest/client/v2_alpha/devices.py47
-rw-r--r--synapse/rest/client/v2_alpha/register.py139
-rw-r--r--synapse/rest/client/v2_alpha/sync.py20
-rw-r--r--synapse/rest/media/v1/download_resource.py12
-rw-r--r--synapse/rest/media/v1/media_repository.py44
-rw-r--r--synapse/state.py17
-rw-r--r--synapse/storage/_base.py64
-rw-r--r--synapse/storage/account_data.py4
-rw-r--r--synapse/storage/background_updates.py15
-rw-r--r--synapse/storage/deviceinbox.py16
-rw-r--r--synapse/storage/devices.py30
-rw-r--r--synapse/storage/end_to_end_keys.py63
-rw-r--r--synapse/storage/event_federation.py35
-rw-r--r--synapse/storage/event_push_actions.py2
-rw-r--r--synapse/storage/events.py402
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/presence.py4
-rw-r--r--synapse/storage/receipts.py5
-rw-r--r--synapse/storage/registration.py2
-rw-r--r--synapse/storage/room.py4
-rw-r--r--synapse/storage/roommember.py66
-rw-r--r--synapse/storage/signatures.py2
-rw-r--r--synapse/storage/state.py92
-rw-r--r--synapse/storage/stream.py3
-rw-r--r--synapse/storage/tags.py4
-rw-r--r--synapse/storage/util/id_generators.py14
-rw-r--r--synapse/util/__init__.py10
-rw-r--r--synapse/util/caches/descriptors.py165
-rw-r--r--synapse/util/caches/stream_change_cache.py2
-rw-r--r--synapse/util/logcontext.py77
-rw-r--r--synapse/util/msisdn.py40
-rw-r--r--synapse/util/retryutils.py34
-rw-r--r--synapse/visibility.py7
83 files changed, 2797 insertions, 1330 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 7628e7c5..2e5f4e0e 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.19.3"
+__version__ = "0.20.0"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 03a215ab..9dbc7993 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes
from synapse.types import UserID
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util import logcontext
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -209,8 +209,7 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
- preserve_context_over_fn(
- self.store.insert_client_ip,
+ logcontext.preserve_fn(self.store.insert_client_ip)(
user=user,
access_token=access_token,
ip=ip_addr,
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index ca23c9c4..489efb7f 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -44,6 +45,7 @@ class JoinRules(object):
class LoginType(object):
PASSWORD = u"m.login.password"
EMAIL_IDENTITY = u"m.login.email.identity"
+ MSISDN = u"m.login.msisdn"
RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 921c4577..6fbd5d68 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -15,6 +15,7 @@
"""Contains exceptions and error codes."""
+import json
import logging
logger = logging.getLogger(__name__)
@@ -50,27 +51,35 @@ class Codes(object):
class CodeMessageException(RuntimeError):
- """An exception with integer code and message string attributes."""
+ """An exception with integer code and message string attributes.
+ Attributes:
+ code (int): HTTP error code
+ msg (str): string describing the error
+ """
def __init__(self, code, msg):
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code
self.msg = msg
- self.response_code_message = None
def error_dict(self):
return cs_error(self.msg)
class SynapseError(CodeMessageException):
- """A base error which can be caught for all synapse events."""
+ """A base exception type for matrix errors which have an errcode and error
+ message (as well as an HTTP status code).
+
+ Attributes:
+ errcode (str): Matrix error code e.g 'M_FORBIDDEN'
+ """
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error.
Args:
code (int): The integer error code (an HTTP response code)
msg (str): The human-readable error message.
- err (str): The error code e.g 'M_FORBIDDEN'
+ errcode (str): The matrix error code e.g 'M_FORBIDDEN'
"""
super(SynapseError, self).__init__(code, msg)
self.errcode = errcode
@@ -81,6 +90,39 @@ class SynapseError(CodeMessageException):
self.errcode,
)
+ @classmethod
+ def from_http_response_exception(cls, err):
+ """Make a SynapseError based on an HTTPResponseException
+
+ This is useful when a proxied request has failed, and we need to
+ decide how to map the failure onto a matrix error to send back to the
+ client.
+
+ An attempt is made to parse the body of the http response as a matrix
+ error. If that succeeds, the errcode and error message from the body
+ are used as the errcode and error message in the new synapse error.
+
+ Otherwise, the errcode is set to M_UNKNOWN, and the error message is
+ set to the reason code from the HTTP response.
+
+ Args:
+ err (HttpResponseException):
+
+ Returns:
+ SynapseError:
+ """
+ # try to parse the body as json, to get better errcode/msg, but
+ # default to M_UNKNOWN with the HTTP status as the error text
+ try:
+ j = json.loads(err.response)
+ except ValueError:
+ j = {}
+ errcode = j.get('errcode', Codes.UNKNOWN)
+ errmsg = j.get('error', err.msg)
+
+ res = SynapseError(err.code, errmsg, errcode)
+ return res
+
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
@@ -106,13 +148,11 @@ class UnrecognizedRequestError(SynapseError):
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
- def __init__(self, *args, **kwargs):
- if "errcode" not in kwargs:
- kwargs["errcode"] = Codes.NOT_FOUND
+ def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
super(NotFoundError, self).__init__(
404,
- "Not found",
- **kwargs
+ msg,
+ errcode=errcode
)
@@ -173,7 +213,6 @@ class LimitExceededError(SynapseError):
errcode=Codes.LIMIT_EXCEEDED):
super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
- self.response_code_message = "Too Many Requests"
def error_dict(self):
return cs_error(
@@ -243,6 +282,19 @@ class FederationError(RuntimeError):
class HttpResponseException(CodeMessageException):
+ """
+ Represents an HTTP-level failure of an outbound request
+
+ Attributes:
+ response (str): body of response
+ """
def __init__(self, code, msg, response):
- self.response = response
+ """
+
+ Args:
+ code (int): HTTP status code
+ msg (str): reason phrase from HTTP response status line
+ response (str): body of response
+ """
super(HttpResponseException, self).__init__(code, msg)
+ self.response = response
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index fb291d7f..83206348 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -13,11 +13,174 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError
+from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID
-
from twisted.internet import defer
import ujson as json
+import jsonschema
+from jsonschema import FormatChecker
+
+FILTER_SCHEMA = {
+ "additionalProperties": False,
+ "type": "object",
+ "properties": {
+ "limit": {
+ "type": "number"
+ },
+ "senders": {
+ "$ref": "#/definitions/user_id_array"
+ },
+ "not_senders": {
+ "$ref": "#/definitions/user_id_array"
+ },
+ # TODO: We don't limit event type values but we probably should...
+ # check types are valid event types
+ "types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "not_types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ }
+}
+
+ROOM_FILTER_SCHEMA = {
+ "additionalProperties": False,
+ "type": "object",
+ "properties": {
+ "not_rooms": {
+ "$ref": "#/definitions/room_id_array"
+ },
+ "rooms": {
+ "$ref": "#/definitions/room_id_array"
+ },
+ "ephemeral": {
+ "$ref": "#/definitions/room_event_filter"
+ },
+ "include_leave": {
+ "type": "boolean"
+ },
+ "state": {
+ "$ref": "#/definitions/room_event_filter"
+ },
+ "timeline": {
+ "$ref": "#/definitions/room_event_filter"
+ },
+ "account_data": {
+ "$ref": "#/definitions/room_event_filter"
+ },
+ }
+}
+
+ROOM_EVENT_FILTER_SCHEMA = {
+ "additionalProperties": False,
+ "type": "object",
+ "properties": {
+ "limit": {
+ "type": "number"
+ },
+ "senders": {
+ "$ref": "#/definitions/user_id_array"
+ },
+ "not_senders": {
+ "$ref": "#/definitions/user_id_array"
+ },
+ "types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "not_types": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "rooms": {
+ "$ref": "#/definitions/room_id_array"
+ },
+ "not_rooms": {
+ "$ref": "#/definitions/room_id_array"
+ },
+ "contains_url": {
+ "type": "boolean"
+ }
+ }
+}
+
+USER_ID_ARRAY_SCHEMA = {
+ "type": "array",
+ "items": {
+ "type": "string",
+ "format": "matrix_user_id"
+ }
+}
+
+ROOM_ID_ARRAY_SCHEMA = {
+ "type": "array",
+ "items": {
+ "type": "string",
+ "format": "matrix_room_id"
+ }
+}
+
+USER_FILTER_SCHEMA = {
+ "$schema": "http://json-schema.org/draft-04/schema#",
+ "description": "schema for a Sync filter",
+ "type": "object",
+ "definitions": {
+ "room_id_array": ROOM_ID_ARRAY_SCHEMA,
+ "user_id_array": USER_ID_ARRAY_SCHEMA,
+ "filter": FILTER_SCHEMA,
+ "room_filter": ROOM_FILTER_SCHEMA,
+ "room_event_filter": ROOM_EVENT_FILTER_SCHEMA
+ },
+ "properties": {
+ "presence": {
+ "$ref": "#/definitions/filter"
+ },
+ "account_data": {
+ "$ref": "#/definitions/filter"
+ },
+ "room": {
+ "$ref": "#/definitions/room_filter"
+ },
+ "event_format": {
+ "type": "string",
+ "enum": ["client", "federation"]
+ },
+ "event_fields": {
+ "type": "array",
+ "items": {
+ "type": "string",
+ # Don't allow '\\' in event field filters. This makes matching
+ # events a lot easier as we can then use a negative lookbehind
+ # assertion to split '\.' If we allowed \\ then it would
+ # incorrectly split '\\.' See synapse.events.utils.serialize_event
+ "pattern": "^((?!\\\).)*$"
+ }
+ }
+ },
+ "additionalProperties": False
+}
+
+
+@FormatChecker.cls_checks('matrix_room_id')
+def matrix_room_id_validator(room_id_str):
+ return RoomID.from_string(room_id_str)
+
+
+@FormatChecker.cls_checks('matrix_user_id')
+def matrix_user_id_validator(user_id_str):
+ return UserID.from_string(user_id_str)
class Filtering(object):
@@ -52,98 +215,11 @@ class Filtering(object):
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
-
- top_level_definitions = [
- "presence", "account_data"
- ]
-
- room_level_definitions = [
- "state", "timeline", "ephemeral", "account_data"
- ]
-
- for key in top_level_definitions:
- if key in user_filter_json:
- self._check_definition(user_filter_json[key])
-
- if "room" in user_filter_json:
- self._check_definition_room_lists(user_filter_json["room"])
- for key in room_level_definitions:
- if key in user_filter_json["room"]:
- self._check_definition(user_filter_json["room"][key])
-
- if "event_fields" in user_filter_json:
- if type(user_filter_json["event_fields"]) != list:
- raise SynapseError(400, "event_fields must be a list of strings")
- for field in user_filter_json["event_fields"]:
- if not isinstance(field, basestring):
- raise SynapseError(400, "Event field must be a string")
- # Don't allow '\\' in event field filters. This makes matching
- # events a lot easier as we can then use a negative lookbehind
- # assertion to split '\.' If we allowed \\ then it would
- # incorrectly split '\\.' See synapse.events.utils.serialize_event
- if r'\\' in field:
- raise SynapseError(
- 400, r'The escape character \ cannot itself be escaped'
- )
-
- def _check_definition_room_lists(self, definition):
- """Check that "rooms" and "not_rooms" are lists of room ids if they
- are present
-
- Args:
- definition(dict): The filter definition
- Raises:
- SynapseError: If there was a problem with this definition.
- """
- # check rooms are valid room IDs
- room_id_keys = ["rooms", "not_rooms"]
- for key in room_id_keys:
- if key in definition:
- if type(definition[key]) != list:
- raise SynapseError(400, "Expected %s to be a list." % key)
- for room_id in definition[key]:
- RoomID.from_string(room_id)
-
- def _check_definition(self, definition):
- """Check if the provided definition is valid.
-
- This inspects not only the types but also the values to make sure they
- make sense.
-
- Args:
- definition(dict): The filter definition
- Raises:
- SynapseError: If there was a problem with this definition.
- """
- # NB: Filters are the complete json blobs. "Definitions" are an
- # individual top-level key e.g. public_user_data. Filters are made of
- # many definitions.
- if type(definition) != dict:
- raise SynapseError(
- 400, "Expected JSON object, not %s" % (definition,)
- )
-
- self._check_definition_room_lists(definition)
-
- # check senders are valid user IDs
- user_id_keys = ["senders", "not_senders"]
- for key in user_id_keys:
- if key in definition:
- if type(definition[key]) != list:
- raise SynapseError(400, "Expected %s to be a list." % key)
- for user_id in definition[key]:
- UserID.from_string(user_id)
-
- # TODO: We don't limit event type values but we probably should...
- # check types are valid event types
- event_keys = ["types", "not_types"]
- for key in event_keys:
- if key in definition:
- if type(definition[key]) != list:
- raise SynapseError(400, "Expected %s to be a list." % key)
- for event_type in definition[key]:
- if not isinstance(event_type, basestring):
- raise SynapseError(400, "Event type should be a string")
+ try:
+ jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
+ format_checker=FormatChecker())
+ except jsonschema.ValidationError as e:
+ raise SynapseError(400, e.message)
class FilterCollection(object):
@@ -253,19 +329,35 @@ class Filter(object):
Returns:
bool: True if the event matches
"""
- sender = event.get("sender", None)
- if not sender:
- # Presence events have their 'sender' in content.user_id
- content = event.get("content")
- # account_data has been allowed to have non-dict content, so check type first
- if isinstance(content, dict):
- sender = content.get("user_id")
+ # We usually get the full "events" as dictionaries coming through,
+ # except for presence which actually gets passed around as its own
+ # namedtuple type.
+ if isinstance(event, UserPresenceState):
+ sender = event.user_id
+ room_id = None
+ ev_type = "m.presence"
+ is_url = False
+ else:
+ sender = event.get("sender", None)
+ if not sender:
+ # Presence events had their 'sender' in content.user_id, but are
+ # now handled above. We don't know if anything else uses this
+ # form. TODO: Check this and probably remove it.
+ content = event.get("content")
+ # account_data has been allowed to have non-dict content, so
+ # check type first
+ if isinstance(content, dict):
+ sender = content.get("user_id")
+
+ room_id = event.get("room_id", None)
+ ev_type = event.get("type", None)
+ is_url = "url" in event.get("content", {})
return self.check_fields(
- event.get("room_id", None),
+ room_id,
sender,
- event.get("type", None),
- "url" in event.get("content", {})
+ ev_type,
+ is_url,
)
def check_fields(self, room_id, sender, event_type, contains_url):
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 19009300..a6f1e759 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -29,7 +29,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -157,7 +157,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.appservice"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -187,7 +187,11 @@ def start(config_options):
ps.start_listening(config.worker_listeners)
def run():
- with LoggingContext("run"):
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 4d081ecc..e4ea3ab9 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -29,13 +29,14 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -63,6 +64,7 @@ class ClientReaderSlavedStore(
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
+ TransactionStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
@@ -171,7 +173,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.client_reader"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -193,7 +195,11 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def run():
- with LoggingContext("run"):
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 90a48167..e52b0f24 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -31,7 +31,7 @@ from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -162,7 +162,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_reader"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -184,7 +184,11 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def run():
- with LoggingContext("run"):
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 411e47d9..76c4cc54 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -35,7 +35,7 @@ from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -160,7 +160,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_sender"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -193,7 +193,11 @@ def start(config_options):
ps.start_listening(config.worker_listeners)
def run():
- with LoggingContext("run"):
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e0b87468..2cdd2d39 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -20,6 +20,8 @@ import gc
import logging
import os
import sys
+
+import synapse.config.logger
from synapse.config._base import ConfigError
from synapse.python_dependencies import (
@@ -50,7 +52,7 @@ from synapse.api.urls import (
)
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
@@ -286,7 +288,7 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
- config.setup_logging()
+ synapse.config.logger.setup_logging(config, use_worker_options=False)
# check any extra requirements we have now we have a config
check_requirements(config)
@@ -454,7 +456,12 @@ def run(hs):
def in_thread():
# Uncomment to enable tracing of log context changes.
# sys.settrace(logcontext_tracer)
- with LoggingContext("run"):
+
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index ef17b158..1444e69a 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
@@ -32,7 +33,7 @@ from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
+ TransactionStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
@@ -168,7 +170,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.media_repository"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -190,7 +192,11 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def run():
- with LoggingContext("run"):
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 073f2c24..ab682e52 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -31,7 +31,8 @@ from synapse.storage.engines import create_engine
from synapse.storage import DataStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, preserve_fn, \
+ PreserveLoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -245,7 +246,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.pusher"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -275,7 +276,11 @@ def start(config_options):
ps.start_listening(config.worker_listeners)
def run():
- with LoggingContext("run"):
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 3f295952..34e34e55 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -20,7 +20,6 @@ from synapse.api.constants import EventTypes, PresenceState
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.events import FrozenEvent
from synapse.handlers.presence import PresenceHandler
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
@@ -48,7 +47,8 @@ from synapse.storage.presence import PresenceStore, UserPresenceState
from synapse.storage.roommember import RoomMemberStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, preserve_fn, \
+ PreserveLoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string
@@ -399,8 +399,7 @@ class SynchrotronServer(HomeServer):
position = row[position_index]
user_id = row[user_index]
- rooms = yield store.get_rooms_for_user(user_id)
- room_ids = [r.room_id for r in rooms]
+ room_ids = yield store.get_rooms_for_user(user_id)
notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
@@ -411,11 +410,16 @@ class SynchrotronServer(HomeServer):
stream = result.get("events")
if stream:
max_position = stream["position"]
+
+ event_map = yield store.get_events([row[1] for row in stream["rows"]])
+
for row in stream["rows"]:
position = row[0]
- internal = json.loads(row[1])
- event_json = json.loads(row[2])
- event = FrozenEvent(event_json, internal_metadata_dict=internal)
+ event_id = row[1]
+ event = event_map.get(event_id, None)
+ if not event:
+ continue
+
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
@@ -478,7 +482,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.synchrotron"
- setup_logging(config.worker_log_config, config.worker_log_file)
+ setup_logging(config, use_worker_options=True)
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -497,7 +501,11 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def run():
- with LoggingContext("run"):
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index c0455888..23eb6a1e 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -23,14 +23,27 @@ import signal
import subprocess
import sys
import yaml
+import errno
+import time
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m"
+YELLOW = "\x1b[1;33m"
RED = "\x1b[1;31m"
NORMAL = "\x1b[m"
+def pid_running(pid):
+ try:
+ os.kill(pid, 0)
+ return True
+ except OSError, err:
+ if err.errno == errno.EPERM:
+ return True
+ return False
+
+
def write(message, colour=NORMAL, stream=sys.stdout):
if colour == NORMAL:
stream.write(message + "\n")
@@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
stream.write(colour + message + NORMAL + "\n")
+def abort(message, colour=RED, stream=sys.stderr):
+ write(message, colour, stream)
+ sys.exit(1)
+
+
def start(configfile):
write("Starting ...")
args = SYNAPSE
@@ -45,7 +63,8 @@ def start(configfile):
try:
subprocess.check_call(args)
- write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
+ write("started synapse.app.homeserver(%r)" %
+ (configfile,), colour=GREEN)
except subprocess.CalledProcessError as e:
write(
"error starting (exit code: %d); see above for logs" % e.returncode,
@@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
def stop(pidfile, app):
if os.path.exists(pidfile):
pid = int(open(pidfile).read())
- os.kill(pid, signal.SIGTERM)
- write("stopped %s" % (app,), colour=GREEN)
+ try:
+ os.kill(pid, signal.SIGTERM)
+ write("stopped %s" % (app,), colour=GREEN)
+ except OSError, err:
+ if err.errno == errno.ESRCH:
+ write("%s not running" % (app,), colour=YELLOW)
+ elif err.errno == errno.EPERM:
+ abort("Cannot stop %s: Operation not permitted" % (app,))
+ else:
+ abort("Cannot stop %s: Unknown error" % (app,))
Worker = collections.namedtuple("Worker", [
@@ -190,7 +217,19 @@ def main():
if start_stop_synapse:
stop(pidfile, "synapse.app.homeserver")
- # TODO: Wait for synapse to actually shutdown before starting it again
+ # Wait for synapse to actually shutdown before starting it again
+ if action == "restart":
+ running_pids = []
+ if start_stop_synapse and os.path.exists(pidfile):
+ running_pids.append(int(open(pidfile).read()))
+ for worker in workers:
+ if os.path.exists(worker.pidfile):
+ running_pids.append(int(open(worker.pidfile).read()))
+ if len(running_pids) > 0:
+ write("Waiting for process to exit before restarting...")
+ for running_pid in running_pids:
+ while pid_running(running_pid):
+ time.sleep(0.2)
if action == "start" or action == "restart":
if start_stop_synapse:
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 77ded0ad..2dbeafa9 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -45,7 +45,6 @@ handlers:
maxBytes: 104857600
backupCount: 10
filters: [context]
- level: INFO
console:
class: logging.StreamHandler
formatter: precise
@@ -56,6 +55,8 @@ loggers:
level: INFO
synapse.storage.SQL:
+ # beware: increasing this to DEBUG will make synapse log sensitive
+ # information such as access tokens.
level: INFO
root:
@@ -68,6 +69,7 @@ class LoggingConfig(Config):
def read_config(self, config):
self.verbosity = config.get("verbose", 0)
+ self.no_redirect_stdio = config.get("no_redirect_stdio", False)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
@@ -77,10 +79,10 @@ class LoggingConfig(Config):
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
- # Logging verbosity level.
+ # Logging verbosity level. Ignored if log_config is specified.
verbose: 0
- # File to write logging to
+ # File to write logging to. Ignored if log_config is specified.
log_file: "%(log_file)s"
# A yaml python logging config file
@@ -90,6 +92,8 @@ class LoggingConfig(Config):
def read_arguments(self, args):
if args.verbose is not None:
self.verbosity = args.verbose
+ if args.no_redirect_stdio is not None:
+ self.no_redirect_stdio = args.no_redirect_stdio
if args.log_config is not None:
self.log_config = args.log_config
if args.log_file is not None:
@@ -99,16 +103,22 @@ class LoggingConfig(Config):
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count',
- help="The verbosity level."
+ help="The verbosity level. Specify multiple times to increase "
+ "verbosity. (Ignored if --log-config is specified.)"
)
logging_group.add_argument(
'-f', '--log-file', dest="log_file",
- help="File to log to."
+ help="File to log to. (Ignored if --log-config is specified.)"
)
logging_group.add_argument(
'--log-config', dest="log_config", default=None,
help="Python logging config file"
)
+ logging_group.add_argument(
+ '-n', '--no-redirect-stdio',
+ action='store_true', default=None,
+ help="Do not redirect stdout/stderr to the log"
+ )
def generate_files(self, config):
log_config = config.get("log_config")
@@ -118,11 +128,22 @@ class LoggingConfig(Config):
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
)
- def setup_logging(self):
- setup_logging(self.log_config, self.log_file, self.verbosity)
+def setup_logging(config, use_worker_options=False):
+ """ Set up python logging
+
+ Args:
+ config (LoggingConfig | synapse.config.workers.WorkerConfig):
+ configuration data
+
+ use_worker_options (bool): True to use 'worker_log_config' and
+ 'worker_log_file' options instead of 'log_config' and 'log_file'.
+ """
+ log_config = (config.worker_log_config if use_worker_options
+ else config.log_config)
+ log_file = (config.worker_log_file if use_worker_options
+ else config.log_file)
-def setup_logging(log_config=None, log_file=None, verbosity=None):
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
@@ -131,9 +152,9 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
level = logging.INFO
level_for_storage = logging.INFO
- if verbosity:
+ if config.verbosity:
level = logging.DEBUG
- if verbosity > 1:
+ if config.verbosity > 1:
level_for_storage = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option
@@ -153,14 +174,6 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
-
- # TODO(paul): obviously this is a terrible mechanism for
- # stealing SIGHUP, because it means no other part of synapse
- # can use it instead. If we want to catch SIGHUP anywhere
- # else as well, I'd suggest we find a nicer way to broadcast
- # it around.
- if getattr(signal, "SIGHUP"):
- signal.signal(signal.SIGHUP, sighup)
else:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
@@ -169,8 +182,25 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
logger.addHandler(handler)
else:
- with open(log_config, 'r') as f:
- logging.config.dictConfig(yaml.load(f))
+ def load_log_config():
+ with open(log_config, 'r') as f:
+ logging.config.dictConfig(yaml.load(f))
+
+ def sighup(signum, stack):
+ # it might be better to use a file watcher or something for this.
+ logging.info("Reloading log config from %s due to SIGHUP",
+ log_config)
+ load_log_config()
+
+ load_log_config()
+
+ # TODO(paul): obviously this is a terrible mechanism for
+ # stealing SIGHUP, because it means no other part of synapse
+ # can use it instead. If we want to catch SIGHUP anywhere
+ # else as well, I'd suggest we find a nicer way to broadcast
+ # it around.
+ if getattr(signal, "SIGHUP"):
+ signal.signal(signal.SIGHUP, sighup)
# It's critical to point twisted's internal logging somewhere, otherwise it
# stacks up and leaks kup to 64K object;
@@ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
#
# However this may not be too much of a problem if we are just writing to a file.
observer = STDLibLogObserver()
- globalLogBeginner.beginLoggingTo([observer])
+ globalLogBeginner.beginLoggingTo(
+ [observer],
+ redirectStandardIO=not config.no_redirect_stdio,
+ )
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d7211ee9..1bb27edc 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,7 +15,6 @@
from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes
-from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
@@ -96,10 +95,11 @@ class Keyring(object):
verify_requests = []
for server_name, json_object in server_and_json:
- logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
+ logger.warn("Request from %s: no supported signature keys",
+ server_name)
deferred = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
@@ -108,6 +108,9 @@ class Keyring(object):
else:
deferred = defer.Deferred()
+ logger.debug("Verifying for %s with key_ids %s",
+ server_name, key_ids)
+
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
@@ -142,6 +145,9 @@ class Keyring(object):
json_object = verify_request.json_object
+ logger.debug("Got key %s %s:%s for server %s, verifying" % (
+ key_id, verify_key.alg, verify_key.version, server_name,
+ ))
try:
verify_signed_json(json_object, server_name, verify_key)
except:
@@ -231,8 +237,14 @@ class Keyring(object):
d.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests):
- """Takes a dict of KeyGroups and tries to find at least one key for
- each group.
+ """Tries to find at least one key for each verify request
+
+ For each verify_request, verify_request.deferred is called back with
+ params (server_name, key_id, VerifyKey) if a key is found, or errbacked
+ with a SynapseError if none of the keys are found.
+
+ Args:
+ verify_requests (list[VerifyKeyRequest]): list of verify requests
"""
# These are functions that produce keys given a list of key ids
@@ -245,8 +257,11 @@ class Keyring(object):
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
+ # dict[str, dict[str, VerifyKey]]: results so far.
+ # map server_name -> key_id -> VerifyKey
merged_results = {}
+ # dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
@@ -308,6 +323,16 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
+ """
+
+ Args:
+ server_name_and_key_ids (list[(str, iterable[str])]):
+ list of (server_name, iterable[key_id]) tuples to fetch keys for
+
+ Returns:
+ Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+ server_name -> key_id -> VerifyKey
+ """
res = yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store.get_server_verify_keys)(
@@ -356,30 +381,24 @@ class Keyring(object):
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
- limiter = yield get_retry_limiter(
- server_name,
- self.clock,
- self.store,
- )
- with limiter:
- keys = None
- try:
- keys = yield self.get_server_verify_key_v2_direct(
- server_name, key_ids
- )
- except Exception as e:
- logger.info(
- "Unable to get key %r for %r directly: %s %s",
- key_ids, server_name,
- type(e).__name__, str(e.message),
- )
+ keys = None
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(
+ server_name, key_ids
+ )
+ except Exception as e:
+ logger.info(
+ "Unable to get key %r for %r directly: %s %s",
+ key_ids, server_name,
+ type(e).__name__, str(e.message),
+ )
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
- keys = {server_name: keys}
+ keys = {server_name: keys}
defer.returnValue(keys)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 11605b34..6be18880 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,6 +15,32 @@
class EventContext(object):
+ """
+ Attributes:
+ current_state_ids (dict[(str, str), str]):
+ The current state map including the current event.
+ (type, state_key) -> event_id
+
+ prev_state_ids (dict[(str, str), str]):
+ The current state map excluding the current event.
+ (type, state_key) -> event_id
+
+ state_group (int): state group id
+ rejected (bool|str): A rejection reason if the event was rejected, else
+ False
+
+ push_actions (list[(str, list[object])]): list of (user_id, actions)
+ tuples
+
+ prev_group (int): Previously persisted state group. ``None`` for an
+ outlier.
+ delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
+ (type, state_key) -> event_id. ``None`` for an outlier.
+
+ prev_state_events (?): XXX: is this ever set to anything other than
+ the empty list?
+ """
+
__slots__ = [
"current_state_ids",
"prev_state_ids",
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 5dcd4eec..deee0f49 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent, builder
import synapse.metrics
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+from synapse.util.retryutils import NotRetryingDestination
import copy
import itertools
@@ -88,7 +88,7 @@ class FederationClient(FederationBase):
@log_function
def make_query(self, destination, query_type, args,
- retry_on_dns_fail=False):
+ retry_on_dns_fail=False, ignore_backoff=False):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
@@ -98,6 +98,8 @@ class FederationClient(FederationBase):
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
Returns:
a Deferred which will eventually yield a JSON object from the
@@ -106,7 +108,8 @@ class FederationClient(FederationBase):
sent_queries_counter.inc(query_type)
return self.transport_layer.make_query(
- destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+ destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+ ignore_backoff=ignore_backoff,
)
@log_function
@@ -234,31 +237,24 @@ class FederationClient(FederationBase):
continue
try:
- limiter = yield get_retry_limiter(
- destination,
- self._clock,
- self.store,
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id, timeout=timeout,
)
- with limiter:
- transaction_data = yield self.transport_layer.get_event(
- destination, event_id, timeout=timeout,
- )
-
- logger.debug("transaction_data %r", transaction_data)
+ logger.debug("transaction_data %r", transaction_data)
- pdu_list = [
- self.event_from_pdu_json(p, outlier=outlier)
- for p in transaction_data["pdus"]
- ]
+ pdu_list = [
+ self.event_from_pdu_json(p, outlier=outlier)
+ for p in transaction_data["pdus"]
+ ]
- if pdu_list and pdu_list[0]:
- pdu = pdu_list[0]
+ if pdu_list and pdu_list[0]:
+ pdu = pdu_list[0]
- # Check signatures are correct.
- signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+ # Check signatures are correct.
+ signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
- break
+ break
pdu_attempts[destination] = now
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e922b7ff..bc20b9c2 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -52,7 +52,6 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
- self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._server_linearizer = Linearizer("fed_server")
# We cache responses to state queries, as they take a while and often
@@ -147,11 +146,15 @@ class FederationServer(FederationBase):
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if transaction.origin != get_domain_from_id(pdu.event_id):
+ # We continue to accept join events from any server; this is
+ # necessary for the federation join dance to work correctly.
+ # (When we join over federation, the "helper" server is
+ # responsible for sending out the join event, rather than the
+ # origin. See bug #1893).
if not (
pdu.type == 'm.room.member' and
pdu.content and
- pdu.content.get("membership", None) == 'join' and
- self.hs.is_mine_id(pdu.state_key)
+ pdu.content.get("membership", None) == 'join'
):
logger.info(
"Discarding PDU %s from invalid origin %s",
@@ -165,7 +168,7 @@ class FederationServer(FederationBase):
)
try:
- yield self._handle_new_pdu(transaction.origin, pdu)
+ yield self._handle_received_pdu(transaction.origin, pdu)
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
@@ -497,27 +500,16 @@ class FederationServer(FederationBase):
)
@defer.inlineCallbacks
- @log_function
- def _handle_new_pdu(self, origin, pdu, get_missing=True):
-
- # We reprocess pdus when we have seen them only as outliers
- existing = yield self._get_persisted_pdu(
- origin, pdu.event_id, do_auth=False
- )
-
- # FIXME: Currently we fetch an event again when we already have it
- # if it has been marked as an outlier.
+ def _handle_received_pdu(self, origin, pdu):
+ """ Process a PDU received in a federation /send/ transaction.
- already_seen = (
- existing and (
- not existing.internal_metadata.is_outlier()
- or pdu.internal_metadata.is_outlier()
- )
- )
- if already_seen:
- logger.debug("Already seen pdu %s", pdu.event_id)
- return
+ Args:
+ origin (str): server which sent the pdu
+ pdu (FrozenEvent): received pdu
+ Returns (Deferred): completes with None
+ Raises: FederationError if the signatures / hash do not match
+ """
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
@@ -529,143 +521,7 @@ class FederationServer(FederationBase):
affected=pdu.event_id,
)
- state = None
-
- auth_chain = []
-
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
- )
-
- fetch_state = False
-
- # Get missing pdus if necessary.
- if not pdu.internal_metadata.is_outlier():
- # We only backfill backwards to the min depth.
- min_depth = yield self.handler.get_min_depth_for_context(
- pdu.room_id
- )
-
- logger.debug(
- "_handle_new_pdu min_depth for %s: %d",
- pdu.room_id, min_depth
- )
-
- prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
-
- if min_depth and pdu.depth < min_depth:
- # This is so that we don't notify the user about this
- # message, to work around the fact that some events will
- # reference really really old events we really don't want to
- # send to the clients.
- pdu.internal_metadata.outlier = True
- elif min_depth and pdu.depth > min_depth:
- if get_missing and prevs - seen:
- # If we're missing stuff, ensure we only fetch stuff one
- # at a time.
- logger.info(
- "Acquiring lock for room %r to fetch %d missing events: %r...",
- pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
- )
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
- logger.info(
- "Acquired lock for room %r to fetch %d missing events",
- pdu.room_id, len(prevs - seen),
- )
-
- # We recalculate seen, since it may have changed.
- have_seen = yield self.store.have_events(prevs)
- seen = set(have_seen.keys())
-
- if prevs - seen:
- latest = yield self.store.get_latest_event_ids_in_room(
- pdu.room_id
- )
-
- # We add the prev events that we have seen to the latest
- # list to ensure the remote server doesn't give them to us
- latest = set(latest)
- latest |= seen
-
- logger.info(
- "Missing %d events for room %r: %r...",
- len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
- )
-
- # XXX: we set timeout to 10s to help workaround
- # https://github.com/matrix-org/synapse/issues/1733.
- # The reason is to avoid holding the linearizer lock
- # whilst processing inbound /send transactions, causing
- # FDs to stack up and block other inbound transactions
- # which empirically can currently take up to 30 minutes.
- #
- # N.B. this explicitly disables retry attempts.
- #
- # N.B. this also increases our chances of falling back to
- # fetching fresh state for the room if the missing event
- # can't be found, which slightly reduces our security.
- # it may also increase our DAG extremity count for the room,
- # causing additional state resolution? See #1760.
- # However, fetching state doesn't hold the linearizer lock
- # apparently.
- #
- # see https://github.com/matrix-org/synapse/pull/1744
-
- missing_events = yield self.get_missing_events(
- origin,
- pdu.room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- timeout=10000,
- )
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- missing_events.sort(key=lambda x: x.depth)
-
- for e in missing_events:
- yield self._handle_new_pdu(
- origin,
- e,
- get_missing=False
- )
-
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
- )
-
- prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
- if prevs - seen:
- logger.info(
- "Still missing %d events for room %r: %r...",
- len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
- )
- fetch_state = True
-
- if fetch_state:
- # We need to get the state at this event, since we haven't
- # processed all the prev events.
- logger.debug(
- "_handle_new_pdu getting state for %s",
- pdu.room_id
- )
- try:
- state, auth_chain = yield self.get_state_for_room(
- origin, pdu.room_id, pdu.event_id,
- )
- except:
- logger.exception("Failed to get state for event: %s", pdu.event_id)
-
- yield self.handler.on_receive_pdu(
- origin,
- pdu,
- state=state,
- auth_chain=auth_chain,
- )
+ yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5c9f7a86..bbb01952 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object):
def __init__(self, hs):
self.server_name = hs.hostname
self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
self.presence_map = {}
self.presence_changed = sorteddict()
@@ -186,6 +187,8 @@ class FederationRemoteSendQueue(object):
else:
self.edus[pos] = edu
+ self.notifier.on_new_replication_data()
+
def send_presence(self, destination, states):
"""As per TransactionQueue"""
pos = self._next_pos()
@@ -199,16 +202,20 @@ class FederationRemoteSendQueue(object):
(destination, state.user_id) for state in states
]
+ self.notifier.on_new_replication_data()
+
def send_failure(self, failure, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.failures[pos] = (destination, str(failure))
+ self.notifier.on_new_replication_data()
def send_device_messages(self, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.device_messages[pos] = destination
+ self.notifier.on_new_replication_data()
def get_current_token(self):
return self.pos - 1
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 90235ff0..c27ce7c5 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import datetime
from twisted.internet import defer
@@ -22,9 +22,7 @@ from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util.retryutils import (
- get_retry_limiter, NotRetryingDestination,
-)
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state
@@ -99,7 +97,12 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
+ # destination -> stream_id of last successfully sent to-device message.
+ # NB: may be a long or an int.
self.last_device_stream_id_by_dest = {}
+
+ # destination -> stream_id of last successfully sent device list
+ # update.
self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id
@@ -300,20 +303,20 @@ class TransactionQueue(object):
)
return
+ pending_pdus = []
try:
self.pending_transactions[destination] = 1
+ # This will throw if we wouldn't retry. We do this here so we fail
+ # quickly, but we will later check this again in the http client,
+ # hence why we throw the result away.
+ yield get_retry_limiter(destination, self.clock, self.store)
+
# XXX: what's this for?
yield run_on_reactor()
+ pending_pdus = []
while True:
- limiter = yield get_retry_limiter(
- destination,
- self.clock,
- self.store,
- backoff_on_404=True, # If we get a 404 the other side has gone
- )
-
device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination)
)
@@ -369,7 +372,6 @@ class TransactionQueue(object):
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
- limiter=limiter,
)
if success:
# Remove the acknowledged device messages from the database
@@ -387,12 +389,24 @@ class TransactionQueue(object):
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else:
break
- except NotRetryingDestination:
+ except NotRetryingDestination as e:
logger.debug(
- "TX [%s] not ready for retry yet - "
+ "TX [%s] not ready for retry yet (next retry at %s) - "
"dropping transaction for now",
destination,
+ datetime.datetime.fromtimestamp(
+ (e.retry_last_ts + e.retry_interval) / 1000.0
+ ),
)
+ except Exception as e:
+ logger.warn(
+ "TX [%s] Failed to send transaction: %s",
+ destination,
+ e,
+ )
+ for p, _ in pending_pdus:
+ logger.info("Failed to send event %s to %s", p.event_id,
+ destination)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
@@ -432,7 +446,7 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
- pending_failures, limiter):
+ pending_failures):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
@@ -442,132 +456,104 @@ class TransactionQueue(object):
success = True
- try:
- logger.debug("TX [%s] _attempt_new_transaction", destination)
+ logger.debug("TX [%s] _attempt_new_transaction", destination)
- txn_id = str(self._next_txn_id)
+ txn_id = str(self._next_txn_id)
- logger.debug(
- "TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
- destination, txn_id,
- len(pdus),
- len(edus),
- len(failures)
- )
+ logger.debug(
+ "TX [%s] {%s} Attempting new transaction"
+ " (pdus: %d, edus: %d, failures: %d)",
+ destination, txn_id,
+ len(pdus),
+ len(edus),
+ len(failures)
+ )
- logger.debug("TX [%s] Persisting transaction...", destination)
+ logger.debug("TX [%s] Persisting transaction...", destination)
- transaction = Transaction.create_new(
- origin_server_ts=int(self.clock.time_msec()),
- transaction_id=txn_id,
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
- )
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self.clock.time_msec()),
+ transaction_id=txn_id,
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
- self._next_txn_id += 1
+ self._next_txn_id += 1
- yield self.transaction_actions.prepare_to_send(transaction)
+ yield self.transaction_actions.prepare_to_send(transaction)
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
- destination, txn_id,
- transaction.transaction_id,
- len(pdus),
- len(edus),
- len(failures),
- )
+ logger.debug("TX [%s] Persisted transaction", destination)
+ logger.info(
+ "TX [%s] {%s} Sending transaction [%s],"
+ " (PDUs: %d, EDUs: %d, failures: %d)",
+ destination, txn_id,
+ transaction.transaction_id,
+ len(pdus),
+ len(edus),
+ len(failures),
+ )
- with limiter:
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self.clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
- )
- code = 200
-
- if response:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warn(
- "Transaction returned error for %s: %s",
- e_id, r,
- )
- except HttpResponseException as e:
- code = e.code
- response = e.response
-
- if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self.clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+ code = 200
+
+ if response:
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
+ logger.warn(
+ "Transaction returned error for %s: %s",
+ e_id, r,
)
- raise e
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
+ if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
+ raise e
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
+ logger.info(
+ "TX [%s] {%s} got %d response",
+ destination, txn_id, code
+ )
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
- logger.debug("TX [%s] Marked as delivered", destination)
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
- if code != 200:
- for p in pdus:
- logger.info(
- "Failed to send event %s to %s", p.event_id, destination
- )
- success = False
- except RuntimeError as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
-
- success = False
+ logger.debug("TX [%s] Marked as delivered", destination)
+ if code != 200:
for p in pdus:
- logger.info("Failed to send event %s to %s", p.event_id, destination)
- except Exception as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
-
+ logger.info(
+ "Failed to send event %s to %s", p.event_id, destination
+ )
success = False
- for p in pdus:
- logger.info("Failed to send event %s to %s", p.event_id, destination)
-
defer.returnValue(success)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index f49e8a2c..15a03378 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -163,6 +163,7 @@ class TransportLayerClient(object):
data=json_data,
json_data_callback=json_data_callback,
long_retries=True,
+ backoff_on_404=True, # If we get a 404 the other side has gone
)
logger.debug(
@@ -174,7 +175,8 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def make_query(self, destination, query_type, args, retry_on_dns_fail):
+ def make_query(self, destination, query_type, args, retry_on_dns_fail,
+ ignore_backoff=False):
path = PREFIX + "/query/%s" % query_type
content = yield self.client.get_json(
@@ -183,6 +185,7 @@ class TransportLayerClient(object):
args=args,
retry_on_dns_fail=retry_on_dns_fail,
timeout=10000,
+ ignore_backoff=ignore_backoff,
)
defer.returnValue(content)
@@ -242,6 +245,7 @@ class TransportLayerClient(object):
destination=destination,
path=path,
data=content,
+ ignore_backoff=True,
)
defer.returnValue(response)
@@ -269,6 +273,7 @@ class TransportLayerClient(object):
destination=remote_server,
path=path,
args=args,
+ ignore_backoff=True,
)
defer.returnValue(response)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fffba343..e7a1bb72 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -47,6 +48,7 @@ class AuthHandler(BaseHandler):
LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
+ LoginType.MSISDN: self._check_msisdn,
LoginType.DUMMY: self._check_dummy_auth,
}
self.bcrypt_rounds = hs.config.bcrypt_rounds
@@ -307,31 +309,47 @@ class AuthHandler(BaseHandler):
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
- @defer.inlineCallbacks
def _check_email_identity(self, authdict, _):
+ return self._check_threepid('email', authdict)
+
+ def _check_msisdn(self, authdict, _):
+ return self._check_threepid('msisdn', authdict)
+
+ @defer.inlineCallbacks
+ def _check_dummy_auth(self, authdict, _):
+ yield run_on_reactor()
+ defer.returnValue(True)
+
+ @defer.inlineCallbacks
+ def _check_threepid(self, medium, authdict):
yield run_on_reactor()
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds']
+
identity_handler = self.hs.get_handlers().identity_handler
- logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
+ logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+ if threepid['medium'] != medium:
+ raise LoginError(
+ 401,
+ "Expecting threepid of type '%s', got '%s'" % (
+ medium, threepid['medium'],
+ ),
+ errcode=Codes.UNAUTHORIZED
+ )
+
threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid)
- @defer.inlineCallbacks
- def _check_dummy_auth(self, authdict, _):
- yield run_on_reactor()
- defer.returnValue(True)
-
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index e859b316..c22f65ce 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -170,6 +170,40 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
+ def delete_devices(self, user_id, device_ids):
+ """ Delete several devices
+
+ Args:
+ user_id (str):
+ device_ids (str): The list of device IDs to delete
+
+ Returns:
+ defer.Deferred:
+ """
+
+ try:
+ yield self.store.delete_devices(user_id, device_ids)
+ except errors.StoreError, e:
+ if e.code == 404:
+ # no match
+ pass
+ else:
+ raise
+
+ # Delete access tokens and e2e keys for each device. Not optimised as it is not
+ # considered as part of a critical path.
+ for device_id in device_ids:
+ yield self.store.user_delete_access_tokens(
+ user_id, device_id=device_id,
+ delete_refresh_tokens=True,
+ )
+ yield self.store.delete_e2e_keys_by_device(
+ user_id=user_id, device_id=device_id
+ )
+
+ yield self.notify_device_update(user_id, device_ids)
+
+ @defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
@@ -214,8 +248,7 @@ class DeviceHandler(BaseHandler):
user_id, device_ids, list(hosts)
)
- rooms = yield self.store.get_rooms_for_user(user_id)
- room_ids = [r.room_id for r in rooms]
+ room_ids = yield self.store.get_rooms_for_user(user_id)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
@@ -236,8 +269,7 @@ class DeviceHandler(BaseHandler):
user_id (str)
from_token (StreamToken)
"""
- rooms = yield self.store.get_rooms_for_user(user_id)
- room_ids = set(r.room_id for r in rooms)
+ room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
@@ -262,7 +294,7 @@ class DeviceHandler(BaseHandler):
# ordering: treat it the same as a new room
event_ids = []
- current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = yield self.store.get_current_state_ids(room_id)
# special-case for an empty prev state: include all members
# in the changed list
@@ -313,8 +345,8 @@ class DeviceHandler(BaseHandler):
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
user_id = user.to_string()
- rooms = yield self.store.get_rooms_for_user(user_id)
- if not rooms:
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
@@ -370,8 +402,8 @@ class DeviceListEduUpdater(object):
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
- rooms = yield self.store.get_rooms_for_user(user_id)
- if not rooms:
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1b5317ed..943554ce 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
"room_alias": room_alias.to_string(),
},
retry_on_dns_fail=False,
+ ignore_backoff=True,
)
except CodeMessageException as e:
logging.warn("Error retrieving alias")
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index e40495d1..c2b38d72 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -121,15 +121,11 @@ class E2eKeysHandler(object):
def do_remote_query(destination):
destination_query = remote_queries_not_in_cache[destination]
try:
- limiter = yield get_retry_limiter(
- destination, self.clock, self.store
+ remote_result = yield self.federation.query_client_keys(
+ destination,
+ {"device_keys": destination_query},
+ timeout=timeout
)
- with limiter:
- remote_result = yield self.federation.query_client_keys(
- destination,
- {"device_keys": destination_query},
- timeout=timeout
- )
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
@@ -239,18 +235,14 @@ class E2eKeysHandler(object):
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
- limiter = yield get_retry_limiter(
- destination, self.clock, self.store
+ remote_result = yield self.federation.claim_client_keys(
+ destination,
+ {"one_time_keys": device_keys},
+ timeout=timeout
)
- with limiter:
- remote_result = yield self.federation.claim_client_keys(
- destination,
- {"one_time_keys": device_keys},
- timeout=timeout
- )
- for user_id, keys in remote_result["one_time_keys"].items():
- if user_id in device_keys:
- json_result[user_id] = keys
+ for user_id, keys in remote_result["one_time_keys"].items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
@@ -316,7 +308,7 @@ class E2eKeysHandler(object):
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
- self.device_handler.check_device_registered(user_id, device_id)
+ yield self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ed0fa51e..53f92963 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Contains handlers for federation events."""
+import synapse.util.logcontext
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -31,7 +32,7 @@ from synapse.util.logcontext import (
)
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor
+from synapse.util.async import run_on_reactor, Linearizer
from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures,
@@ -79,29 +80,216 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up
self.room_queues = {}
+ self._room_pdu_linearizer = Linearizer("fed_room_pdu")
- @log_function
@defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
- """ Called by the ReplicationLayer when we have a new pdu. We need to
- do auth checks and put it through the StateHandler.
+ @log_function
+ def on_receive_pdu(self, origin, pdu, get_missing=True):
+ """ Process a PDU received via a federation /send/ transaction, or
+ via backfill of missing prev_events
+
+ Args:
+ origin (str): server which initiated the /send/ transaction. Will
+ be used to fetch missing events or state.
+ pdu (FrozenEvent): received PDU
+ get_missing (bool): True if we should fetch missing prev_events
- auth_chain and state are None if we already have the necessary state
- and prev_events in the db
+ Returns (Deferred): completes with None
"""
- event = pdu
- logger.debug("Got event: %s", event.event_id)
+ # We reprocess pdus when we have seen them only as outliers
+ existing = yield self.get_persisted_pdu(
+ origin, pdu.event_id, do_auth=False
+ )
+
+ # FIXME: Currently we fetch an event again when we already have it
+ # if it has been marked as an outlier.
+
+ already_seen = (
+ existing and (
+ not existing.internal_metadata.is_outlier()
+ or pdu.internal_metadata.is_outlier()
+ )
+ )
+ if already_seen:
+ logger.debug("Already seen pdu %s", pdu.event_id)
+ return
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
- if event.room_id in self.room_queues:
- self.room_queues[event.room_id].append((pdu, origin))
+ if pdu.room_id in self.room_queues:
+ logger.info("Ignoring PDU %s for room %s from %s for now; join "
+ "in progress", pdu.event_id, pdu.room_id, origin)
+ self.room_queues[pdu.room_id].append((pdu, origin))
return
- logger.debug("Processing event: %s", event.event_id)
+ state = None
+
+ auth_chain = []
+
+ have_seen = yield self.store.have_events(
+ [ev for ev, _ in pdu.prev_events]
+ )
+
+ fetch_state = False
+
+ # Get missing pdus if necessary.
+ if not pdu.internal_metadata.is_outlier():
+ # We only backfill backwards to the min depth.
+ min_depth = yield self.get_min_depth_for_context(
+ pdu.room_id
+ )
+
+ logger.debug(
+ "_handle_new_pdu min_depth for %s: %d",
+ pdu.room_id, min_depth
+ )
+
+ prevs = {e_id for e_id, _ in pdu.prev_events}
+ seen = set(have_seen.keys())
+
+ if min_depth and pdu.depth < min_depth:
+ # This is so that we don't notify the user about this
+ # message, to work around the fact that some events will
+ # reference really really old events we really don't want to
+ # send to the clients.
+ pdu.internal_metadata.outlier = True
+ elif min_depth and pdu.depth > min_depth:
+ if get_missing and prevs - seen:
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ logger.info(
+ "Acquiring lock for room %r to fetch %d missing events: %r...",
+ pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
+ )
+ with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ logger.info(
+ "Acquired lock for room %r to fetch %d missing events",
+ pdu.room_id, len(prevs - seen),
+ )
+
+ yield self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+
+ prevs = {e_id for e_id, _ in pdu.prev_events}
+ seen = set(have_seen.keys())
+ if prevs - seen:
+ logger.info(
+ "Still missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
+ fetch_state = True
+
+ if fetch_state:
+ # We need to get the state at this event, since we haven't
+ # processed all the prev events.
+ logger.debug(
+ "_handle_new_pdu getting state for %s",
+ pdu.room_id
+ )
+ try:
+ state, auth_chain = yield self.replication_layer.get_state_for_room(
+ origin, pdu.room_id, pdu.event_id,
+ )
+ except:
+ logger.exception("Failed to get state for event: %s", pdu.event_id)
+
+ yield self._process_received_pdu(
+ origin,
+ pdu,
+ state=state,
+ auth_chain=auth_chain,
+ )
+
+ @defer.inlineCallbacks
+ def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ """
+ Args:
+ origin (str): Origin of the pdu. Will be called to get the missing events
+ pdu: received pdu
+ prevs (str[]): List of event ids which we are missing
+ min_depth (int): Minimum depth of events to return.
+
+ Returns:
+ Deferred<dict(str, str?)>: updated have_seen dictionary
+ """
+ # We recalculate seen, since it may have changed.
+ have_seen = yield self.store.have_events(prevs)
+ seen = set(have_seen.keys())
- logger.debug("Event: %s", event)
+ if not prevs - seen:
+ # nothing left to do
+ defer.returnValue(have_seen)
+
+ latest = yield self.store.get_latest_event_ids_in_room(
+ pdu.room_id
+ )
+
+ # We add the prev events that we have seen to the latest
+ # list to ensure the remote server doesn't give them to us
+ latest = set(latest)
+ latest |= seen
+
+ logger.info(
+ "Missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
+
+ # XXX: we set timeout to 10s to help workaround
+ # https://github.com/matrix-org/synapse/issues/1733.
+ # The reason is to avoid holding the linearizer lock
+ # whilst processing inbound /send transactions, causing
+ # FDs to stack up and block other inbound transactions
+ # which empirically can currently take up to 30 minutes.
+ #
+ # N.B. this explicitly disables retry attempts.
+ #
+ # N.B. this also increases our chances of falling back to
+ # fetching fresh state for the room if the missing event
+ # can't be found, which slightly reduces our security.
+ # it may also increase our DAG extremity count for the room,
+ # causing additional state resolution? See #1760.
+ # However, fetching state doesn't hold the linearizer lock
+ # apparently.
+ #
+ # see https://github.com/matrix-org/synapse/pull/1744
+
+ missing_events = yield self.replication_layer.get_missing_events(
+ origin,
+ pdu.room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ timeout=10000,
+ )
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ missing_events.sort(key=lambda x: x.depth)
+
+ for e in missing_events:
+ yield self.on_receive_pdu(
+ origin,
+ e,
+ get_missing=False
+ )
+
+ have_seen = yield self.store.have_events(
+ [ev for ev, _ in pdu.prev_events]
+ )
+ defer.returnValue(have_seen)
+
+ @log_function
+ @defer.inlineCallbacks
+ def _process_received_pdu(self, origin, pdu, state, auth_chain):
+ """ Called when we have a new pdu. We need to do auth checks and put it
+ through the StateHandler.
+ """
+ event = pdu
+
+ logger.debug("Processing event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work
@@ -670,8 +858,6 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
- yield self.store.clean_room_for_join(room_id)
-
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
@@ -680,7 +866,15 @@ class FederationHandler(BaseHandler):
content,
)
+ # This shouldn't happen, because the RoomMemberHandler has a
+ # linearizer lock which only allows one operation per user per room
+ # at a time - so this is just paranoia.
+ assert (room_id not in self.room_queues)
+
self.room_queues[room_id] = []
+
+ yield self.store.clean_room_for_join(room_id)
+
handled_events = set()
try:
@@ -733,18 +927,37 @@ class FederationHandler(BaseHandler):
room_queue = self.room_queues[room_id]
del self.room_queues[room_id]
- for p, origin in room_queue:
- if p.event_id in handled_events:
- continue
+ # we don't need to wait for the queued events to be processed -
+ # it's just a best-effort thing at this point. We do want to do
+ # them roughly in order, though, otherwise we'll end up making
+ # lots of requests for missing prev_events which we do actually
+ # have. Hence we fire off the deferred, but don't wait for it.
- try:
- self.on_receive_pdu(origin, p)
- except:
- logger.exception("Couldn't handle pdu")
+ synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
+ room_queue
+ )
defer.returnValue(True)
@defer.inlineCallbacks
+ def _handle_queued_pdus(self, room_queue):
+ """Process PDUs which got queued up while we were busy send_joining.
+
+ Args:
+ room_queue (list[FrozenEvent, str]): list of PDUs to be processed
+ and the servers that sent them
+ """
+ for p, origin in room_queue:
+ try:
+ logger.info("Processing queued PDU %s which was received "
+ "while we were joining %s", p.event_id, p.room_id)
+ yield self.on_receive_pdu(origin, p)
+ except Exception as e:
+ logger.warn(
+ "Error handling queued PDU %s from %s: %s",
+ p.event_id, origin, e)
+
+ @defer.inlineCallbacks
@log_function
def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial
@@ -791,9 +1004,19 @@ class FederationHandler(BaseHandler):
)
event.internal_metadata.outlier = False
- # Send this event on behalf of the origin server since they may not
- # have an up to data view of the state of the room at this event so
- # will not know which servers to send the event to.
+ # Send this event on behalf of the origin server.
+ #
+ # The reasons we have the destination server rather than the origin
+ # server send it are slightly mysterious: the origin server should have
+ # all the neccessary state once it gets the response to the send_join,
+ # so it could send the event itself if it wanted to. It may be that
+ # doing it this way reduces failure modes, or avoids certain attacks
+ # where a new server selectively tells a subset of the federation that
+ # it has joined.
+ #
+ # The fact is that, as of the current writing, Synapse doesn't send out
+ # the join event over federation after joining, and changing it now
+ # would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin
context, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -878,15 +1101,15 @@ class FederationHandler(BaseHandler):
user_id,
"leave"
)
- signed_event = self._sign_event(event)
+ event = self._sign_event(event)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
- # Try the host we successfully got a response to /make_join/
- # request first.
+ # Try the host that we succesfully called /make_leave/ on first for
+ # the /send_leave/ request.
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
@@ -896,7 +1119,7 @@ class FederationHandler(BaseHandler):
try:
yield self.replication_layer.send_leave(
target_hosts,
- signed_event
+ event
)
except SynapseError:
raise
@@ -1325,7 +1548,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None):
+ """
+
+ Args:
+ origin:
+ event:
+ state:
+ auth_events:
+ Returns:
+ Deferred, which resolves to synapse.events.snapshot.EventContext
+ """
context = yield self.state_handler.compute_event_context(
event, old_state=state,
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 559e5d5a..6a53c5eb 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -150,7 +151,7 @@ class IdentityHandler(BaseHandler):
params.update(kwargs)
try:
- data = yield self.http_client.post_urlencoded_get_json(
+ data = yield self.http_client.post_json_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/email/requestToken"
@@ -161,3 +162,37 @@ class IdentityHandler(BaseHandler):
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e
+
+ @defer.inlineCallbacks
+ def requestMsisdnToken(
+ self, id_server, country, phone_number,
+ client_secret, send_attempt, **kwargs
+ ):
+ yield run_on_reactor()
+
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server,
+ Codes.SERVER_NOT_TRUSTED
+ )
+
+ params = {
+ 'country': country,
+ 'phone_number': phone_number,
+ 'client_secret': client_secret,
+ 'send_attempt': send_attempt,
+ }
+ params.update(kwargs)
+
+ try:
+ data = yield self.http_client.post_json_get_json(
+ "https://%s%s" % (
+ id_server,
+ "/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ ),
+ params
+ )
+ defer.returnValue(data)
+ except CodeMessageException as e:
+ logger.info("Proxied requestToken failed: %r", e)
+ raise e
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e0ade4c1..10f5f35a 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
+from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, StreamToken,
@@ -225,9 +226,17 @@ class InitialSyncHandler(BaseHandler):
"content": content,
})
+ now = self.clock.time_msec()
+
ret = {
"rooms": rooms_ret,
- "presence": presence,
+ "presence": [
+ {
+ "type": "m.presence",
+ "content": format_user_presence_state(event, now),
+ }
+ for event in presence
+ ],
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index da610e43..1ede117c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -29,6 +29,7 @@ from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from synapse.storage.presence import UserPresenceState
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
@@ -556,9 +557,9 @@ class PresenceHandler(object):
room_ids_to_states = {}
users_to_states = {}
for state in states:
- events = yield self.store.get_rooms_for_user(state.user_id)
- for e in events:
- room_ids_to_states.setdefault(e.room_id, []).append(state)
+ room_ids = yield self.store.get_rooms_for_user(state.user_id)
+ for room_id in room_ids:
+ room_ids_to_states.setdefault(room_id, []).append(state)
plist = yield self.store.get_presence_list_observers_accepted(state.user_id)
for u in plist:
@@ -574,8 +575,7 @@ class PresenceHandler(object):
if not local_states:
continue
- users = yield self.store.get_users_in_room(room_id)
- hosts = set(get_domain_from_id(u) for u in users)
+ hosts = yield self.store.get_hosts_in_room(room_id)
for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)
@@ -719,9 +719,7 @@ class PresenceHandler(object):
for state in updates
])
else:
- defer.returnValue([
- format_user_presence_state(state, now) for state in updates
- ])
+ defer.returnValue(updates)
@defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False):
@@ -795,6 +793,9 @@ class PresenceHandler(object):
as_event=False,
)
+ now = self.clock.time_msec()
+ results[:] = [format_user_presence_state(r, now) for r in results]
+
is_accepted = {
row["observed_user_id"]: row["accepted"] for row in presence_list
}
@@ -847,6 +848,7 @@ class PresenceHandler(object):
)
state_dict = yield self.get_state(observed_user, as_event=False)
+ state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
self.federation.send_edu(
destination=observer_user.domain,
@@ -910,11 +912,12 @@ class PresenceHandler(object):
def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence.
"""
- observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string())
- observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string())
-
- observer_room_ids = set(r.room_id for r in observer_rooms)
- observed_room_ids = set(r.room_id for r in observed_rooms)
+ observer_room_ids = yield self.store.get_rooms_for_user(
+ observer_user.to_string()
+ )
+ observed_room_ids = yield self.store.get_rooms_for_user(
+ observed_user.to_string()
+ )
if observer_room_ids & observed_room_ids:
defer.returnValue(True)
@@ -979,14 +982,18 @@ def should_notify(old_state, new_state):
return False
-def format_user_presence_state(state, now):
+def format_user_presence_state(state, now, include_user_id=True):
"""Convert UserPresenceState to a format that can be sent down to clients
and to other servers.
+
+ The "user_id" is optional so that this function can be used to format presence
+ updates for client /sync responses and for federation /send requests.
"""
content = {
"presence": state.state,
- "user_id": state.user_id,
}
+ if include_user_id:
+ content["user_id"] = state.user_id
if state.last_active_ts:
content["last_active_ago"] = now - state.last_active_ts
if state.status_msg and state.state != PresenceState.OFFLINE:
@@ -1025,7 +1032,6 @@ class PresenceEventSource(object):
# sending down the rare duplicate is not a concern.
with Measure(self.clock, "presence.get_new_events"):
- user_id = user.to_string()
if from_key is not None:
from_key = int(from_key)
@@ -1034,18 +1040,7 @@ class PresenceEventSource(object):
max_token = self.store.get_current_presence_token()
- plist = yield self.store.get_presence_list_accepted(user.localpart)
- users_interested_in = set(row["observed_user_id"] for row in plist)
- users_interested_in.add(user_id) # So that we receive our own presence
-
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id
- )
- users_interested_in.update(users_who_share_room)
-
- if explicit_room_id:
- user_ids = yield self.store.get_users_in_room(explicit_room_id)
- users_interested_in.update(user_ids)
+ users_interested_in = yield self._get_interested_in(user, explicit_room_id)
user_ids_changed = set()
changed = None
@@ -1073,16 +1068,13 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed)
- now = self.clock.time_msec()
-
- defer.returnValue(([
- {
- "type": "m.presence",
- "content": format_user_presence_state(s, now),
- }
- for s in updates.values()
- if include_offline or s.state != PresenceState.OFFLINE
- ], max_token))
+ if include_offline:
+ defer.returnValue((updates.values(), max_token))
+ else:
+ defer.returnValue(([
+ s for s in updates.itervalues()
+ if s.state != PresenceState.OFFLINE
+ ], max_token))
def get_current_key(self):
return self.store.get_current_presence_token()
@@ -1090,6 +1082,31 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key):
return self.get_new_events(user, from_key=None, include_offline=False)
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _get_interested_in(self, user, explicit_room_id, cache_context):
+ """Returns the set of users that the given user should see presence
+ updates for
+ """
+ user_id = user.to_string()
+ plist = yield self.store.get_presence_list_accepted(
+ user.localpart, on_invalidate=cache_context.invalidate,
+ )
+ users_interested_in = set(row["observed_user_id"] for row in plist)
+ users_interested_in.add(user_id) # So that we receive our own presence
+
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id, on_invalidate=cache_context.invalidate,
+ )
+ users_interested_in.update(users_who_share_room)
+
+ if explicit_room_id:
+ user_ids = yield self.store.get_users_in_room(
+ explicit_room_id, on_invalidate=cache_context.invalidate,
+ )
+ users_interested_in.update(user_ids)
+
+ defer.returnValue(users_interested_in)
+
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as
@@ -1157,7 +1174,10 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
# If there are have been no sync for a while (and none ongoing),
# set presence to offline
if user_id not in syncing_user_ids:
- if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
+ # If the user has done something recently but hasn't synced,
+ # don't set them as offline.
+ sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
+ if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace(
state=PresenceState.OFFLINE,
status_msg=None,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 87f74dfb..9bf638f8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
args={
"user_id": target_user.to_string(),
"field": "displayname",
- }
+ },
+ ignore_backoff=True,
)
except CodeMessageException as e:
if e.code != 404:
@@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
args={
"user_id": target_user.to_string(),
"field": "avatar_url",
- }
+ },
+ ignore_backoff=True,
)
except CodeMessageException as e:
if e.code != 404:
@@ -156,11 +158,11 @@ class ProfileHandler(BaseHandler):
self.ratelimit(requester)
- joins = yield self.store.get_rooms_for_user(
+ room_ids = yield self.store.get_rooms_for_user(
user.to_string(),
)
- for j in joins:
+ for room_id in room_ids:
handler = self.hs.get_handlers().room_member_handler
try:
# Assume the user isn't a guest because we don't let guests set
@@ -171,12 +173,12 @@ class ProfileHandler(BaseHandler):
yield handler.update_membership(
requester,
user,
- j.room_id,
+ room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
)
except Exception as e:
logger.warn(
"Failed to update join event for room %s - %s",
- j.room_id, str(e.message)
+ room_id, str(e.message)
)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 50aa5139..e1cd3a48 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -210,10 +210,9 @@ class ReceiptEventSource(object):
else:
from_key = None
- rooms = yield self.store.get_rooms_for_user(user.to_string())
- rooms = [room.room_id for room in rooms]
+ room_ids = yield self.store.get_rooms_for_user(user.to_string())
events = yield self.store.get_linearized_receipts_for_rooms(
- rooms,
+ room_ids,
from_key=from_key,
to_key=to_key,
)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 19eebbd4..516cd9a6 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -21,6 +21,7 @@ from synapse.api.constants import (
EventTypes, JoinRules,
)
from synapse.util.async import concurrently_execute
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID
@@ -62,6 +63,10 @@ class RoomListHandler(BaseHandler):
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
"""
+ logger.info(
+ "Getting public room list: limit=%r, since=%r, search=%r, network=%r",
+ limit, since_token, bool(search_filter), network_tuple,
+ )
if search_filter:
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
@@ -91,7 +96,6 @@ class RoomListHandler(BaseHandler):
rooms_to_order_value = {}
rooms_to_num_joined = {}
- rooms_to_latest_event_ids = {}
newly_visible = []
newly_unpublished = []
@@ -116,19 +120,26 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_order_for_room(room_id):
- latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
- if not latest_event_ids:
+ # Most of the rooms won't have changed between the since token and
+ # now (especially if the since token is "now"). So, we can ask what
+ # the current users are in a room (that will hit a cache) and then
+ # check if the room has changed since the since token. (We have to
+ # do it in that order to avoid races).
+ # If things have changed then fall back to getting the current state
+ # at the since token.
+ joined_users = yield self.store.get_users_in_room(room_id)
+ if self.store.has_room_changed_since(room_id, stream_token):
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
- rooms_to_latest_event_ids[room_id] = latest_event_ids
- if not latest_event_ids:
- return
+ if not latest_event_ids:
+ return
+
+ joined_users = yield self.state_handler.get_current_user_in_room(
+ room_id, latest_event_ids,
+ )
- joined_users = yield self.state_handler.get_current_user_in_room(
- room_id, latest_event_ids,
- )
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
@@ -165,19 +176,19 @@ class RoomListHandler(BaseHandler):
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
- # Actually generate the entries. _generate_room_entry will append to
+ # Actually generate the entries. _append_room_entry_to_chunk will append to
# chunk but will stop if len(chunk) > limit
chunk = []
if limit and not search_filter:
step = limit + 1
for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop
- # at first iteration, but occaisonally _generate_room_entry
+ # at first iteration, but occaisonally _append_room_entry_to_chunk
# won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
yield concurrently_execute(
- lambda r: self._generate_room_entry(
+ lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
@@ -187,7 +198,7 @@ class RoomListHandler(BaseHandler):
break
else:
yield concurrently_execute(
- lambda r: self._generate_room_entry(
+ lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
@@ -256,21 +267,35 @@ class RoomListHandler(BaseHandler):
defer.returnValue(results)
@defer.inlineCallbacks
- def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
- search_filter):
+ def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
+ search_filter):
+ """Generate the entry for a room in the public room list and append it
+ to the `chunk` if it matches the search filter
+ """
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
+ result = yield self._generate_room_entry(room_id, num_joined_users)
+
+ if result and _matches_room_entry(result, search_filter):
+ chunk.append(result)
+
+ @cachedInlineCallbacks(num_args=1, cache_context=True)
+ def _generate_room_entry(self, room_id, num_joined_users, cache_context):
+ """Returns the entry for a room
+ """
result = {
"room_id": room_id,
"num_joined_members": num_joined_users,
}
- current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
+ current_state_ids = yield self.store.get_current_state_ids(
+ room_id, on_invalidate=cache_context.invalidate,
+ )
event_map = yield self.store.get_events([
- event_id for key, event_id in current_state_ids.items()
+ event_id for key, event_id in current_state_ids.iteritems()
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
@@ -294,7 +319,9 @@ class RoomListHandler(BaseHandler):
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
- aliases = yield self.store.get_aliases_for_room(room_id)
+ aliases = yield self.store.get_aliases_for_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
if aliases:
result["aliases"] = aliases
@@ -334,8 +361,7 @@ class RoomListHandler(BaseHandler):
if avatar_url:
result["avatar_url"] = avatar_url
- if _matches_room_entry(result, search_filter):
- chunk.append(result)
+ defer.returnValue(result)
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5572cb88..c0205da1 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client
+from synapse.types import RoomStreamToken
from twisted.internet import defer
@@ -225,8 +226,7 @@ class SyncHandler(object):
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
- rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
- room_ids = [room.room_id for room in rooms]
+ room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events(
@@ -568,16 +568,15 @@ class SyncHandler(object):
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
- rooms = yield self.store.get_rooms_for_user(user_id)
- room_ids = set(r.room_id for r in rooms)
+ room_ids = yield self.store.get_rooms_for_user(user_id)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
for other_user_id in changed:
- other_rooms = yield self.store.get_rooms_for_user(other_user_id)
- if room_ids.intersection(e.room_id for e in other_rooms):
+ other_room_ids = yield self.store.get_rooms_for_user(other_user_id)
+ if room_ids.intersection(other_room_ids):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
@@ -721,14 +720,14 @@ class SyncHandler(object):
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())
- states = yield self.presence_handler.get_states(
- extra_users_ids,
- as_event=True,
- )
- presence.extend(states)
+ if extra_users_ids:
+ states = yield self.presence_handler.get_states(
+ extra_users_ids,
+ )
+ presence.extend(states)
- # Deduplicate the presence entries so that there's at most one per user
- presence = {p["content"]["user_id"]: p for p in presence}.values()
+ # Deduplicate the presence entries so that there's at most one per user
+ presence = {p.user_id: p for p in presence}.values()
presence = sync_config.filter_collection.filter_presence(
presence
@@ -765,6 +764,21 @@ class SyncHandler(object):
)
sync_result_builder.now_token = now_token
+ # We check up front if anything has changed, if it hasn't then there is
+ # no point in going futher.
+ since_token = sync_result_builder.since_token
+ if not sync_result_builder.full_state:
+ if since_token and not ephemeral_by_room and not account_data_by_room:
+ have_changed = yield self._have_rooms_changed(sync_result_builder)
+ if not have_changed:
+ tags_by_room = yield self.store.get_updated_tags(
+ user_id,
+ since_token.account_data_key,
+ )
+ if not tags_by_room:
+ logger.debug("no-oping sync")
+ defer.returnValue(([], []))
+
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id,
)
@@ -774,13 +788,12 @@ class SyncHandler(object):
else:
ignored_users = frozenset()
- if sync_result_builder.since_token:
+ if since_token:
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms = res
tags_by_room = yield self.store.get_updated_tags(
- user_id,
- sync_result_builder.since_token.account_data_key,
+ user_id, since_token.account_data_key,
)
else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
@@ -805,7 +818,7 @@ class SyncHandler(object):
# Now we want to get any newly joined users
newly_joined_users = set()
- if sync_result_builder.since_token:
+ if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
joined_sync.timeline.events, joined_sync.state.values()
@@ -818,6 +831,38 @@ class SyncHandler(object):
defer.returnValue((newly_joined_rooms, newly_joined_users))
@defer.inlineCallbacks
+ def _have_rooms_changed(self, sync_result_builder):
+ """Returns whether there may be any new events that should be sent down
+ the sync. Returns True if there are.
+ """
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
+
+ assert since_token
+
+ # Get a list of membership change events that have happened.
+ rooms_changed = yield self.store.get_membership_changes_for_user(
+ user_id, since_token.room_key, now_token.room_key
+ )
+
+ if rooms_changed:
+ defer.returnValue(True)
+
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ rooms = yield self.store.get_app_service_rooms(app_service)
+ joined_room_ids = set(r.room_id for r in rooms)
+ else:
+ joined_room_ids = yield self.store.get_rooms_for_user(user_id)
+
+ stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
+ for room_id in joined_room_ids:
+ if self.store.has_room_changed_since(room_id, stream_id):
+ defer.returnValue(True)
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
def _get_rooms_changed(self, sync_result_builder, ignored_users):
"""Gets the the changes that have happened since the last sync.
@@ -841,8 +886,7 @@ class SyncHandler(object):
rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms)
else:
- rooms = yield self.store.get_rooms_for_user(user_id)
- joined_room_ids = set(r.room_id for r in rooms)
+ joined_room_ids = yield self.store.get_rooms_for_user(user_id)
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 78b92cef..62b4d7e9 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
+import synapse.util.retryutils
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent
@@ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util import logcontext
import synapse.metrics
from canonicaljson import encode_canonical_json
@@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.clock = hs.get_clock()
+ self._store = hs.get_datastore()
self.version_string = hs.version_string
self._next_id = 1
@@ -103,123 +103,152 @@ class MatrixFederationHttpClient(object):
)
@defer.inlineCallbacks
- def _create_request(self, destination, method, path_bytes,
- body_callback, headers_dict={}, param_bytes=b"",
- query_bytes=b"", retry_on_dns_fail=True,
- timeout=None, long_retries=False):
- """ Creates and sends a request to the given url
- """
- headers_dict[b"User-Agent"] = [self.version_string]
- headers_dict[b"Host"] = [destination]
+ def _request(self, destination, method, path,
+ body_callback, headers_dict={}, param_bytes=b"",
+ query_bytes=b"", retry_on_dns_fail=True,
+ timeout=None, long_retries=False,
+ ignore_backoff=False,
+ backoff_on_404=False):
+ """ Creates and sends a request to the given server
+ Args:
+ destination (str): The remote server to send the HTTP request to.
+ method (str): HTTP method
+ path (str): The HTTP path
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
+ backoff_on_404 (bool): Back off if we get a 404
- url_bytes = self._create_url(
- destination, path_bytes, param_bytes, query_bytes
+ Returns:
+ Deferred: resolves with the http response object on success.
+
+ Fails with ``HTTPRequestException``: if we get an HTTP response
+ code >= 300.
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+ """
+ limiter = yield synapse.util.retryutils.get_retry_limiter(
+ destination,
+ self.clock,
+ self._store,
+ backoff_on_404=backoff_on_404,
+ ignore_backoff=ignore_backoff,
)
- txn_id = "%s-O-%s" % (method, self._next_id)
- self._next_id = (self._next_id + 1) % (sys.maxint - 1)
+ destination = destination.encode("ascii")
+ path_bytes = path.encode("ascii")
+ with limiter:
+ headers_dict[b"User-Agent"] = [self.version_string]
+ headers_dict[b"Host"] = [destination]
- outbound_logger.info(
- "{%s} [%s] Sending request: %s %s",
- txn_id, destination, method, url_bytes
- )
+ url_bytes = self._create_url(
+ destination, path_bytes, param_bytes, query_bytes
+ )
- # XXX: Would be much nicer to retry only at the transaction-layer
- # (once we have reliable transactions in place)
- if long_retries:
- retries_left = MAX_LONG_RETRIES
- else:
- retries_left = MAX_SHORT_RETRIES
+ txn_id = "%s-O-%s" % (method, self._next_id)
+ self._next_id = (self._next_id + 1) % (sys.maxint - 1)
- http_url_bytes = urlparse.urlunparse(
- ("", "", path_bytes, param_bytes, query_bytes, "")
- )
+ outbound_logger.info(
+ "{%s} [%s] Sending request: %s %s",
+ txn_id, destination, method, url_bytes
+ )
- log_result = None
- try:
- while True:
- producer = None
- if body_callback:
- producer = body_callback(method, http_url_bytes, headers_dict)
-
- try:
- def send_request():
- request_deferred = preserve_context_over_fn(
- self.agent.request,
- method,
- url_bytes,
- Headers(headers_dict),
- producer
- )
+ # XXX: Would be much nicer to retry only at the transaction-layer
+ # (once we have reliable transactions in place)
+ if long_retries:
+ retries_left = MAX_LONG_RETRIES
+ else:
+ retries_left = MAX_SHORT_RETRIES
- return self.clock.time_bound_deferred(
- request_deferred,
- time_out=timeout / 1000. if timeout else 60,
- )
+ http_url_bytes = urlparse.urlunparse(
+ ("", "", path_bytes, param_bytes, query_bytes, "")
+ )
- response = yield preserve_context_over_fn(send_request)
+ log_result = None
+ try:
+ while True:
+ producer = None
+ if body_callback:
+ producer = body_callback(method, http_url_bytes, headers_dict)
+
+ try:
+ def send_request():
+ request_deferred = self.agent.request(
+ method,
+ url_bytes,
+ Headers(headers_dict),
+ producer
+ )
+
+ return self.clock.time_bound_deferred(
+ request_deferred,
+ time_out=timeout / 1000. if timeout else 60,
+ )
+
+ with logcontext.PreserveLoggingContext():
+ response = yield send_request()
+
+ log_result = "%d %s" % (response.code, response.phrase,)
+ break
+ except Exception as e:
+ if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+ logger.warn(
+ "DNS Lookup failed to %s with %s",
+ destination,
+ e
+ )
+ log_result = "DNS Lookup failed to %s with %s" % (
+ destination, e
+ )
+ raise
- log_result = "%d %s" % (response.code, response.phrase,)
- break
- except Exception as e:
- if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
- "DNS Lookup failed to %s with %s",
+ "{%s} Sending request failed to %s: %s %s: %s - %s",
+ txn_id,
destination,
- e
+ method,
+ url_bytes,
+ type(e).__name__,
+ _flatten_response_never_received(e),
)
- log_result = "DNS Lookup failed to %s with %s" % (
- destination, e
+
+ log_result = "%s - %s" % (
+ type(e).__name__, _flatten_response_never_received(e),
)
- raise
-
- logger.warn(
- "{%s} Sending request failed to %s: %s %s: %s - %s",
- txn_id,
- destination,
- method,
- url_bytes,
- type(e).__name__,
- _flatten_response_never_received(e),
- )
-
- log_result = "%s - %s" % (
- type(e).__name__, _flatten_response_never_received(e),
- )
-
- if retries_left and not timeout:
- if long_retries:
- delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
- delay = min(delay, 60)
- delay *= random.uniform(0.8, 1.4)
+
+ if retries_left and not timeout:
+ if long_retries:
+ delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
+ delay = min(delay, 60)
+ delay *= random.uniform(0.8, 1.4)
+ else:
+ delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
+ delay = min(delay, 2)
+ delay *= random.uniform(0.8, 1.4)
+
+ yield sleep(delay)
+ retries_left -= 1
else:
- delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
- delay = min(delay, 2)
- delay *= random.uniform(0.8, 1.4)
-
- yield sleep(delay)
- retries_left -= 1
- else:
- raise
- finally:
- outbound_logger.info(
- "{%s} [%s] Result: %s",
- txn_id,
- destination,
- log_result,
- )
+ raise
+ finally:
+ outbound_logger.info(
+ "{%s} [%s] Result: %s",
+ txn_id,
+ destination,
+ log_result,
+ )
- if 200 <= response.code < 300:
- pass
- else:
- # :'(
- # Update transactions table?
- body = yield preserve_context_over_fn(readBody, response)
- raise HttpResponseException(
- response.code, response.phrase, body
- )
+ if 200 <= response.code < 300:
+ pass
+ else:
+ # :'(
+ # Update transactions table?
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
+ raise HttpResponseException(
+ response.code, response.phrase, body
+ )
- defer.returnValue(response)
+ defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None):
@@ -248,7 +277,9 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None,
- long_retries=False, timeout=None):
+ long_retries=False, timeout=None,
+ ignore_backoff=False,
+ backoff_on_404=False):
""" Sends the specifed json data using PUT
Args:
@@ -263,11 +294,19 @@ class MatrixFederationHttpClient(object):
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
+ backoff_on_404 (bool): True if we should count a 404 response as
+ a failure of the server (and should therefore back off future
+ requests)
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
"""
if not json_data_callback:
@@ -282,26 +321,29 @@ class MatrixFederationHttpClient(object):
producer = _JsonProducer(json_data)
return producer
- response = yield self._create_request(
- destination.encode("ascii"),
+ response = yield self._request(
+ destination,
"PUT",
- path.encode("ascii"),
+ path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
timeout=timeout,
+ ignore_backoff=ignore_backoff,
+ backoff_on_404=backoff_on_404,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
- body = yield preserve_context_over_fn(readBody, response)
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
- timeout=None):
+ timeout=None, ignore_backoff=False):
""" Sends the specifed json data using POST
Args:
@@ -314,11 +356,15 @@ class MatrixFederationHttpClient(object):
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
-
+ ignore_backoff (bool): true to ignore the historical backoff data and
+ try the request anyway.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
"""
def body_callback(method, url_bytes, headers_dict):
@@ -327,27 +373,29 @@ class MatrixFederationHttpClient(object):
)
return _JsonProducer(data)
- response = yield self._create_request(
- destination.encode("ascii"),
+ response = yield self._request(
+ destination,
"POST",
- path.encode("ascii"),
+ path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
timeout=timeout,
+ ignore_backoff=ignore_backoff,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
- body = yield preserve_context_over_fn(readBody, response)
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
- timeout=None):
+ timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path
Args:
@@ -359,11 +407,16 @@ class MatrixFederationHttpClient(object):
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
be retried.
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
Returns:
Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
"""
logger.debug("get_json args: %s", args)
@@ -380,36 +433,47 @@ class MatrixFederationHttpClient(object):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
- response = yield self._create_request(
- destination.encode("ascii"),
+ response = yield self._request(
+ destination,
"GET",
- path.encode("ascii"),
+ path,
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
+ ignore_backoff=ignore_backoff,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
- body = yield preserve_context_over_fn(readBody, response)
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
- retry_on_dns_fail=True, max_size=None):
+ retry_on_dns_fail=True, max_size=None,
+ ignore_backoff=False):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET.
output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string.
+ ignore_backoff (bool): true to ignore the historical backoff data
+ and try the request anyway.
Returns:
- A (int,dict) tuple of the file length and a dict of the response
- headers.
+ Deferred: resolves with an (int,dict) tuple of the file length and
+ a dict of the response headers.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response code
+ >= 300
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
"""
encoded_args = {}
@@ -419,28 +483,29 @@ class MatrixFederationHttpClient(object):
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
- logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
+ logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
- response = yield self._create_request(
- destination.encode("ascii"),
+ response = yield self._request(
+ destination,
"GET",
- path.encode("ascii"),
+ path,
query_bytes=query_bytes,
body_callback=body_callback,
- retry_on_dns_fail=retry_on_dns_fail
+ retry_on_dns_fail=retry_on_dns_fail,
+ ignore_backoff=ignore_backoff,
)
headers = dict(response.headers.getAllRawHeaders())
try:
- length = yield preserve_context_over_fn(
- _readBodyToFile,
- response, output_stream, max_size
- )
+ with logcontext.PreserveLoggingContext():
+ length = yield _readBodyToFile(
+ response, output_stream, max_size
+ )
except:
logger.exception("Failed to download body")
raise
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 8c22d6f0..9a4c36ad 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -192,6 +192,16 @@ def parse_json_object_from_request(request):
return content
+def assert_params_in_request(body, required):
+ absent = []
+ for k in required:
+ if k not in body:
+ absent.append(k)
+
+ if len(absent) > 0:
+ raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+
class RestServlet(object):
""" A Synapse REST Servlet.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 8051a7a8..7eeba6d2 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
+from synapse.handlers.presence import format_user_presence_state
from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function
@@ -37,6 +38,10 @@ metrics = synapse.metrics.get_metrics_for(__name__)
notified_events_counter = metrics.register_counter("notified_events")
+users_woken_by_stream_counter = metrics.register_counter(
+ "users_woken_by_stream", labels=["stream"]
+)
+
# TODO(paul): Should be shared somewhere
def count(func, l):
@@ -73,6 +78,13 @@ class _NotifierUserStream(object):
self.user_id = user_id
self.rooms = set(rooms)
self.current_token = current_token
+
+ # The last token for which we should wake up any streams that have a
+ # token that comes before it. This gets updated everytime we get poked.
+ # We start it at the current token since if we get any streams
+ # that have a token from before we have no idea whether they should be
+ # woken up or not, so lets just wake them up.
+ self.last_notified_token = current_token
self.last_notified_ms = time_now_ms
with PreserveLoggingContext():
@@ -89,9 +101,12 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id
)
+ self.last_notified_token = self.current_token
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
+ users_woken_by_stream_counter.inc(stream_key)
+
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)
@@ -113,8 +128,14 @@ class _NotifierUserStream(object):
def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
+
+ Args:
+ token: The token from which we are streaming from, i.e. we shouldn't
+ notify for things that happened before this.
"""
- if self.current_token.is_after(token):
+ # Immediately wake up stream if something has already since happened
+ # since their last token.
+ if self.last_notified_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
@@ -283,8 +304,7 @@ class Notifier(object):
if user_stream is None:
current_token = yield self.event_sources.get_current_token()
if room_ids is None:
- rooms = yield self.store.get_rooms_for_user(user_id)
- room_ids = [room.room_id for room in rooms]
+ room_ids = yield self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
@@ -294,40 +314,44 @@ class Notifier(object):
self._register_with_keys(user_stream)
result = None
+ prev_token = from_token
if timeout:
end_time = self.clock.time_msec() + timeout
- prev_token = from_token
while not result:
try:
- current_token = user_stream.current_token
-
- result = yield callback(prev_token, current_token)
- if result:
- break
-
now = self.clock.time_msec()
if end_time <= now:
break
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
- # We need to supply the token we supplied to callback so
- # that we don't miss any current_token updates.
- prev_token = current_token
listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext():
yield self.clock.time_bound_deferred(
listener.deferred,
time_out=(end_time - now) / 1000.
)
+
+ current_token = user_stream.current_token
+
+ result = yield callback(prev_token, current_token)
+ if result:
+ break
+
+ # Update the prev_token to the current_token since nothing
+ # has happened between the old prev_token and the current_token
+ prev_token = current_token
except DeferredTimedOutError:
break
except defer.CancelledError:
break
- else:
+
+ if result is None:
+ # This happened if there was no timeout or if the timeout had
+ # already expired.
current_token = user_stream.current_token
- result = yield callback(from_token, current_token)
+ result = yield callback(prev_token, current_token)
defer.returnValue(result)
@@ -388,6 +412,15 @@ class Notifier(object):
new_events,
is_peeking=is_peeking,
)
+ elif name == "presence":
+ now = self.clock.time_msec()
+ new_events[:] = [
+ {
+ "type": "m.presence",
+ "content": format_user_presence_state(event, now),
+ }
+ for event in new_events
+ ]
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
@@ -420,8 +453,7 @@ class Notifier(object):
@defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id):
- joined_rooms = yield self.store.get_rooms_for_user(user.to_string())
- joined_room_ids = map(lambda r: r.room_id, joined_rooms)
+ joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
defer.returnValue(([explicit_room_id], True))
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 62d794f2..3a50c72e 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -139,7 +139,7 @@ class Mailer(object):
@defer.inlineCallbacks
def _fetch_room_state(room_id):
- room_state = yield self.state_handler.get_current_state_ids(room_id)
+ room_state = yield self.store.get_current_state_ids(room_id)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4db76f18..4d880465 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -17,6 +17,7 @@ import logging
import re
from synapse.types import UserID
+from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
return self._value_cache.get(dotted_key, None)
+# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
+register_cache("regex_push_cache", regex_cache)
+
+
def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob.
@@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False):
Returns:
bool
"""
- try:
- if IS_GLOB.search(glob):
- r = re.escape(glob)
-
- r = r.replace(r'\*', '.*?')
- r = r.replace(r'\?', '.')
-
- # handle [abc], [a-z] and [!a-z] style ranges.
- r = GLOB_REGEX.sub(
- lambda x: (
- '[%s%s]' % (
- x.group(1) and '^' or '',
- x.group(2).replace(r'\\\-', '-')
- )
- ),
- r,
- )
- if word_boundary:
- r = r"\b%s\b" % (r,)
- r = _compile_regex(r)
-
- return r.search(value)
- else:
- r = r + "$"
- r = _compile_regex(r)
-
- return r.match(value)
- elif word_boundary:
- r = re.escape(glob)
- r = r"\b%s\b" % (r,)
- r = _compile_regex(r)
- return r.search(value)
- else:
- return value.lower() == glob.lower()
+ try:
+ r = regex_cache.get((glob, word_boundary), None)
+ if not r:
+ r = _glob_to_re(glob, word_boundary)
+ regex_cache[(glob, word_boundary)] = r
+ return r.search(value)
except re.error:
logger.warn("Failed to parse glob to regex: %r", glob)
return False
+def _glob_to_re(glob, word_boundary):
+ """Generates regex for a given glob.
+
+ Args:
+ glob (string)
+ word_boundary (bool): Whether to match against word boundaries or entire
+ string. Defaults to False.
+
+ Returns:
+ regex object
+ """
+ if IS_GLOB.search(glob):
+ r = re.escape(glob)
+
+ r = r.replace(r'\*', '.*?')
+ r = r.replace(r'\?', '.')
+
+ # handle [abc], [a-z] and [!a-z] style ranges.
+ r = GLOB_REGEX.sub(
+ lambda x: (
+ '[%s%s]' % (
+ x.group(1) and '^' or '',
+ x.group(2).replace(r'\\\-', '-')
+ )
+ ),
+ r,
+ )
+ if word_boundary:
+ r = r"\b%s\b" % (r,)
+
+ return re.compile(r, flags=re.IGNORECASE)
+ else:
+ r = "^" + r + "$"
+
+ return re.compile(r, flags=re.IGNORECASE)
+ elif word_boundary:
+ r = re.escape(glob)
+ r = r"\b%s\b" % (r,)
+
+ return re.compile(r, flags=re.IGNORECASE)
+ else:
+ r = "^" + re.escape(glob) + "$"
+ return re.compile(r, flags=re.IGNORECASE)
+
+
def _flatten_dict(d, prefix=[], result={}):
for key, value in d.items():
if isinstance(value, basestring):
@@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result
-
-
-regex_cache = LruCache(5000)
-
-
-def _compile_regex(regex_str):
- r = regex_cache.get(regex_str, None)
- if r:
- return r
-
- r = re.compile(regex_str, flags=re.IGNORECASE)
- regex_cache[regex_str] = r
- return r
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index a27476bb..287df94b 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -33,13 +33,13 @@ def get_badge_count(store, user_id):
badge = len(invites)
- for r in joins:
- if r.room_id in my_receipts_by_room:
- last_unread_event_id = my_receipts_by_room[r.room_id]
+ for room_id in joins:
+ if room_id in my_receipts_by_room:
+ last_unread_event_id = my_receipts_by_room[room_id]
notifs = yield (
store.get_unread_event_push_actions_by_room_for_user(
- r.room_id, user_id, last_unread_event_id
+ room_id, user_id, last_unread_event_id
)
)
# return one badge count per conversation, as count per
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 7817b0cd..ed7f1c89 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
+ "jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
@@ -37,6 +39,7 @@ REQUIREMENTS = {
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
+ "phonenumbers>=8.2.0": ["phonenumbers"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index d8eb1459..03930fe9 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -283,12 +283,12 @@ class ReplicationResource(Resource):
if request_events != upto_events_token:
writer.write_header_and_rows("events", res.new_forward_events, (
- "position", "internal", "json", "state_group"
+ "position", "event_id", "room_id", "type", "state_key",
), position=upto_events_token)
if request_backfill != upto_backfill_token:
writer.write_header_and_rows("backfill", res.new_backfill_events, (
- "position", "internal", "json", "state_group",
+ "position", "event_id", "room_id", "type", "state_key", "redacts",
), position=upto_backfill_token)
writer.write_header_and_rows(
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 24b5c79d..9d1d173b 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -27,4 +27,9 @@ class SlavedIdTracker(object):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
+ """
+
+ Returns:
+ int
+ """
return self._current
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 622b2d85..d4db1e45 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -16,7 +16,6 @@ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes
-from synapse.events import FrozenEvent
from synapse.storage import DataStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
@@ -25,7 +24,6 @@ from synapse.storage.state import StateStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
-import ujson as json
import logging
@@ -109,6 +107,10 @@ class SlavedEventStore(BaseSlavedStore):
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
+ get_current_state_ids = (
+ StateStore.__dict__["get_current_state_ids"]
+ )
+ has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = (
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
@@ -165,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore):
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
- _get_members_rows_txn = DataStore._get_members_rows_txn.__func__
_get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__
@@ -238,46 +239,32 @@ class SlavedEventStore(BaseSlavedStore):
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled):
- internal = json.loads(row[1])
- event_json = json.loads(row[2])
- event = FrozenEvent(event_json, internal_metadata_dict=internal)
+ stream_ordering = row[0] if not backfilled else -row[0]
self.invalidate_caches_for_event(
- event, backfilled,
+ stream_ordering, row[1], row[2], row[3], row[4], row[5],
+ backfilled=backfilled,
)
- def invalidate_caches_for_event(self, event, backfilled):
- self._invalidate_get_event_cache(event.event_id)
+ def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
+ etype, state_key, redacts, backfilled):
+ self._invalidate_get_event_cache(event_id)
- self.get_latest_event_ids_in_room.invalidate((event.room_id,))
+ self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
- (event.room_id,)
+ (room_id,)
)
if not backfilled:
self._events_stream_cache.entity_has_changed(
- event.room_id, event.internal_metadata.stream_ordering
+ room_id, stream_ordering
)
- # self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
- # (event.room_id,)
- # )
-
- if event.type == EventTypes.Redaction:
- self._invalidate_get_event_cache(event.redacts)
+ if redacts:
+ self._invalidate_get_event_cache(redacts)
- if event.type == EventTypes.Member:
+ if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(
- event.state_key, event.internal_metadata.stream_ordering
+ state_key, stream_ordering
)
- self.get_invited_rooms_for_user.invalidate((event.state_key,))
-
- if not event.is_state():
- return
-
- if backfilled:
- return
-
- if (not event.internal_metadata.is_invite_from_remote()
- and event.internal_metadata.is_outlier()):
- return
+ self.get_invited_rooms_for_user.invalidate((state_key,))
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 40f6c9a3..e4a2414d 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore):
self.presence_stream_cache.entity_has_changed(
user_id, position
)
+ self._get_presence_for_user.invalidate((user_id,))
return super(SlavedPresenceStore, self).process_replication(result)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 72057f1b..a43410fb 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID
from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request
+from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
@@ -33,10 +34,55 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
+from twisted.web.client import PartialDownloadError
+
logger = logging.getLogger(__name__)
+def login_submission_legacy_convert(submission):
+ """
+ If the input login submission is an old style object
+ (ie. with top-level user / medium / address) convert it
+ to a typed object.
+ """
+ if "user" in submission:
+ submission["identifier"] = {
+ "type": "m.id.user",
+ "user": submission["user"],
+ }
+ del submission["user"]
+
+ if "medium" in submission and "address" in submission:
+ submission["identifier"] = {
+ "type": "m.id.thirdparty",
+ "medium": submission["medium"],
+ "address": submission["address"],
+ }
+ del submission["medium"]
+ del submission["address"]
+
+
+def login_id_thirdparty_from_phone(identifier):
+ """
+ Convert a phone login identifier type to a generic threepid identifier
+ Args:
+ identifier(dict): Login identifier dict of type 'm.id.phone'
+
+ Returns: Login identifier dict of type 'm.id.threepid'
+ """
+ if "country" not in identifier or "number" not in identifier:
+ raise SynapseError(400, "Invalid phone-type identifier")
+
+ msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
+
+ return {
+ "type": "m.id.thirdparty",
+ "medium": "msisdn",
+ "address": msisdn,
+ }
+
+
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password"
@@ -117,20 +163,52 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_password_login(self, login_submission):
- if 'medium' in login_submission and 'address' in login_submission:
- address = login_submission['address']
- if login_submission['medium'] == 'email':
+ if "password" not in login_submission:
+ raise SynapseError(400, "Missing parameter: password")
+
+ login_submission_legacy_convert(login_submission)
+
+ if "identifier" not in login_submission:
+ raise SynapseError(400, "Missing param: identifier")
+
+ identifier = login_submission["identifier"]
+ if "type" not in identifier:
+ raise SynapseError(400, "Login identifier has no type")
+
+ # convert phone type identifiers to generic threepids
+ if identifier["type"] == "m.id.phone":
+ identifier = login_id_thirdparty_from_phone(identifier)
+
+ # convert threepid identifiers to user IDs
+ if identifier["type"] == "m.id.thirdparty":
+ if 'medium' not in identifier or 'address' not in identifier:
+ raise SynapseError(400, "Invalid thirdparty identifier")
+
+ address = identifier['address']
+ if identifier['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- login_submission['medium'], address
+ identifier['medium'], address
)
if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
- else:
- user_id = login_submission['user']
+
+ identifier = {
+ "type": "m.id.user",
+ "user": user_id,
+ }
+
+ # by this point, the identifier should be an m.id.user: if it's anything
+ # else, we haven't understood it.
+ if identifier["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+ if "user" not in identifier:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ user_id = identifier["user"]
if not user_id.startswith('@'):
user_id = UserID.create(
@@ -341,7 +419,12 @@ class CasTicketServlet(ClientV1RestServlet):
"ticket": request.args["ticket"],
"service": self.cas_service_url
}
- body = yield http_client.get_raw(uri, args)
+ try:
+ body = yield http_client.get_raw(uri, args)
+ except PartialDownloadError as pde:
+ # Twisted raises this error if the connection is closed,
+ # even if that's being used old-http style to signal end-of-data
+ body = pde.response
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index eafdce86..47b2dc45 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID
+from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
@@ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler()
+ self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
@@ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.presence_handler.get_state(target_user=user)
+ state = format_user_presence_state(state, self.clock.time_msec())
defer.returnValue((200, state))
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 90242a6b..0bdd6b5b 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -748,8 +748,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet):
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
- room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
+ room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
defer.returnValue((200, {"joined_rooms": list(room_ids)}))
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 398e7f5e..4990b22b 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,8 +18,11 @@ from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+ RestServlet, parse_json_object_from_request, assert_params_in_request
+)
from synapse.util.async import run_on_reactor
+from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns
@@ -28,11 +32,11 @@ import logging
logger = logging.getLogger(__name__)
-class PasswordRequestTokenRestServlet(RestServlet):
+class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
- super(PasswordRequestTokenRestServlet, self).__init__()
+ super(EmailPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@@ -40,14 +44,9 @@ class PasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ assert_params_in_request(body, [
+ 'id_server', 'client_secret', 'email', 'send_attempt'
+ ])
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
@@ -60,6 +59,37 @@ class PasswordRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
+class MsisdnPasswordRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
+
+ def __init__(self, hs):
+ super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
+ self.hs = hs
+ self.datastore = self.hs.get_datastore()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_request(body, [
+ 'id_server', 'client_secret',
+ 'country', 'phone_number', 'send_attempt',
+ ])
+
+ msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+ existingUid = yield self.datastore.get_user_id_by_threepid(
+ 'msisdn', msisdn
+ )
+
+ if existingUid is None:
+ raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
+
+ ret = yield self.identity_handler.requestMsisdnToken(**body)
+ defer.returnValue((200, ret))
+
+
class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$")
@@ -68,6 +98,7 @@ class PasswordRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -77,7 +108,8 @@ class PasswordRestServlet(RestServlet):
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
- [LoginType.EMAIL_IDENTITY]
+ [LoginType.EMAIL_IDENTITY],
+ [LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request))
if not authed:
@@ -102,7 +134,7 @@ class PasswordRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with!
- threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ threepid_user_id = yield self.datastore.get_user_id_by_threepid(
threepid['medium'], threepid['address']
)
if not threepid_user_id:
@@ -169,13 +201,14 @@ class DeactivateAccountRestServlet(RestServlet):
defer.returnValue((200, {}))
-class ThreepidRequestTokenRestServlet(RestServlet):
+class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
- super(ThreepidRequestTokenRestServlet, self).__init__()
+ super(EmailThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
+ self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -190,7 +223,7 @@ class ThreepidRequestTokenRestServlet(RestServlet):
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
- existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+ existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email']
)
@@ -201,6 +234,44 @@ class ThreepidRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
+class MsisdnThreepidRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.datastore = self.hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ required = [
+ 'id_server', 'client_secret',
+ 'country', 'phone_number', 'send_attempt',
+ ]
+ absent = []
+ for k in required:
+ if k not in body:
+ absent.append(k)
+
+ if absent:
+ raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+ msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+ existingUid = yield self.datastore.get_user_id_by_threepid(
+ 'msisdn', msisdn
+ )
+
+ if existingUid is not None:
+ raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
+
+ ret = yield self.identity_handler.requestMsisdnToken(**body)
+ defer.returnValue((200, ret))
+
+
class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$")
@@ -210,6 +281,7 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -217,7 +289,7 @@ class ThreepidRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request)
- threepids = yield self.hs.get_datastore().user_get_threepids(
+ threepids = yield self.datastore.user_get_threepids(
requester.user.to_string()
)
@@ -258,7 +330,7 @@ class ThreepidRestServlet(RestServlet):
if 'bind' in body and body['bind']:
logger.debug(
- "Binding emails %s to %s",
+ "Binding threepid %s to %s",
threepid, user_id
)
yield self.identity_handler.bind_threepid(
@@ -302,9 +374,11 @@ class ThreepidDeleteRestServlet(RestServlet):
def register_servlets(hs, http_server):
- PasswordRequestTokenRestServlet(hs).register(http_server)
+ EmailPasswordRequestTokenRestServlet(hs).register(http_server)
+ MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
- ThreepidRequestTokenRestServlet(hs).register(http_server)
+ EmailThreepidRequestTokenRestServlet(hs).register(http_server)
+ MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index a1feaf3d..b57ba95d 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -46,6 +46,52 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices}))
+class DeleteDevicesRestServlet(servlet.RestServlet):
+ """
+ API for bulk deletion of devices. Accepts a JSON object with a devices
+ key which lists the device_ids to delete. Requires user interactive auth.
+ """
+ PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
+
+ def __init__(self, hs):
+ super(DeleteDevicesRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+ self.auth_handler = hs.get_auth_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ try:
+ body = servlet.parse_json_object_from_request(request)
+ except errors.SynapseError as e:
+ if e.errcode == errors.Codes.NOT_JSON:
+ # deal with older clients which didn't pass a J*DELETESON dict
+ # the same as those that pass an empty dict
+ body = {}
+ else:
+ raise e
+
+ if 'devices' not in body:
+ raise errors.SynapseError(
+ 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
+ )
+
+ authed, result, params, _ = yield self.auth_handler.check_auth([
+ [constants.LoginType.PASSWORD],
+ ], body, self.hs.get_ip_from_request(request))
+
+ if not authed:
+ defer.returnValue((401, result))
+
+ requester = yield self.auth.get_user_by_req(request)
+ yield self.device_handler.delete_devices(
+ requester.user.to_string(),
+ body['devices'],
+ )
+ defer.returnValue((200, {}))
+
+
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
@@ -111,5 +157,6 @@ class DeviceRestServlet(servlet.RestServlet):
def register_servlets(hs, http_server):
+ DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ccca5a12..3acf4eac 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 - 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +20,10 @@ import synapse
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+ RestServlet, parse_json_object_from_request, assert_params_in_request
+)
+from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns
@@ -43,7 +47,7 @@ else:
logger = logging.getLogger(__name__)
-class RegisterRequestTokenRestServlet(RestServlet):
+class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs):
@@ -51,7 +55,7 @@ class RegisterRequestTokenRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RegisterRequestTokenRestServlet, self).__init__()
+ super(EmailRegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@@ -59,14 +63,9 @@ class RegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if len(absent) > 0:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ assert_params_in_request(body, [
+ 'id_server', 'client_secret', 'email', 'send_attempt'
+ ])
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
@@ -79,6 +78,43 @@ class RegisterRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
+class MsisdnRegisterRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_request(body, [
+ 'id_server', 'client_secret',
+ 'country', 'phone_number',
+ 'send_attempt',
+ ])
+
+ msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+ existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+ 'msisdn', msisdn
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "Phone number is already in use", Codes.THREEPID_IN_USE
+ )
+
+ ret = yield self.identity_handler.requestMsisdnToken(**body)
+ defer.returnValue((200, ret))
+
+
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$")
@@ -200,16 +236,37 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id,
)
+ # Only give msisdn flows if the x_show_msisdn flag is given:
+ # this is a hack to work around the fact that clients were shipped
+ # that use fallback registration if they see any flows that they don't
+ # recognise, which means we break registration for these clients if we
+ # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
+ # Android <=0.6.9 have fallen below an acceptable threshold, this
+ # parameter should go away and we should always advertise msisdn flows.
+ show_msisdn = False
+ if 'x_show_msisdn' in body and body['x_show_msisdn']:
+ show_msisdn = True
+
if self.hs.config.enable_registration_captcha:
flows = [
[LoginType.RECAPTCHA],
- [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
+ [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
]
+ if show_msisdn:
+ flows.extend([
+ [LoginType.MSISDN, LoginType.RECAPTCHA],
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
+ ])
else:
flows = [
[LoginType.DUMMY],
- [LoginType.EMAIL_IDENTITY]
+ [LoginType.EMAIL_IDENTITY],
]
+ if show_msisdn:
+ flows.extend([
+ [LoginType.MSISDN],
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
+ ])
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
@@ -224,8 +281,9 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session",
registered_user_id
)
- # don't re-register the email address
+ # don't re-register the threepids
add_email = False
+ add_msisdn = False
else:
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
@@ -250,6 +308,7 @@ class RegisterRestServlet(RestServlet):
)
add_email = True
+ add_msisdn = True
return_dict = yield self._create_registration_details(
registered_user_id, params
@@ -262,6 +321,13 @@ class RegisterRestServlet(RestServlet):
params.get("bind_email")
)
+ if add_msisdn and auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ yield self._register_msisdn_threepid(
+ registered_user_id, threepid, return_dict["access_token"],
+ params.get("bind_msisdn")
+ )
+
defer.returnValue((200, return_dict))
def on_OPTIONS(self, _):
@@ -323,8 +389,9 @@ class RegisterRestServlet(RestServlet):
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
+ # This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
- defer.returnValue()
+ return
yield self.auth_handler.add_threepid(
user_id,
@@ -372,6 +439,43 @@ class RegisterRestServlet(RestServlet):
logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
+ def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn):
+ """Add a phone number as a 3pid identifier
+
+ Also optionally binds msisdn to the given user_id on the identity server
+
+ Args:
+ user_id (str): id of user
+ threepid (object): m.login.msisdn auth response
+ token (str): access_token for the user
+ bind_email (bool): true if the client requested the email to be
+ bound at the identity server
+ Returns:
+ defer.Deferred:
+ """
+ reqd = ('medium', 'address', 'validated_at')
+ if any(x not in threepid for x in reqd):
+ # This will only happen if the ID server returns a malformed response
+ logger.info("Can't add incomplete 3pid")
+ defer.returnValue()
+
+ yield self.auth_handler.add_threepid(
+ user_id,
+ threepid['medium'],
+ threepid['address'],
+ threepid['validated_at'],
+ )
+
+ if bind_msisdn:
+ logger.info("bind_msisdn specified: binding")
+ logger.debug("Binding msisdn %s to %s", threepid, user_id)
+ yield self.identity_handler.bind_threepid(
+ threepid['threepid_creds'], user_id
+ )
+ else:
+ logger.info("bind_msisdn not specified: not binding msisdn")
+
+ @defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
@@ -433,7 +537,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
- self.device_handler.check_device_registered(
+ yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
@@ -449,5 +553,6 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server):
- RegisterRequestTokenRestServlet(hs).register(http_server)
+ EmailRegisterRequestTokenRestServlet(hs).register(http_server)
+ MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index b3d80016..a7a9e0a7 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.http.servlet import (
RestServlet, parse_string, parse_integer, parse_boolean
)
+from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
from synapse.events.utils import (
@@ -28,7 +29,6 @@ from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from ._base import client_v2_patterns
-import copy
import itertools
import logging
@@ -194,12 +194,18 @@ class SyncRestServlet(RestServlet):
defer.returnValue((200, response_content))
def encode_presence(self, events, time_now):
- formatted = []
- for event in events:
- event = copy.deepcopy(event)
- event['sender'] = event['content'].pop('user_id')
- formatted.append(event)
- return {"events": formatted}
+ return {
+ "events": [
+ {
+ "type": "m.presence",
+ "sender": event.user_id,
+ "content": format_user_presence_state(
+ event, time_now, include_user_id=False
+ ),
+ }
+ for event in events
+ ]
+ }
def encode_joined(self, rooms, time_now, token_id, event_fields):
"""
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index dfb87ffd..6788375e 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import synapse.http.servlet
from ._base import parse_media_id, respond_with_file, respond_404
from twisted.web.resource import Resource
@@ -81,6 +82,17 @@ class DownloadResource(Resource):
@defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id, name):
+ # don't forward requests for remote media if allow_remote is false
+ allow_remote = synapse.http.servlet.parse_boolean(
+ request, "allow_remote", default=True)
+ if not allow_remote:
+ logger.info(
+ "Rejecting request for remote media %s/%s due to allow_remote",
+ server_name, media_id,
+ )
+ respond_404(request)
+ return
+
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 481ffee2..c43b185e 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer, threads
+import twisted.internet.error
+import twisted.web.http
+from twisted.web.resource import Resource
+
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths
-
-from twisted.web.resource import Resource
-
from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError
-
-from twisted.internet import defer, threads
+from synapse.api.errors import SynapseError, HttpResponseException, \
+ NotFoundError
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
@@ -157,11 +158,34 @@ class MediaRepository(object):
try:
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
- max_size=self.max_upload_size,
+ max_size=self.max_upload_size, args={
+ # tell the remote server to 404 if it doesn't
+ # recognise the server_name, to make sure we don't
+ # end up with a routing loop.
+ "allow_remote": "false",
+ }
)
- except Exception as e:
- logger.warn("Failed to fetch remoted media %r", e)
- raise SynapseError(502, "Failed to fetch remoted media")
+ except twisted.internet.error.DNSLookupError as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %r",
+ server_name, media_id, e)
+ raise NotFoundError()
+
+ except HttpResponseException as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %s",
+ server_name, media_id, e.response)
+ if e.code == twisted.web.http.NOT_FOUND:
+ raise SynapseError.from_http_response_exception(e)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ except SynapseError:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise
+
+ except Exception:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise SynapseError(502, "Failed to fetch remote media")
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
diff --git a/synapse/state.py b/synapse/state.py
index 383d32b1..f6b83d88 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -177,17 +177,12 @@ class StateHandler(object):
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
- """ Fills out the context with the `current state` of the graph. The
- `current state` here is defined to be the state of the event graph
- just before the event - i.e. it never includes `event`
-
- If `event` has `auth_events` then this will also fill out the
- `auth_events` field on `context` from the `current_state`.
+ """Build an EventContext structure for the event.
Args:
- event (EventBase)
+ event (synapse.events.EventBase):
Returns:
- an EventContext
+ synapse.events.snapshot.EventContext:
"""
context = EventContext()
@@ -200,11 +195,11 @@ class StateHandler(object):
(s.type, s.state_key): s.event_id for s in old_state
}
if event.is_state():
- context.current_state_events = dict(context.prev_state_ids)
+ context.current_state_ids = dict(context.prev_state_ids)
key = (event.type, event.state_key)
- context.current_state_events[key] = event.event_id
+ context.current_state_ids[key] = event.event_id
else:
- context.current_state_events = context.prev_state_ids
+ context.current_state_ids = context.prev_state_ids
else:
context.current_state_ids = {}
context.prev_state_ids = {}
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a7a8ec9b..c659004e 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -73,6 +73,9 @@ class LoggingTransaction(object):
def __setattr__(self, name, value):
setattr(self.txn, name, value)
+ def __iter__(self):
+ return self.txn.__iter__()
+
def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args)
@@ -132,7 +135,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
- for name, (count, cum_time) in self.current_counters.items():
+ for name, (count, cum_time) in self.current_counters.iteritems():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@@ -357,7 +360,7 @@ class SQLBaseStore(object):
"""
col_headers = list(intern(column[0]) for column in cursor.description)
results = list(
- dict(zip(col_headers, row)) for row in cursor.fetchall()
+ dict(zip(col_headers, row)) for row in cursor
)
return results
@@ -565,7 +568,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
@@ -579,7 +582,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values())
- return [r[0] for r in txn.fetchall()]
+ return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"):
@@ -712,7 +715,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.items():
+ for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -753,7 +756,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
@@ -840,6 +843,47 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values())
+ def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ return self.runInteraction(
+ desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
+ )
+
+ @staticmethod
+ def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ """
+ if not iterable:
+ return
+
+ sql = "DELETE FROM %s" % table
+
+ clauses = []
+ values = []
+ clauses.append(
+ "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+ )
+ values.extend(iterable)
+
+ for key, value in keyvalues.iteritems():
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ if clauses:
+ sql = "%s WHERE %s" % (
+ sql,
+ " AND ".join(clauses),
+ )
+ return txn.execute(sql, values)
+
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
@@ -860,16 +904,16 @@ class SQLBaseStore(object):
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
- rows = txn.fetchall()
- txn.close()
cache = {
row[0]: int(row[1])
- for row in rows
+ for row in txn
}
+ txn.close()
+
if cache:
- min_val = min(cache.values())
+ min_val = min(cache.itervalues())
else:
min_val = max_value
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 3fa226e9..aa84ffc2 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
global_account_data = {
- row[0]: json.loads(row[1]) for row in txn.fetchall()
+ row[0]: json.loads(row[1]) for row in txn
}
sql = (
@@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
account_data_by_room = {}
- for row in txn.fetchall():
+ for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 94b2bcc5..813ad59e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import synapse.util.async
from ._base import SQLBaseStore
from . import engines
@@ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
- self._background_update_timer = None
@defer.inlineCallbacks
def start_doing_background_updates(self):
- assert self._background_update_timer is None, \
- "background updates already running"
-
logger.info("Starting background schema updates")
while True:
- sleep = defer.Deferred()
- self._background_update_timer = self._clock.call_later(
- self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
- )
- try:
- yield sleep
- finally:
- self._background_update_timer = None
+ yield synapse.util.async.sleep(
+ self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try:
result = yield self.do_next_background_update(
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 5c7db5e5..2714519d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"])
- for row in txn.fetchall():
+ for row in txn:
# Add the message for all devices for this user on this
# server.
device = row[0]
@@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
txn.execute(sql, [user_id] + devices)
- for row in txn.fetchall():
+ for row in txn:
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
@@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
- for row in txn.fetchall():
+ for row in txn:
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
@@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" ORDER BY stream_id ASC"
)
txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn.fetchall())
+ rows.extend(txn)
return rows
@@ -357,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
"""
Args:
destination(str): The name of the remote server.
- last_stream_id(int): The last position of the device message stream
+ last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
- current_stream_id(int): The current position of the device
+ current_stream_id(int|long): The current position of the device
message stream.
Returns:
- Deferred ([dict], int): List of messages for the device and where
+ Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
@@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit
))
messages = []
- for row in txn.fetchall():
+ for row in txn:
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index bd56ba25..53e36791 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore):
desc="delete_device",
)
+ def delete_devices(self, user_id, device_ids):
+ """Deletes several devices.
+
+ Args:
+ user_id (str): The ID of the user which owns the devices
+ device_ids (list): The IDs of the devices to delete
+ Returns:
+ defer.Deferred
+ """
+ return self._simple_delete_many(
+ table="devices",
+ column="device_id",
+ iterable=device_ids,
+ keyvalues={"user_id": user_id},
+ desc="delete_devices",
+ )
+
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
@@ -291,7 +308,7 @@ class DeviceStore(SQLBaseStore):
"""Get stream of updates to send to remote servers
Returns:
- (now_stream_id, [ { updates }, .. ])
+ (int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -312,17 +329,20 @@ class DeviceStore(SQLBaseStore):
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
+ LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
- rows = txn.fetchall()
- if not rows:
+ # maps (user_id, device_id) -> stream_id
+ query_map = {(r[0], r[1]): r[2] for r in txn}
+ if not query_map:
return (now_stream_id, [])
- # maps (user_id, device_id) -> stream_id
- query_map = {(r[0], r[1]): r[2] for r in rows}
+ if len(query_map) >= 20:
+ now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index b9f1365f..7cbc1470 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,6 +14,8 @@
# limitations under the License.
from twisted.internet import defer
+from synapse.api.errors import SynapseError
+
from canonicaljson import encode_canonical_json
import ujson as json
@@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
return result
+ @defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+ """Insert some new one time keys for a device.
+
+ Checks if any of the keys are already inserted, if they are then check
+ if they match. If they don't then we raise an error.
+ """
+
+ # First we check if we have already persisted any of the keys.
+ rows = yield self._simple_select_many_batch(
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ iterable=[key_id for _, key_id, _ in key_list],
+ retcols=("algorithm", "key_id", "key_json",),
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ desc="add_e2e_one_time_keys_check",
+ )
+
+ existing_key_map = {
+ (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
+ }
+
+ new_keys = [] # Keys that we need to insert
+ for algorithm, key_id, json_bytes in key_list:
+ ex_bytes = existing_key_map.get((algorithm, key_id), None)
+ if ex_bytes:
+ if json_bytes != ex_bytes:
+ raise SynapseError(
+ 400, "One time key with key_id %r already exists" % (key_id,)
+ )
+ else:
+ new_keys.append((algorithm, key_id, json_bytes))
+
def _add_e2e_one_time_keys(txn):
- for (algorithm, key_id, json_bytes) in key_list:
- self._simple_upsert_txn(
- txn, table="e2e_one_time_keys_json",
- keyvalues={
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self._simple_insert_many_txn(
+ txn, table="e2e_one_time_keys_json",
+ values=[
+ {
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
- },
- values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
- )
- return self.runInteraction(
- "add_e2e_one_time_keys", _add_e2e_one_time_keys
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ yield self.runInteraction(
+ "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
def count_e2e_one_time_keys(self, user_id, device_id):
@@ -153,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore):
)
txn.execute(sql, (user_id, device_id))
result = {}
- for algorithm, key_count in txn.fetchall():
+ for algorithm, key_count in txn:
result[algorithm] = key_count
return result
return self.runInteraction(
@@ -174,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore):
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn.fetchall():
+ for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
sql = (
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 256e50dc..519059c3 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
base_sql % (",".join(["?"] * len(chunk)),),
chunk
)
- new_front.update([r[0] for r in txn.fetchall()])
+ new_front.update([r[0] for r in txn])
new_front -= results
@@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, False,))
- return dict(txn.fetchall())
+ return dict(txn)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
@@ -201,19 +201,19 @@ class EventFederationStore(SQLBaseStore):
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
- do_insert = depth < min_depth if min_depth else True
+ if min_depth and depth >= min_depth:
+ return
- if do_insert:
- self._simple_upsert_txn(
- txn,
- table="room_depth",
- keyvalues={
- "room_id": room_id,
- },
- values={
- "min_depth": depth,
- },
- )
+ self._simple_upsert_txn(
+ txn,
+ table="room_depth",
+ keyvalues={
+ "room_id": room_id,
+ },
+ values={
+ "min_depth": depth,
+ },
+ )
def _handle_mult_prev_events(self, txn, events):
"""
@@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
- rows = txn.fetchall()
- return [event_id for event_id, in rows]
+ return [event_id for event_id, in txn]
return self.runInteraction(
"get_forward_extremeties_for_room",
@@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results))
)
- for row in txn.fetchall():
+ for row in txn:
if row[1] not in event_results:
queue.put((-row[0], row[1]))
@@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results))
)
- for e_id, in txn.fetchall():
+ for e_id, in txn:
new_front.add(e_id)
new_front -= earliest_events
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 14543b42..d6d8723b 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
" stream_ordering >= ? AND stream_ordering <= ?"
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
- return [r[0] for r in txn.fetchall()]
+ return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index db01eb6d..3f6833fa 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict
from functools import wraps
-import synapse
import synapse.metrics
-
import logging
import math
import ujson as json
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
+from synapse.events.snapshot import EventContext # noqa: F401
+
logger = logging.getLogger(__name__)
@@ -82,6 +84,11 @@ class _EventPeristenceQueue(object):
def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options.
+
+ Args:
+ room_id (str):
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
@@ -210,14 +217,14 @@ class EventsStore(SQLBaseStore):
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in partitioned.items():
+ for room_id, evs_ctxs in partitioned.iteritems():
d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs,
backfilled=backfilled,
)
deferreds.append(d)
- for room_id in partitioned.keys():
+ for room_id in partitioned:
self._maybe_start_persisting(room_id)
return preserve_context_over_deferred(
@@ -227,6 +234,17 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False):
+ """
+
+ Args:
+ event (EventBase):
+ context (EventContext):
+ backfilled (bool):
+
+ Returns:
+ Deferred: resolves to (int, int): the stream ordering of ``event``,
+ and the stream ordering of the latest persisted event
+ """
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)],
backfilled=backfilled,
@@ -253,6 +271,16 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def _persist_events(self, events_and_contexts, backfilled=False,
delete_existing=False):
+ """Persist events to db
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+ delete_existing (bool):
+
+ Returns:
+ Deferred: resolves when the events have been persisted
+ """
if not events_and_contexts:
return
@@ -295,7 +323,7 @@ class EventsStore(SQLBaseStore):
(event, context)
)
- for room_id, ev_ctx_rm in events_by_room.items():
+ for room_id, ev_ctx_rm in events_by_room.iteritems():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -400,6 +428,7 @@ class EventsStore(SQLBaseStore):
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
+ state_groups = set()
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
@@ -409,9 +438,17 @@ class EventsStore(SQLBaseStore):
if event_id == ev.event_id:
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
- state_sets.append(ctx.current_state_ids)
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
+
+ # If we've already seen the state group don't bother adding
+ # it to the state sets again
+ if ctx.state_group not in state_groups:
+ state_sets.append(ctx.current_state_ids)
+ if ctx.delta_ids or hasattr(ev, "state_key"):
+ was_updated = True
+ if ctx.state_group:
+ # Add this as a seen state group (if it has a state
+ # group)
+ state_groups.add(ctx.state_group)
break
else:
# If we couldn't find it, then we'll need to pull
@@ -425,31 +462,57 @@ class EventsStore(SQLBaseStore):
missing_event_ids,
)
- groups = set(event_to_groups.values())
- group_to_state = yield self._get_state_for_groups(groups)
+ groups = set(event_to_groups.itervalues()) - state_groups
- state_sets.extend(group_to_state.values())
+ if groups:
+ group_to_state = yield self._get_state_for_groups(groups)
+ state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids:
current_state = {}
elif was_updated:
- current_state = yield resolve_events(
- state_sets,
- state_map_factory=lambda ev_ids: self.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- ),
- )
+ if len(state_sets) == 1:
+ # If there is only one state set, then we know what the current
+ # state is.
+ current_state = state_sets[0]
+ else:
+ # We work out the current state by passing the state sets to the
+ # state resolution algorithm. It may ask for some events, including
+ # the events we have yet to persist, so we need a slightly more
+ # complicated event lookup function than simply looking the events
+ # up in the db.
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+
+ @defer.inlineCallbacks
+ def get_events(ev_ids):
+ # We get the events by first looking at the list of events we
+ # are trying to persist, and then fetching the rest from the DB.
+ db = []
+ to_return = {}
+ for ev_id in ev_ids:
+ ev = events_map.get(ev_id, None)
+ if ev:
+ to_return[ev_id] = ev
+ else:
+ db.append(ev_id)
+
+ if db:
+ evs = yield self.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
+ to_return.update(evs)
+ defer.returnValue(to_return)
+
+ current_state = yield resolve_events(
+ state_sets,
+ state_map_factory=get_events,
+ )
else:
return
- existing_state_rows = yield self._simple_select_list(
- table="current_state_events",
- keyvalues={"room_id": room_id},
- retcols=["event_id", "type", "state_key"],
- desc="_calculate_state_delta",
- )
+ existing_state = yield self.get_current_state_ids(room_id)
- existing_events = set(row["event_id"] for row in existing_state_rows)
+ existing_events = set(existing_state.itervalues())
new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events
@@ -457,9 +520,8 @@ class EventsStore(SQLBaseStore):
return
to_delete = {
- (row["type"], row["state_key"]): row["event_id"]
- for row in existing_state_rows
- if row["event_id"] in changed_events
+ key: ev_id for key, ev_id in existing_state.iteritems()
+ if ev_id in changed_events
}
events_to_insert = (new_events - existing_events)
to_insert = {
@@ -535,11 +597,91 @@ class EventsStore(SQLBaseStore):
and the rejections table. Things reading from those table will need to check
whether the event was rejected.
- If delete_existing is True then existing events will be purged from the
- database before insertion. This is useful when retrying due to IntegrityError.
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]):
+ events to persist
+ backfilled (bool): True if the events were backfilled
+ delete_existing (bool): True to purge existing table rows for the
+ events from the database. This is useful when retrying due to
+ IntegrityError.
+ current_state_for_room (dict[str, (list[str], list[str])]):
+ The current-state delta for each room. For each room, a tuple
+ (to_delete, to_insert), being a list of event ids to be removed
+ from the current state, and a list of event ids to be added to
+ the current state.
+ new_forward_extremeties (dict[str, list[str]]):
+ The new forward extremities for each room. For each room, a
+ list of the event ids which are the forward extremities.
+
"""
+ self._update_current_state_txn(txn, current_state_for_room)
+
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- for room_id, current_state_tuple in current_state_for_room.iteritems():
+ self._update_forward_extremities_txn(
+ txn,
+ new_forward_extremities=new_forward_extremeties,
+ max_stream_order=max_stream_order,
+ )
+
+ # Ensure that we don't have the same event twice.
+ events_and_contexts = self._filter_events_and_contexts_for_duplicates(
+ events_and_contexts,
+ )
+
+ self._update_room_depths_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ )
+
+ # _update_outliers_txn filters out any events which have already been
+ # persisted, and returns the filtered list.
+ events_and_contexts = self._update_outliers_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ # From this point onwards the events are only events that we haven't
+ # seen before.
+
+ if delete_existing:
+ # For paranoia reasons, we go and delete all the existing entries
+ # for these events so we can reinsert them.
+ # This gets around any problems with some tables already having
+ # entries.
+ self._delete_existing_rows_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ self._store_event_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ # Insert into the state_groups, state_groups_state, and
+ # event_to_state_groups tables.
+ self._store_mult_state_groups_txn(txn, events_and_contexts)
+
+ # _store_rejected_events_txn filters out any events which were
+ # rejected, and returns the filtered list.
+ events_and_contexts = self._store_rejected_events_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ # From this point onwards the events are only ones that weren't
+ # rejected.
+
+ self._update_metadata_tables_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ )
+
+ def _update_current_state_txn(self, txn, state_delta_by_room):
+ for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
@@ -585,7 +727,13 @@ class EventsStore(SQLBaseStore):
txn, self.get_users_in_room, (room_id,)
)
- for room_id, new_extrem in new_forward_extremeties.items():
+ self._invalidate_cache_and_stream(
+ txn, self.get_current_state_ids, (room_id,)
+ )
+
+ def _update_forward_extremities_txn(self, txn, new_forward_extremities,
+ max_stream_order):
+ for room_id, new_extrem in new_forward_extremities.iteritems():
self._simple_delete_txn(
txn,
table="event_forward_extremities",
@@ -603,7 +751,7 @@ class EventsStore(SQLBaseStore):
"event_id": ev_id,
"room_id": room_id,
}
- for room_id, new_extrem in new_forward_extremeties.items()
+ for room_id, new_extrem in new_forward_extremities.iteritems()
for ev_id in new_extrem
],
)
@@ -620,13 +768,22 @@ class EventsStore(SQLBaseStore):
"event_id": event_id,
"stream_ordering": max_stream_order,
}
- for room_id, new_extrem in new_forward_extremeties.items()
+ for room_id, new_extrem in new_forward_extremities.iteritems()
for event_id in new_extrem
]
)
- # Ensure that we don't have the same event twice.
- # Pick the earliest non-outlier if there is one, else the earliest one.
+ @classmethod
+ def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+ """Ensure that we don't have the same event twice.
+
+ Pick the earliest non-outlier if there is one, else the earliest one.
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]):
+ Returns:
+ list[(EventBase, EventContext)]: filtered list
+ """
new_events_and_contexts = OrderedDict()
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
@@ -639,9 +796,17 @@ class EventsStore(SQLBaseStore):
new_events_and_contexts[event.event_id] = (event, context)
else:
new_events_and_contexts[event.event_id] = (event, context)
+ return new_events_and_contexts.values()
- events_and_contexts = new_events_and_contexts.values()
+ def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+ """Update min_depth for each room
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ backfilled (bool): True if the events were backfilled
+ """
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@@ -657,9 +822,24 @@ class EventsStore(SQLBaseStore):
event.depth, depth_updates.get(event.room_id, event.depth)
)
- for room_id, depth in depth_updates.items():
+ for room_id, depth in depth_updates.iteritems():
self._update_min_depth_for_room_txn(txn, room_id, depth)
+ def _update_outliers_txn(self, txn, events_and_contexts):
+ """Update any outliers with new event info.
+
+ This turns outliers into ex-outliers (unless the new event was
+ rejected).
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+
+ Returns:
+ list[(EventBase, EventContext)] new list, without events which
+ are already in the events table.
+ """
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
",".join(["?"] * len(events_and_contexts)),
@@ -669,24 +849,21 @@ class EventsStore(SQLBaseStore):
have_persisted = {
event_id: outlier
- for event_id, outlier in txn.fetchall()
+ for event_id, outlier in txn
}
to_remove = set()
for event, context in events_and_contexts:
- if context.rejected:
- # If the event is rejected then we don't care if the event
- # was an outlier or not.
- if event.event_id in have_persisted:
- # If we have already seen the event then ignore it.
- to_remove.add(event)
- continue
-
if event.event_id not in have_persisted:
continue
to_remove.add(event)
+ if context.rejected:
+ # If the event is rejected then we don't care if the event
+ # was an outlier or not.
+ continue
+
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
@@ -741,37 +918,19 @@ class EventsStore(SQLBaseStore):
# event isn't an outlier any more.
self._update_backward_extremeties(txn, [event])
- events_and_contexts = [
+ return [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
+ @classmethod
+ def _delete_existing_rows_txn(cls, txn, events_and_contexts):
if not events_and_contexts:
- # Make sure we don't pass an empty list to functions that expect to
- # be storing at least one element.
+ # nothing to do here
return
- # From this point onwards the events are only events that we haven't
- # seen before.
-
- def event_dict(event):
- return {
- k: v
- for k, v in event.get_dict().items()
- if k not in [
- "redacted",
- "redacted_because",
- ]
- }
-
- if delete_existing:
- # For paranoia reasons, we go and delete all the existing entries
- # for these events so we can reinsert them.
- # This gets around any problems with some tables already having
- # entries.
+ logger.info("Deleting existing")
- logger.info("Deleting existing")
-
- for table in (
+ for table in (
"events",
"event_auth",
"event_json",
@@ -794,11 +953,30 @@ class EventsStore(SQLBaseStore):
"redactions",
"room_memberships",
"topics"
- ):
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- [(ev.event_id,) for ev, _ in events_and_contexts]
- )
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ [(ev.event_id,) for ev, _ in events_and_contexts]
+ )
+
+ def _store_event_txn(self, txn, events_and_contexts):
+ """Insert new events into the event and event_json tables
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ """
+
+ if not events_and_contexts:
+ # nothing to do here
+ return
+
+ def event_dict(event):
+ d = event.get_dict()
+ d.pop("redacted", None)
+ d.pop("redacted_because", None)
+ return d
self._simple_insert_many_txn(
txn,
@@ -842,6 +1020,19 @@ class EventsStore(SQLBaseStore):
],
)
+ def _store_rejected_events_txn(self, txn, events_and_contexts):
+ """Add rows to the 'rejections' table for received events which were
+ rejected
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+
+ Returns:
+ list[(EventBase, EventContext)] new list, without the rejected
+ events.
+ """
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
to_remove = set()
@@ -853,17 +1044,24 @@ class EventsStore(SQLBaseStore):
)
to_remove.add(event)
- events_and_contexts = [
+ return [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
+ def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+ """Update all the miscellaneous tables for new events
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ backfilled (bool): True if the events were backfilled
+ """
+
if not events_and_contexts:
- # Make sure we don't pass an empty list to functions that expect to
- # be storing at least one element.
+ # nothing to do here
return
- # From this point onwards the events are only ones that weren't rejected.
-
for event, context in events_and_contexts:
# Insert all the push actions into the event_push_actions table.
if context.push_actions:
@@ -892,10 +1090,6 @@ class EventsStore(SQLBaseStore):
],
)
- # Insert into the state_groups, state_groups_state, and
- # event_to_state_groups tables.
- self._store_mult_state_groups_txn(txn, events_and_contexts)
-
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@@ -982,13 +1176,6 @@ class EventsStore(SQLBaseStore):
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
- if backfilled:
- # Backfilled events come before the current state so we don't need
- # to update the current state table
- return
-
- return
-
def _add_to_cache(self, txn, events_and_contexts):
to_prefill = []
@@ -1597,14 +1784,13 @@ class EventsStore(SQLBaseStore):
def get_all_new_events_txn(txn):
sql = (
- "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
- " FROM events as e"
- " JOIN event_json as ej"
- " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
- " LEFT JOIN event_to_state_groups as eg"
- " ON e.event_id = eg.event_id"
- " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
" LIMIT ?"
)
if have_forward_events:
@@ -1630,15 +1816,13 @@ class EventsStore(SQLBaseStore):
forward_ex_outliers = []
sql = (
- "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
- " eg.state_group"
- " FROM events as e"
- " JOIN event_json as ej"
- " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
- " LEFT JOIN event_to_state_groups as eg"
- " ON e.event_id = eg.event_id"
- " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
- " ORDER BY e.stream_ordering DESC"
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering DESC"
" LIMIT ?"
)
if have_backfill_events:
@@ -1825,7 +2009,7 @@ class EventsStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in curr_state.items()
+ for key, state_id in curr_state.iteritems()
],
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 86b37b9d..3b5e0a4f 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
key_ids
Args:
server_name (str): The name of the server.
- key_ids (list of str): List of key_ids to try and look up.
+ key_ids (iterable[str]): key_ids to try and look up.
Returns:
- (list of VerifyKey): The verification keys.
+ Deferred: resolves to dict[str, VerifyKey]: map from
+ key_id to verification key.
"""
keys = {}
for key_id in key_ids:
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ed84db6b..6e623843 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
),
(current_version,)
)
- applied_deltas = [d for d, in txn.fetchall()]
+ applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded
return None
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 4d1590d2..9e9d3c25 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
- self._invalidate_cache_and_stream(
- txn, self._get_presence_for_user, (state.user_id,)
+ txn.call_after(
+ self._get_presence_for_user.invalidate, (state.user_id,)
)
# Actually insert new rows
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 5cf41501..6b0f8c27 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
)
txn.execute(sql, (room_id, receipt_type, user_id))
- results = txn.fetchall()
- if results and topological_ordering:
- for to, so, _ in results:
+ if topological_ordering:
+ for to, so, _ in txn:
if int(to) > topological_ordering:
return False
elif int(to) == topological_ordering and int(so) >= stream_ordering:
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26be6060..ec2c52ab 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
- return dict(txn.fetchall())
+ return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8a2fe2fd..e4c56cc1 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
- return dict(txn.fetchall())
+ return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
@@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
results = {}
# A room is visible if its visible on any list.
- for room_id, visibility in txn.fetchall():
+ for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False)
return results
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 545d3d3a..367dbbbc 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+ @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+ def get_hosts_in_room(self, room_id, cache_context):
+ """Returns the set of all hosts currently in the room
+ """
+ user_ids = yield self.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate,
+ )
+ hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+ defer.returnValue(hosts)
+
@cached(max_entries=500000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
-
- rows = self._get_members_rows_txn(
- txn,
- room_id=room_id,
- membership=Membership.JOIN,
+ sql = (
+ "SELECT m.user_id FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id "
+ " AND m.room_id = c.room_id "
+ " AND m.user_id = c.state_key"
+ " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
)
- return [r["user_id"] for r in rows]
+ txn.execute(sql, (room_id, Membership.JOIN,))
+ return [r[0] for r in txn]
return self.runInteraction("get_users_in_room", f)
@cached()
@@ -246,52 +259,27 @@ class RoomMemberStore(SQLBaseStore):
return results
- def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
- where_clause = "c.room_id = ?"
- where_values = [room_id]
-
- if membership:
- where_clause += " AND m.membership = ?"
- where_values.append(membership)
-
- if user_id:
- where_clause += " AND m.user_id = ?"
- where_values.append(user_id)
-
- sql = (
- "SELECT m.* FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id "
- " AND m.room_id = c.room_id "
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND %(where)s"
- ) % {
- "where": where_clause,
- }
-
- txn.execute(sql, where_values)
- rows = self.cursor_to_dict(txn)
-
- return rows
-
- @cached(max_entries=500000, iterable=True)
+ @cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id):
- return self.get_rooms_for_user_where_membership_is(
+ """Returns a set of room_ids the user is currently joined to
+ """
+ rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
+ defer.returnValue(frozenset(r.room_id for r in rooms))
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
"""Returns the set of users who share a room with `user_id`
"""
- rooms = yield self.get_rooms_for_user(
+ room_ids = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room = set()
- for room in rooms:
+ for room_id in room_ids:
user_ids = yield self.get_users_in_room(
- room.room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room.update(user_ids)
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index e1dca927..67d5d996 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
- return {k: v for k, v in txn.fetchall()}
+ return {k: v for k, v in txn}
def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 84482d82..fb23f6f4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,7 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
from synapse.util.caches import intern_string
from synapse.storage.engines import PostgresEngine
@@ -69,6 +69,18 @@ class StateStore(SQLBaseStore):
where_clause="type='m.room.member'",
)
+ @cachedInlineCallbacks(max_entries=100000, iterable=True)
+ def get_current_state_ids(self, room_id):
+ rows = yield self._simple_select_list(
+ table="current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=["event_id", "type", "state_key"],
+ desc="_calculate_state_delta",
+ )
+ defer.returnValue({
+ (r["type"], r["state_key"]): r["event_id"] for r in rows
+ })
+
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
@@ -78,7 +90,7 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state)
@@ -96,17 +108,18 @@ class StateStore(SQLBaseStore):
state_event_map = yield self.get_events(
[
- ev_id for group_ids in group_to_ids.values()
- for ev_id in group_ids.values()
+ ev_id for group_ids in group_to_ids.itervalues()
+ for ev_id in group_ids.itervalues()
],
get_prev_content=False
)
defer.returnValue({
group: [
- state_event_map[v] for v in event_id_map.values() if v in state_event_map
+ state_event_map[v] for v in event_id_map.itervalues()
+ if v in state_event_map
]
- for group, event_id_map in group_to_ids.items()
+ for group, event_id_map in group_to_ids.iteritems()
})
def _have_persisted_state_group_txn(self, txn, state_group):
@@ -124,6 +137,16 @@ class StateStore(SQLBaseStore):
continue
if context.current_state_ids is None:
+ # AFAIK, this can never happen
+ logger.error(
+ "Non-outlier event %s had current_state_ids==None",
+ event.event_id)
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
@@ -168,7 +191,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in context.delta_ids.items()
+ for key, state_id in context.delta_ids.iteritems()
],
)
else:
@@ -183,7 +206,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in context.current_state_ids.items()
+ for key, state_id in context.current_state_ids.iteritems()
],
)
@@ -195,7 +218,7 @@ class StateStore(SQLBaseStore):
"state_group": state_group_id,
"event_id": event_id,
}
- for event_id, state_group_id in state_groups.items()
+ for event_id, state_group_id in state_groups.iteritems()
],
)
@@ -319,10 +342,10 @@ class StateStore(SQLBaseStore):
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
- rows = self.cursor_to_dict(txn)
- for row in rows:
- key = (row["type"], row["state_key"])
- results[group][key] = row["event_id"]
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
else:
if types is not None:
where_clause = "AND (%s)" % (
@@ -351,12 +374,11 @@ class StateStore(SQLBaseStore):
" WHERE state_group = ? %s" % (where_clause,),
args
)
- rows = txn.fetchall()
- results[group].update({
- (typ, state_key): event_id
- for typ, state_key, event_id in rows
+ results[group].update(
+ ((typ, state_key), event_id)
+ for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
- })
+ )
# If the lengths match then we must have all the types,
# so no need to go walk further down the tree.
@@ -393,21 +415,21 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events(
- [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
get_prev_content=False
)
event_to_state = {
event_id: {
k: state_event_map[v]
- for k, v in group_to_state[group].items()
+ for k, v in group_to_state[group].iteritems()
if v in state_event_map
}
- for event_id, group in event_to_groups.items()
+ for event_id, group in event_to_groups.iteritems()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -430,12 +452,12 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = {
event_id: group_to_state[group]
- for event_id, group in event_to_groups.items()
+ for event_id, group in event_to_groups.iteritems()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -474,7 +496,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
- @cached(num_args=2, max_entries=10000)
+ @cached(num_args=2, max_entries=100000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
@@ -547,7 +569,7 @@ class StateStore(SQLBaseStore):
got_all = not (missing_types or types is None)
return {
- k: v for k, v in state_dict_ids.items()
+ k: v for k, v in state_dict_ids.iteritems()
if include(k[0], k[1])
}, missing_types, got_all
@@ -606,7 +628,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
- for group, group_state_dict in group_to_state_dict.items():
+ for group, group_state_dict in group_to_state_dict.iteritems():
if types:
# We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've
@@ -621,10 +643,10 @@ class StateStore(SQLBaseStore):
else:
state_dict = results[group]
- state_dict.update({
- (intern_string(k[0]), intern_string(k[1])): v
- for k, v in group_state_dict.items()
- })
+ state_dict.update(
+ ((intern_string(k[0]), intern_string(k[1])), v)
+ for k, v in group_state_dict.iteritems()
+ )
self._state_group_cache.update(
cache_seq_num,
@@ -635,10 +657,10 @@ class StateStore(SQLBaseStore):
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
- for group, state_dict in results.items():
+ for group, state_dict in results.iteritems():
results[group] = {
key: event_id
- for key, event_id in state_dict.items()
+ for key, event_id in state_dict.iteritems()
if event_id
}
@@ -727,7 +749,7 @@ class StateStore(SQLBaseStore):
# of keys
delta_state = {
- key: value for key, value in curr_state.items()
+ key: value for key, value in curr_state.iteritems()
if prev_state.get(key, None) != value
}
@@ -767,7 +789,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in delta_state.items()
+ for key, state_id in delta_state.iteritems()
],
)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 200d1246..dddd5fc0 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -829,3 +829,6 @@ class StreamStore(SQLBaseStore):
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
+
+ def has_room_changed_since(self, room_id, stream_id):
+ return self._events_stream_cache.has_entity_changed(room_id, stream_id)
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 5a2c1aa5..bff73f3f 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
tags = []
- for tag, content in txn.fetchall():
+ for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json))
@@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
- room_ids = [row[0] for row in txn.fetchall()]
+ room_ids = [row[0] for row in txn]
return room_ids
changed = self._account_data_stream_cache.has_entity_changed(
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 46cf93ff..95031dc9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -30,6 +30,17 @@ class IdGenerator(object):
def _load_current_id(db_conn, table, column, step=1):
+ """
+
+ Args:
+ db_conn (object):
+ table (str):
+ column (str):
+ step (int):
+
+ Returns:
+ int
+ """
cur = db_conn.cursor()
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
@@ -131,6 +142,9 @@ class StreamIdGenerator(object):
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
+
+ Returns:
+ int
"""
with self._lock:
if self._unfinished_ids:
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 30fc4801..98a5a26a 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError):
def __init__(self):
- super(SynapseError).__init__(504, "Timed out")
+ super(SynapseError, self).__init__(504, "Timed out")
def unwrapFirstError(failure):
@@ -93,8 +93,10 @@ class Clock(object):
ret_deferred = defer.Deferred()
def timed_out_fn():
+ e = DeferredTimedOutError()
+
try:
- ret_deferred.errback(DeferredTimedOutError())
+ ret_deferred.errback(e)
except:
pass
@@ -114,7 +116,7 @@ class Clock(object):
ret_deferred.addBoth(cancel)
- def sucess(res):
+ def success(res):
try:
ret_deferred.callback(res)
except:
@@ -128,7 +130,7 @@ class Clock(object):
except:
pass
- given_deferred.addCallbacks(callback=sucess, errback=err)
+ given_deferred.addCallbacks(callback=success, errback=err)
timer = self.call_later(time_out, timed_out_fn)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 998de70d..5c30ed23 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -15,12 +15,9 @@
import logging
from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
from . import DEBUG_CACHES, register_cache
@@ -189,7 +186,55 @@ class Cache(object):
self.cache.clear()
-class CacheDescriptor(object):
+class _CacheDescriptorBase(object):
+ def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ self.orig = orig
+
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
+ arg_spec = inspect.getargspec(orig)
+ all_args = arg_spec.args
+
+ if "cache_context" in all_args:
+ if not cache_context:
+ raise ValueError(
+ "Cannot have a 'cache_context' arg without setting"
+ " cache_context=True"
+ )
+ elif cache_context:
+ raise ValueError(
+ "Cannot have cache_context=True without having an arg"
+ " named `cache_context`"
+ )
+
+ if num_args is None:
+ num_args = len(all_args) - 1
+ if cache_context:
+ num_args -= 1
+
+ if len(all_args) < num_args + 1:
+ raise Exception(
+ "Not enough explicit positional arguments to key off for %r: "
+ "got %i args, but wanted %i. (@cached cannot key off *args or "
+ "**kwargs)"
+ % (orig.__name__, len(all_args), num_args)
+ )
+
+ self.num_args = num_args
+ self.arg_names = all_args[1:num_args + 1]
+
+ if "cache_context" in self.arg_names:
+ raise Exception(
+ "cache_context arg cannot be included among the cache keys"
+ )
+
+ self.add_cache_context = cache_context
+
+
+class CacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
@@ -217,52 +262,24 @@ class CacheDescriptor(object):
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
+ Args:
+ num_args (int): number of positional arguments (excluding ``self`` and
+ ``cache_context``) to use as cache keys. Defaults to all named
+ args of the function.
"""
- def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+ def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
inlineCallbacks=False, cache_context=False, iterable=False):
- max_entries = int(max_entries * CACHE_SIZE_FACTOR)
- self.orig = orig
+ super(CacheDescriptor, self).__init__(
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
+ cache_context=cache_context)
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
+ max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.max_entries = max_entries
- self.num_args = num_args
self.tree = tree
-
self.iterable = iterable
- all_args = inspect.getargspec(orig)
- self.arg_names = all_args.args[1:num_args + 1]
-
- if "cache_context" in all_args.args:
- if not cache_context:
- raise ValueError(
- "Cannot have a 'cache_context' arg without setting"
- " cache_context=True"
- )
- try:
- self.arg_names.remove("cache_context")
- except ValueError:
- pass
- elif cache_context:
- raise ValueError(
- "Cannot have cache_context=True without having an arg"
- " named `cache_context`"
- )
-
- self.add_cache_context = cache_context
-
- if len(self.arg_names) < self.num_args:
- raise Exception(
- "Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwargs)"
- % (orig.__name__,)
- )
-
def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__,
@@ -308,11 +325,9 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return preserve_context_over_deferred(observer)
except KeyError:
ret = defer.maybeDeferred(
- preserve_context_over_fn,
- self.function_to_call,
+ logcontext.preserve_fn(self.function_to_call),
obj, *args, **kwargs
)
@@ -322,10 +337,11 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
- ret = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, ret, callback=invalidate_callback)
+ result_d = ObservableDeferred(ret, consumeErrors=True)
+ cache.set(cache_key, result_d, callback=invalidate_callback)
+ observer = result_d.observe()
- return preserve_context_over_deferred(ret.observe())
+ return logcontext.make_deferred_yieldable(observer)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
@@ -338,48 +354,40 @@ class CacheDescriptor(object):
return wrapped
-class CacheListDescriptor(object):
+class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
- the list of missing keys to the wrapped fucntion.
+ the list of missing keys to the wrapped function.
+
+ Once wrapped, the function returns either a Deferred which resolves to
+ the list of results, or (if all results were cached), just the list of
+ results.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=1,
+ def __init__(self, orig, cached_method_name, list_name, num_args=None,
inlineCallbacks=False):
"""
Args:
orig (function)
- method_name (str); The name of the chached method.
+ cached_method_name (str): The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list
- num_args (int)
+ num_args (int): number of positional arguments (excluding ``self``,
+ but including list_name) to use as cache keys. Defaults to all
+ named args of the function.
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
- self.orig = orig
+ super(CacheListDescriptor, self).__init__(
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
-
- self.num_args = num_args
self.list_name = list_name
- self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
-
self.cached_method_name = cached_method_name
self.sentinel = object()
- if len(self.arg_names) < self.num_args:
- raise Exception(
- "Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
- % (orig.__name__,)
- )
-
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
@@ -425,8 +433,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
- preserve_context_over_fn,
- self.function_to_call,
+ logcontext.preserve_fn(self.function_to_call),
**args_to_call
)
@@ -435,8 +442,7 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- with PreserveLoggingContext():
- observer = ret_d.observe()
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -463,7 +469,7 @@ class CacheListDescriptor(object):
results.update(res)
return results
- return preserve_context_over_deferred(defer.gatherResults(
+ return logcontext.make_deferred_yieldable(defer.gatherResults(
cached_defers.values(),
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
@@ -487,7 +493,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
@@ -499,8 +505,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
- iterable=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
+ cache_context=False, iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -512,7 +518,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
)
-def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -525,7 +531,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
cache (Cache): The underlying cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
- num_args (int): Number of arguments to use as the key in the cache.
+ num_args (int): Number of arguments to use as the key in the cache
+ (including list_name). Defaults to all named parameters.
inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`?
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index b72bb0ff..70fe00ce 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -50,7 +50,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
- assert type(stream_pos) is int
+ assert type(stream_pos) is int or type(stream_pos) is long
if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 6c83eb21..857afee7 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+""" Thread-local-alike tracking of log contexts within synapse
+
+This module provides objects and utilities for tracking contexts through
+synapse code, so that log lines can include a request identifier, and so that
+CPU and database activity can be accounted for against the request that caused
+them.
+
+See doc/log_contexts.rst for details on how this works.
+"""
+
from twisted.internet import defer
import threading
@@ -300,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
+
+ Deprecated: this almost certainly doesn't do want you want, ie make
+ the deferred follow the synapse logcontext rules: try
+ ``make_deferred_yieldable`` instead.
"""
if context is None:
context = LoggingContext.current_context()
@@ -309,24 +323,65 @@ def preserve_context_over_deferred(deferred, context=None):
def preserve_fn(f):
- """Ensures that function is called with correct context and that context is
- restored after return. Useful for wrapping functions that return a deferred
- which you don't yield on.
+ """Wraps a function, to ensure that the current context is restored after
+ return from the function, and that the sentinel context is set once the
+ deferred returned by the funtion completes.
+
+ Useful for wrapping functions that return a deferred which you don't yield
+ on.
"""
+ def reset_context(result):
+ LoggingContext.set_current_context(LoggingContext.sentinel)
+ return result
+
+ # XXX: why is this here rather than inside g? surely we want to preserve
+ # the context from the time the function was called, not when it was
+ # wrapped?
current = LoggingContext.current_context()
def g(*args, **kwargs):
- with PreserveLoggingContext(current):
- res = f(*args, **kwargs)
- if isinstance(res, defer.Deferred):
- return preserve_context_over_deferred(
- res, context=LoggingContext.sentinel
- )
- else:
- return res
+ res = f(*args, **kwargs)
+ if isinstance(res, defer.Deferred) and not res.called:
+ # The function will have reset the context before returning, so
+ # we need to restore it now.
+ LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(reset_context)
+ return res
return g
+@defer.inlineCallbacks
+def make_deferred_yieldable(deferred):
+ """Given a deferred, make it follow the Synapse logcontext rules:
+
+ If the deferred has completed (or is not actually a Deferred), essentially
+ does nothing (just returns another completed deferred with the
+ result/failure).
+
+ If the deferred has not yet completed, resets the logcontext before
+ returning a deferred. Then, when the deferred completes, restores the
+ current logcontext before running callbacks/errbacks.
+
+ (This is more-or-less the opposite operation to preserve_fn.)
+ """
+ with PreserveLoggingContext():
+ r = yield deferred
+ defer.returnValue(r)
+
+
# modules to ignore in `logcontext_tracer`
_to_ignore = [
"synapse.util.logcontext",
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
new file mode 100644
index 00000000..607161e7
--- /dev/null
+++ b/synapse/util/msisdn.py
@@ -0,0 +1,40 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import phonenumbers
+from synapse.api.errors import SynapseError
+
+
+def phone_number_to_msisdn(country, number):
+ """
+ Takes an ISO-3166-1 2 letter country code and phone number and
+ returns an msisdn representing the canonical version of that
+ phone number.
+ Args:
+ country (str): ISO-3166-1 2 letter country code
+ number (str): Phone number in a national or international format
+
+ Returns:
+ (str) The canonical form of the phone number, as an msisdn
+ Raises:
+ SynapseError if the number could not be parsed.
+ """
+ try:
+ phoneNumber = phonenumbers.parse(number, country)
+ except phonenumbers.NumberParseException:
+ raise SynapseError(400, "Unable to parse phone number")
+ return phonenumbers.format_number(
+ phoneNumber, phonenumbers.PhoneNumberFormat.E164
+ )[1:]
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 153ef001..4fa9d1a0 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import synapse.util.logcontext
from twisted.internet import defer
from synapse.api.errors import CodeMessageException
@@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False,
+ **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
@@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
that will mark the destination as down if an exception is thrown (excluding
CodeMessageException with code < 500)
+ Args:
+ destination (str): name of homeserver
+ clock (synapse.util.clock): timing source
+ store (synapse.storage.transactions.TransactionStore): datastore
+ ignore_backoff (bool): true to ignore the historical backoff data and
+ try the request anyway. We will still update the next
+ retry_interval on success/failure.
+
Example usage:
try:
@@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
now = int(clock.time_msec())
- if retry_last_ts + retry_interval > now:
+ if not ignore_backoff and retry_last_ts + retry_interval > now:
raise NotRetryingDestination(
retry_last_ts=retry_last_ts,
retry_interval=retry_interval,
@@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False
- if exc_type is not None and issubclass(exc_type, CodeMessageException):
+ if exc_type is None:
+ valid_err_code = True
+ elif not issubclass(exc_type, Exception):
+ # avoid treating exceptions which don't derive from Exception as
+ # failures; this is mostly so as not to catch defer._DefGen.
+ valid_err_code = True
+ elif issubclass(exc_type, CodeMessageException):
# Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS
@@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
else:
valid_err_code = False
- if exc_type is None or valid_err_code:
+ if valid_err_code:
# We connected successfully.
if not self.retry_interval:
return
+ logger.debug("Connection to %s was successful; clearing backoff",
+ self.destination)
retry_last_ts = 0
self.retry_interval = 0
else:
@@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
else:
self.retry_interval = self.min_retry_interval
+ logger.debug(
+ "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
+ self.destination, exc_type, exc_val, self.retry_interval
+ )
retry_last_ts = int(self.clock.time_msec())
@defer.inlineCallbacks
@@ -173,4 +194,5 @@ class RetryDestinationLimiter(object):
"Failed to store set_destination_retry_timings",
)
- store_retry_timings()
+ # we deliberately do this in the background.
+ synapse.util.logcontext.preserve_fn(store_retry_timings)()
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 199b16d8..31659156 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
+ # Always allow the user to see their own leave events, otherwise
+ # they won't see the room disappear if they reject the invite
+ if membership == "leave" and (
+ prev_membership == "join" or prev_membership == "invite"
+ ):
+ return True
+
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority: