summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2016-12-16 10:43:39 +0000
committerErik Johnston <erikj@matrix.org>2016-12-16 10:43:39 +0000
commite25685dbda5a91922a7ecee0dce89db1ea1662a9 (patch)
treefce7147d9b4422f76b5cec8b53312bb34932d84f /synapse
parent4217a384316b04d4c2eed7814ca4defb62af8703 (diff)
Imported Upstream version 0.18.5
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py100
-rw-r--r--synapse/api/errors.py1
-rw-r--r--synapse/api/filtering.py45
-rw-r--r--synapse/app/federation_sender.py331
-rw-r--r--synapse/appservice/__init__.py3
-rw-r--r--synapse/appservice/api.py8
-rw-r--r--synapse/config/logger.py1
-rw-r--r--synapse/config/password_auth_providers.py20
-rw-r--r--synapse/config/registration.py6
-rw-r--r--synapse/config/server.py5
-rw-r--r--synapse/events/utils.py102
-rw-r--r--synapse/federation/__init__.py7
-rw-r--r--synapse/federation/federation_client.py69
-rw-r--r--synapse/federation/replication.py5
-rw-r--r--synapse/federation/send_queue.py298
-rw-r--r--synapse/federation/transaction_queue.py102
-rw-r--r--synapse/federation/transport/client.py9
-rw-r--r--synapse/federation/transport/server.py19
-rw-r--r--synapse/handlers/__init__.py2
-rw-r--r--synapse/handlers/auth.py53
-rw-r--r--synapse/handlers/devicemessage.py4
-rw-r--r--synapse/handlers/directory.py19
-rw-r--r--synapse/handlers/e2e_keys.py10
-rw-r--r--synapse/handlers/federation.py60
-rw-r--r--synapse/handlers/initial_sync.py7
-rw-r--r--synapse/handlers/message.py73
-rw-r--r--synapse/handlers/presence.py11
-rw-r--r--synapse/handlers/receipts.py10
-rw-r--r--synapse/handlers/register.py7
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/handlers/room_list.py64
-rw-r--r--synapse/handlers/sync.py44
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/http/matrixfederationclient.py48
-rw-r--r--synapse/http/servlet.py8
-rw-r--r--synapse/notifier.py49
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py12
-rw-r--r--synapse/python_dependencies.py6
-rw-r--r--synapse/replication/expire_cache.py60
-rw-r--r--synapse/replication/resource.py37
-rw-r--r--synapse/replication/slave/storage/_base.py19
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py20
-rw-r--r--synapse/replication/slave/storage/events.py14
-rw-r--r--synapse/replication/slave/storage/room.py3
-rw-r--r--synapse/replication/slave/storage/transactions.py10
-rw-r--r--synapse/rest/client/v1/directory.py34
-rw-r--r--synapse/rest/client/v1/login.py28
-rw-r--r--synapse/rest/client/v1/register.py12
-rw-r--r--synapse/rest/client/v1/room.py55
-rw-r--r--synapse/rest/client/v2_alpha/devices.py6
-rw-r--r--synapse/rest/client/v2_alpha/keys.py49
-rw-r--r--synapse/rest/client/v2_alpha/receipts.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py29
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py23
-rw-r--r--synapse/rest/client/v2_alpha/tokenrefresh.py26
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py7
-rw-r--r--synapse/server.py28
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/_base.py18
-rw-r--r--synapse/storage/appservice.py8
-rw-r--r--synapse/storage/deviceinbox.py29
-rw-r--r--synapse/storage/event_push_actions.py41
-rw-r--r--synapse/storage/events.py1
-rw-r--r--synapse/storage/filtering.py8
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/presence.py7
-rw-r--r--synapse/storage/push_rule.py12
-rw-r--r--synapse/storage/receipts.py2
-rw-r--r--synapse/storage/registration.py66
-rw-r--r--synapse/storage/room.py187
-rw-r--r--synapse/storage/roommember.py102
-rw-r--r--synapse/storage/schema/delta/39/appservice_room_list.sql29
-rw-r--r--synapse/storage/schema/delta/39/device_federation_stream_idx.sql16
-rw-r--r--synapse/storage/schema/delta/39/event_push_index.sql17
-rw-r--r--synapse/storage/schema/delta/39/federation_out_position.sql22
-rw-r--r--synapse/storage/schema/delta/39/membership_profile.sql20
-rw-r--r--synapse/storage/state.py5
-rw-r--r--synapse/storage/stream.py50
-rw-r--r--synapse/storage/transactions.py47
-rw-r--r--synapse/types.py34
-rw-r--r--synapse/util/__init__.py7
-rw-r--r--synapse/util/async.py58
-rw-r--r--synapse/util/jsonobject.py17
-rw-r--r--synapse/util/ldap_auth_provider.py369
-rw-r--r--synapse/util/retryutils.py23
87 files changed, 2283 insertions, 1014 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 432567a1..f006e10d 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.18.4"
+__version__ = "0.18.5"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 69b33927..a9998671 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -39,6 +39,9 @@ AuthEventTypes = (
EventTypes.ThirdPartyInvite,
)
+# guests always get this device id.
+GUEST_DEVICE_ID = "guest_device"
+
class Auth(object):
"""
@@ -51,17 +54,6 @@ class Auth(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
- # Docs for these currently lives at
- # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
- # In addition, we have type == delete_pusher which grants access only to
- # delete pushers.
- self._KNOWN_CAVEAT_PREFIXES = set([
- "gen = ",
- "guest = ",
- "type = ",
- "time < ",
- "user_id = ",
- ])
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
@@ -685,31 +677,28 @@ class Auth(object):
@defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"):
- """ Get a registered user's ID.
+ """ Validate access token and get user_id from it
Args:
token (str): The access token to get the user by.
+ rights (str): The operation being performed; the access token must
+ allow this.
Returns:
dict : dict that includes the user and the ID of their access token.
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
try:
- ret = yield self.get_user_from_macaroon(token, rights)
- except AuthError:
- # TODO(daniel): Remove this fallback when all existing access tokens
- # have been re-issued as macaroons.
- if self.hs.config.expire_access_token:
- raise
- ret = yield self._look_up_user_by_access_token(token)
-
- defer.returnValue(ret)
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ except Exception: # deserialize can throw more-or-less anything
+ # doesn't look like a macaroon: treat it as an opaque token which
+ # must be in the database.
+ # TODO: it would be nice to get rid of this, but apparently some
+ # people use access tokens which aren't macaroons
+ r = yield self._look_up_user_by_access_token(token)
+ defer.returnValue(r)
- @defer.inlineCallbacks
- def get_user_from_macaroon(self, macaroon_str, rights="access"):
try:
- macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
-
user_id = self.get_user_id_from_macaroon(macaroon)
user = UserID.from_string(user_id)
@@ -724,11 +713,36 @@ class Auth(object):
guest = True
if guest:
+ # Guest access tokens are not stored in the database (there can
+ # only be one access token per guest, anyway).
+ #
+ # In order to prevent guest access tokens being used as regular
+ # user access tokens (and hence getting around the invalidation
+ # process), we look up the user id and check that it is indeed
+ # a guest user.
+ #
+ # It would of course be much easier to store guest access
+ # tokens in the database as well, but that would break existing
+ # guest tokens.
+ stored_user = yield self.store.get_user_by_id(user_id)
+ if not stored_user:
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS,
+ "Unknown user_id %s" % user_id,
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+ if not stored_user["is_guest"]:
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS,
+ "Guest access token used for regular user",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
ret = {
"user": user,
"is_guest": True,
"token_id": None,
- "device_id": None,
+ # all guests get the same device id
+ "device_id": GUEST_DEVICE_ID,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
@@ -750,7 +764,7 @@ class Auth(object):
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
- ret = yield self._look_up_user_by_access_token(macaroon_str)
+ ret = yield self._look_up_user_by_access_token(token)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
@@ -798,27 +812,38 @@ class Auth(object):
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
- type_string(str): The kind of token required (e.g. "access", "refresh",
+ type_string(str): The kind of token required (e.g. "access",
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
- This should really always be True, but no clients currently implement
- token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required
"""
v = pymacaroons.Verifier()
+
+ # the verifier runs a test for every caveat on the macaroon, to check
+ # that it is met for the current request. Each caveat must match at
+ # least one of the predicates specified by satisfy_exact or
+ # specify_general.
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
+
+ # verify_expiry should really always be True, but there exist access
+ # tokens in the wild which expire when they should not, so we can't
+ # enforce expiry yet (so we have to allow any caveat starting with
+ # 'time < ' in access tokens).
+ #
+ # On the other hand, short-term login tokens (as used by CAS login, for
+ # example) have an expiry time which we do want to enforce.
+
if verify_expiry:
v.satisfy_general(self._verify_expiry)
else:
v.satisfy_general(lambda c: c.startswith("time < "))
- v.verify(macaroon, self.hs.config.macaroon_secret_key)
+ # access_tokens include a nonce for uniqueness: any value is acceptable
+ v.satisfy_general(lambda c: c.startswith("nonce = "))
- v = pymacaroons.Verifier()
- v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def _verify_expiry(self, caveat):
@@ -829,15 +854,6 @@ class Auth(object):
now = self.hs.get_clock().time_msec()
return now < expiry
- def _verify_recognizes_caveats(self, caveat):
- first_space = caveat.find(" ")
- if first_space < 0:
- return False
- second_space = caveat.find(" ", first_space + 1)
- if second_space < 0:
- return False
- return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
-
@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 00416468..921c4577 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -39,6 +39,7 @@ class Codes(object):
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
MISSING_PARAM = "M_MISSING_PARAM"
+ INVALID_PARAM = "M_INVALID_PARAM"
TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 3b3ef707..fb291d7f 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -71,6 +71,21 @@ class Filtering(object):
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
+ if "event_fields" in user_filter_json:
+ if type(user_filter_json["event_fields"]) != list:
+ raise SynapseError(400, "event_fields must be a list of strings")
+ for field in user_filter_json["event_fields"]:
+ if not isinstance(field, basestring):
+ raise SynapseError(400, "Event field must be a string")
+ # Don't allow '\\' in event field filters. This makes matching
+ # events a lot easier as we can then use a negative lookbehind
+ # assertion to split '\.' If we allowed \\ then it would
+ # incorrectly split '\\.' See synapse.events.utils.serialize_event
+ if r'\\' in field:
+ raise SynapseError(
+ 400, r'The escape character \ cannot itself be escaped'
+ )
+
def _check_definition_room_lists(self, definition):
"""Check that "rooms" and "not_rooms" are lists of room ids if they
are present
@@ -152,6 +167,7 @@ class FilterCollection(object):
self.include_leave = filter_json.get("room", {}).get(
"include_leave", False
)
+ self.event_fields = filter_json.get("event_fields", [])
def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
@@ -186,6 +202,26 @@ class FilterCollection(object):
def filter_room_account_data(self, events):
return self._room_account_data.filter(self._room_filter.filter(events))
+ def blocks_all_presence(self):
+ return (
+ self._presence_filter.filters_all_types() or
+ self._presence_filter.filters_all_senders()
+ )
+
+ def blocks_all_room_ephemeral(self):
+ return (
+ self._room_ephemeral_filter.filters_all_types() or
+ self._room_ephemeral_filter.filters_all_senders() or
+ self._room_ephemeral_filter.filters_all_rooms()
+ )
+
+ def blocks_all_room_timeline(self):
+ return (
+ self._room_timeline_filter.filters_all_types() or
+ self._room_timeline_filter.filters_all_senders() or
+ self._room_timeline_filter.filters_all_rooms()
+ )
+
class Filter(object):
def __init__(self, filter_json):
@@ -202,6 +238,15 @@ class Filter(object):
self.contains_url = self.filter_json.get("contains_url", None)
+ def filters_all_types(self):
+ return "*" in self.not_types
+
+ def filters_all_senders(self):
+ return "*" in self.not_senders
+
+ def filters_all_rooms(self):
+ return "*" in self.not_rooms
+
def check(self, event):
"""Checks whether the filter matches the given event.
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
new file mode 100644
index 00000000..80ea4c80
--- /dev/null
+++ b/synapse/app/federation_sender.py
@@ -0,0 +1,331 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse
+
+from synapse.server import HomeServer
+from synapse.config._base import ConfigError
+from synapse.config.logger import setup_logging
+from synapse.config.homeserver import HomeServerConfig
+from synapse.crypto import context_factory
+from synapse.http.site import SynapseSite
+from synapse.federation import send_queue
+from synapse.federation.units import Edu
+from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.storage.engines import create_engine
+from synapse.storage.presence import UserPresenceState
+from synapse.util.async import sleep
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.rlimit import change_resource_limit
+from synapse.util.versionstring import get_version_string
+
+from synapse import events
+
+from twisted.internet import reactor, defer
+from twisted.web.resource import Resource
+
+from daemonize import Daemonize
+
+import sys
+import logging
+import gc
+import ujson as json
+
+logger = logging.getLogger("synapse.app.appservice")
+
+
+class FederationSenderSlaveStore(
+ SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
+ SlavedRegistrationStore,
+):
+ pass
+
+
+class FederationSenderServer(HomeServer):
+ def get_db_conn(self, run_new_connection=True):
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_address = listener_config.get("bind_address", "")
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+
+ root_resource = create_resource_tree(resources, Resource())
+ reactor.listenTCP(
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ ),
+ interface=bind_address
+ )
+ logger.info("Synapse federation_sender now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ reactor.listenTCP(
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ ),
+ interface=listener.get("bind_address", '127.0.0.1')
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ @defer.inlineCallbacks
+ def replicate(self):
+ http_client = self.get_simple_http_client()
+ store = self.get_datastore()
+ replication_url = self.config.worker_replication_url
+ send_handler = FederationSenderHandler(self)
+
+ send_handler.on_start()
+
+ while True:
+ try:
+ args = store.stream_positions()
+ args.update((yield send_handler.stream_positions()))
+ args["timeout"] = 30000
+ result = yield http_client.get_json(replication_url, args=args)
+ yield store.process_replication(result)
+ yield send_handler.process_replication(result)
+ except:
+ logger.exception("Error replicating from %r", replication_url)
+ yield sleep(30)
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse federation sender", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.federation_sender"
+
+ setup_logging(config.worker_log_config, config.worker_log_file)
+
+ events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ database_engine = create_engine(config.database_config)
+
+ if config.send_federation:
+ sys.stderr.write(
+ "\nThe send_federation must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``send_federation: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.send_federation = True
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ps = FederationSenderServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ps.setup()
+ ps.start_listening(config.worker_listeners)
+
+ def run():
+ with LoggingContext("run"):
+ logger.info("Running")
+ change_resource_limit(config.soft_file_limit)
+ if config.gc_thresholds:
+ gc.set_threshold(*config.gc_thresholds)
+ reactor.run()
+
+ def start():
+ ps.replicate()
+ ps.get_datastore().start_profiling()
+ ps.get_state_handler().start_caching()
+
+ reactor.callWhenRunning(start)
+
+ if config.worker_daemonize:
+ daemon = Daemonize(
+ app="synapse-federation-sender",
+ pid=config.worker_pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
+
+
+class FederationSenderHandler(object):
+ """Processes the replication stream and forwards the appropriate entries
+ to the federation sender.
+ """
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.federation_sender = hs.get_federation_sender()
+
+ self._room_serials = {}
+ self._room_typing = {}
+
+ def on_start(self):
+ # There may be some events that are persisted but haven't been sent,
+ # so send them now.
+ self.federation_sender.notify_new_events(
+ self.store.get_room_max_stream_ordering()
+ )
+
+ @defer.inlineCallbacks
+ def stream_positions(self):
+ stream_id = yield self.store.get_federation_out_pos("federation")
+ defer.returnValue({
+ "federation": stream_id,
+
+ # Ack stuff we've "processed", this should only be called from
+ # one process.
+ "federation_ack": stream_id,
+ })
+
+ @defer.inlineCallbacks
+ def process_replication(self, result):
+ # The federation stream contains things that we want to send out, e.g.
+ # presence, typing, etc.
+ fed_stream = result.get("federation")
+ if fed_stream:
+ latest_id = int(fed_stream["position"])
+
+ # The federation stream containis a bunch of different types of
+ # rows that need to be handled differently. We parse the rows, put
+ # them into the appropriate collection and then send them off.
+ presence_to_send = {}
+ keyed_edus = {}
+ edus = {}
+ failures = {}
+ device_destinations = set()
+
+ # Parse the rows in the stream
+ for row in fed_stream["rows"]:
+ position, typ, content_js = row
+ content = json.loads(content_js)
+
+ if typ == send_queue.PRESENCE_TYPE:
+ destination = content["destination"]
+ state = UserPresenceState.from_dict(content["state"])
+
+ presence_to_send.setdefault(destination, []).append(state)
+ elif typ == send_queue.KEYED_EDU_TYPE:
+ key = content["key"]
+ edu = Edu(**content["edu"])
+
+ keyed_edus.setdefault(
+ edu.destination, {}
+ )[(edu.destination, tuple(key))] = edu
+ elif typ == send_queue.EDU_TYPE:
+ edu = Edu(**content)
+
+ edus.setdefault(edu.destination, []).append(edu)
+ elif typ == send_queue.FAILURE_TYPE:
+ destination = content["destination"]
+ failure = content["failure"]
+
+ failures.setdefault(destination, []).append(failure)
+ elif typ == send_queue.DEVICE_MESSAGE_TYPE:
+ device_destinations.add(content["destination"])
+ else:
+ raise Exception("Unrecognised federation type: %r", typ)
+
+ # We've finished collecting, send everything off
+ for destination, states in presence_to_send.items():
+ self.federation_sender.send_presence(destination, states)
+
+ for destination, edu_map in keyed_edus.items():
+ for key, edu in edu_map.items():
+ self.federation_sender.send_edu(
+ edu.destination, edu.edu_type, edu.content, key=key,
+ )
+
+ for destination, edu_list in edus.items():
+ for edu in edu_list:
+ self.federation_sender.send_edu(
+ edu.destination, edu.edu_type, edu.content, key=None,
+ )
+
+ for destination, failure_list in failures.items():
+ for failure in failure_list:
+ self.federation_sender.send_failure(destination, failure)
+
+ for destination in device_destinations:
+ self.federation_sender.send_device_messages(destination)
+
+ # Record where we are in the stream.
+ yield self.store.update_federation_out_pos(
+ "federation", latest_id
+ )
+
+ # We also need to poke the federation sender when new events happen
+ event_stream = result.get("events")
+ if event_stream:
+ latest_pos = event_stream["position"]
+ self.federation_sender.notify_new_events(latest_pos)
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 91471f7e..b0106a35 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -89,6 +89,9 @@ class ApplicationService(object):
self.namespaces = self._check_namespaces(namespaces)
self.id = id
+ if "|" in self.id:
+ raise Exception("application service ID cannot contain '|' character")
+
# .protocols is a publicly visible field
if protocols:
self.protocols = set(protocols)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index b0eb0c6d..6893610e 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -19,6 +19,7 @@ from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
from synapse.util.caches.response_cache import ResponseCache
+from synapse.types import ThirdPartyInstanceID
import logging
import urllib
@@ -177,6 +178,13 @@ class ApplicationServiceApi(SimpleHttpClient):
" valid result", uri)
defer.returnValue(None)
+ for instance in info.get("instances", []):
+ network_id = instance.get("network_id", None)
+ if network_id is not None:
+ instance["instance_id"] = ThirdPartyInstanceID(
+ service.id, network_id,
+ ).to_string()
+
defer.returnValue(info)
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index dc68683f..ec72c954 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -50,6 +50,7 @@ handlers:
console:
class: logging.StreamHandler
formatter: precise
+ filters: [context]
loggers:
synapse:
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 1f438d2b..83762d08 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -27,17 +27,23 @@ class PasswordAuthProviderConfig(Config):
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
- from synapse.util.ldap_auth_provider import LdapAuthProvider
+ from ldap_auth_provider import LdapAuthProvider
parsed_config = LdapAuthProvider.parse_config(ldap_config)
self.password_providers.append((LdapAuthProvider, parsed_config))
providers = config.get("password_providers", [])
for provider in providers:
- # We need to import the module, and then pick the class out of
- # that, so we split based on the last dot.
- module, clz = provider['module'].rsplit(".", 1)
- module = importlib.import_module(module)
- provider_class = getattr(module, clz)
+ # This is for backwards compat when the ldap auth provider resided
+ # in this package.
+ if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
+ from ldap_auth_provider import LdapAuthProvider
+ provider_class = LdapAuthProvider
+ else:
+ # We need to import the module, and then pick the class out of
+ # that, so we split based on the last dot.
+ module, clz = provider['module'].rsplit(".", 1)
+ module = importlib.import_module(module)
+ provider_class = getattr(module, clz)
try:
provider_config = provider_class.parse_config(provider["config"])
@@ -50,7 +56,7 @@ class PasswordAuthProviderConfig(Config):
def default_config(self, **kwargs):
return """\
# password_providers:
- # - module: "synapse.util.ldap_auth_provider.LdapAuthProvider"
+ # - module: "ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index cc3f8798..87e500c9 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -32,7 +32,6 @@ class RegistrationConfig(Config):
)
self.registration_shared_secret = config.get("registration_shared_secret")
- self.user_creation_max_duration = int(config["user_creation_max_duration"])
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@@ -55,11 +54,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
- # Sets the expiry for the short term user creation in
- # milliseconds. For instance the bellow duration is two weeks
- # in milliseconds.
- user_creation_max_duration: 1209600000
-
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ed5417d0..634d8e6f 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -30,6 +30,11 @@ class ServerConfig(Config):
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
+ # Whether to send federation traffic out in this process. This only
+ # applies to some federation traffic, and so shouldn't be used to
+ # "disable" federation
+ self.send_federation = config.get("send_federation", True)
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 0e9fd902..5bbaef81 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -16,6 +16,17 @@
from synapse.api.constants import EventTypes
from . import EventBase
+from frozendict import frozendict
+
+import re
+
+# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
+# (?<!stuff) matches if the current position in the string is not preceded
+# by a match for 'stuff'.
+# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
+# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
+SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
+
def prune_event(event):
""" Returns a pruned version of the given event, which removes all keys we
@@ -97,6 +108,83 @@ def prune_event(event):
)
+def _copy_field(src, dst, field):
+ """Copy the field in 'src' to 'dst'.
+
+ For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
+ then dst={"foo":{"bar":5}}.
+
+ Args:
+ src(dict): The dict to read from.
+ dst(dict): The dict to modify.
+ field(list<str>): List of keys to drill down to in 'src'.
+ """
+ if len(field) == 0: # this should be impossible
+ return
+ if len(field) == 1: # common case e.g. 'origin_server_ts'
+ if field[0] in src:
+ dst[field[0]] = src[field[0]]
+ return
+
+ # Else is a nested field e.g. 'content.body'
+ # Pop the last field as that's the key to move across and we need the
+ # parent dict in order to access the data. Drill down to the right dict.
+ key_to_move = field.pop(-1)
+ sub_dict = src
+ for sub_field in field: # e.g. sub_field => "content"
+ if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
+ sub_dict = sub_dict[sub_field]
+ else:
+ return
+
+ if key_to_move not in sub_dict:
+ return
+
+ # Insert the key into the output dictionary, creating nested objects
+ # as required. We couldn't do this any earlier or else we'd need to delete
+ # the empty objects if the key didn't exist.
+ sub_out_dict = dst
+ for sub_field in field:
+ sub_out_dict = sub_out_dict.setdefault(sub_field, {})
+ sub_out_dict[key_to_move] = sub_dict[key_to_move]
+
+
+def only_fields(dictionary, fields):
+ """Return a new dict with only the fields in 'dictionary' which are present
+ in 'fields'.
+
+ If there are no event fields specified then all fields are included.
+ The entries may include '.' charaters to indicate sub-fields.
+ So ['content.body'] will include the 'body' field of the 'content' object.
+ A literal '.' character in a field name may be escaped using a '\'.
+
+ Args:
+ dictionary(dict): The dictionary to read from.
+ fields(list<str>): A list of fields to copy over. Only shallow refs are
+ taken.
+ Returns:
+ dict: A new dictionary with only the given fields. If fields was empty,
+ the same dictionary is returned.
+ """
+ if len(fields) == 0:
+ return dictionary
+
+ # for each field, convert it:
+ # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
+ split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
+
+ # for each element of the output array of arrays:
+ # remove escaping so we can use the right key names.
+ split_fields[:] = [
+ [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
+ ]
+
+ output = {}
+ for field_array in split_fields:
+ _copy_field(dictionary, output, field_array)
+ return output
+
+
def format_event_raw(d):
return d
@@ -137,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d):
def serialize_event(e, time_now_ms, as_client_event=True,
event_format=format_event_for_client_v1,
- token_id=None):
+ token_id=None, only_event_fields=None):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
@@ -164,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d["unsigned"]["transaction_id"] = txn_id
if as_client_event:
- return event_format(d)
- else:
- return d
+ d = event_format(d)
+
+ if only_event_fields:
+ if (not isinstance(only_event_fields, list) or
+ not all(isinstance(f, basestring) for f in only_event_fields)):
+ raise TypeError("only_event_fields must be a list of strings")
+ d = only_fields(d, only_event_fields)
+
+ return d
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 979fdf24..2e32d245 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -17,10 +17,9 @@
"""
from .replication import ReplicationLayer
-from .transport.client import TransportLayerClient
-def initialize_http_replication(homeserver):
- transport = TransportLayerClient(homeserver)
+def initialize_http_replication(hs):
+ transport = hs.get_federation_transport_client()
- return ReplicationLayer(homeserver, transport)
+ return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 94e76b19..6e23c207 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from .federation_base import FederationBase
from synapse.api.constants import Membership
-from .units import Edu
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
@@ -45,10 +44,6 @@ logger = logging.getLogger(__name__)
# synapse.federation.federation_client is a silly name
metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
-
-sent_edus_counter = metrics.register_counter("sent_edus")
-
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
@@ -93,63 +88,6 @@ class FederationClient(FederationBase):
self._get_pdu_cache.start()
@log_function
- def send_pdu(self, pdu, destinations):
- """Informs the replication layer about a new PDU generated within the
- home server that should be transmitted to others.
-
- TODO: Figure out when we should actually resolve the deferred.
-
- Args:
- pdu (Pdu): The new Pdu.
-
- Returns:
- Deferred: Completes when we have successfully processed the PDU
- and replicated it to any interested remote home servers.
- """
- order = self._order
- self._order += 1
-
- sent_pdus_destination_dist.inc_by(len(destinations))
-
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
-
- # TODO, add errback, etc.
- self._transaction_queue.enqueue_pdu(pdu, destinations, order)
-
- logger.debug(
- "[%s] transaction_layer.enqueue_pdu... done",
- pdu.event_id
- )
-
- def send_presence(self, destination, states):
- if destination != self.server_name:
- self._transaction_queue.enqueue_presence(destination, states)
-
- @log_function
- def send_edu(self, destination, edu_type, content, key=None):
- edu = Edu(
- origin=self.server_name,
- destination=destination,
- edu_type=edu_type,
- content=content,
- )
-
- sent_edus_counter.inc()
-
- self._transaction_queue.enqueue_edu(edu, key=key)
-
- @log_function
- def send_device_messages(self, destination):
- """Sends the device messages in the local database to the remote
- destination"""
- self._transaction_queue.enqueue_device_messages(destination)
-
- @log_function
- def send_failure(self, failure, destination):
- self._transaction_queue.enqueue_failure(failure, destination)
- return defer.succeed(None)
-
- @log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=False):
"""Sends a federation Query to a remote homeserver of the given type
@@ -717,12 +655,15 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.")
def get_public_rooms(self, destination, limit=None, since_token=None,
- search_filter=None):
+ search_filter=None, include_all_networks=False,
+ third_party_instance_id=None):
if destination == self.server_name:
return
return self.transport_layer.get_public_rooms(
- destination, limit, since_token, search_filter
+ destination, limit, since_token, search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
)
@defer.inlineCallbacks
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index ea66a5dc..62d865ec 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -20,8 +20,6 @@ a given transport.
from .federation_client import FederationClient
from .federation_server import FederationServer
-from .transaction_queue import TransactionQueue
-
from .persistence import TransactionActions
import logging
@@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store)
- self._transaction_queue = TransactionQueue(hs, transport_layer)
-
- self._order = 0
self.hs = hs
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
new file mode 100644
index 00000000..5c9f7a86
--- /dev/null
+++ b/synapse/federation/send_queue.py
@@ -0,0 +1,298 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A federation sender that forwards things to be sent across replication to
+a worker process.
+
+It assumes there is a single worker process feeding off of it.
+
+Each row in the replication stream consists of a type and some json, where the
+types indicate whether they are presence, or edus, etc.
+
+Ephemeral or non-event data are queued up in-memory. When the worker requests
+updates since a particular point, all in-memory data since before that point is
+dropped. We also expire things in the queue after 5 minutes, to ensure that a
+dead worker doesn't cause the queues to grow limitlessly.
+
+Events are replicated via a separate events stream.
+"""
+
+from .units import Edu
+
+from synapse.util.metrics import Measure
+import synapse.metrics
+
+from blist import sorteddict
+import ujson
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+
+PRESENCE_TYPE = "p"
+KEYED_EDU_TYPE = "k"
+EDU_TYPE = "e"
+FAILURE_TYPE = "f"
+DEVICE_MESSAGE_TYPE = "d"
+
+
+class FederationRemoteSendQueue(object):
+ """A drop in replacement for TransactionQueue"""
+
+ def __init__(self, hs):
+ self.server_name = hs.hostname
+ self.clock = hs.get_clock()
+
+ self.presence_map = {}
+ self.presence_changed = sorteddict()
+
+ self.keyed_edu = {}
+ self.keyed_edu_changed = sorteddict()
+
+ self.edus = sorteddict()
+
+ self.failures = sorteddict()
+
+ self.device_messages = sorteddict()
+
+ self.pos = 1
+ self.pos_time = sorteddict()
+
+ # EVERYTHING IS SAD. In particular, python only makes new scopes when
+ # we make a new function, so we need to make a new function so the inner
+ # lambda binds to the queue rather than to the name of the queue which
+ # changes. ARGH.
+ def register(name, queue):
+ metrics.register_callback(
+ queue_name + "_size",
+ lambda: len(queue),
+ )
+
+ for queue_name in [
+ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
+ "edus", "failures", "device_messages", "pos_time",
+ ]:
+ register(queue_name, getattr(self, queue_name))
+
+ self.clock.looping_call(self._clear_queue, 30 * 1000)
+
+ def _next_pos(self):
+ pos = self.pos
+ self.pos += 1
+ self.pos_time[self.clock.time_msec()] = pos
+ return pos
+
+ def _clear_queue(self):
+ """Clear the queues for anything older than N minutes"""
+
+ FIVE_MINUTES_AGO = 5 * 60 * 1000
+ now = self.clock.time_msec()
+
+ keys = self.pos_time.keys()
+ time = keys.bisect_left(now - FIVE_MINUTES_AGO)
+ if not keys[:time]:
+ return
+
+ position_to_delete = max(keys[:time])
+ for key in keys[:time]:
+ del self.pos_time[key]
+
+ self._clear_queue_before_pos(position_to_delete)
+
+ def _clear_queue_before_pos(self, position_to_delete):
+ """Clear all the queues from before a given position"""
+ with Measure(self.clock, "send_queue._clear"):
+ # Delete things out of presence maps
+ keys = self.presence_changed.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.presence_changed[key]
+
+ user_ids = set(
+ user_id for uids in self.presence_changed.values() for _, user_id in uids
+ )
+
+ to_del = [
+ user_id for user_id in self.presence_map if user_id not in user_ids
+ ]
+ for user_id in to_del:
+ del self.presence_map[user_id]
+
+ # Delete things out of keyed edus
+ keys = self.keyed_edu_changed.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.keyed_edu_changed[key]
+
+ live_keys = set()
+ for edu_key in self.keyed_edu_changed.values():
+ live_keys.add(edu_key)
+
+ to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
+ for edu_key in to_del:
+ del self.keyed_edu[edu_key]
+
+ # Delete things out of edu map
+ keys = self.edus.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.edus[key]
+
+ # Delete things out of failure map
+ keys = self.failures.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.failures[key]
+
+ # Delete things out of device map
+ keys = self.device_messages.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.device_messages[key]
+
+ def notify_new_events(self, current_id):
+ """As per TransactionQueue"""
+ # We don't need to replicate this as it gets sent down a different
+ # stream.
+ pass
+
+ def send_edu(self, destination, edu_type, content, key=None):
+ """As per TransactionQueue"""
+ pos = self._next_pos()
+
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
+
+ if key:
+ assert isinstance(key, tuple)
+ self.keyed_edu[(destination, key)] = edu
+ self.keyed_edu_changed[pos] = (destination, key)
+ else:
+ self.edus[pos] = edu
+
+ def send_presence(self, destination, states):
+ """As per TransactionQueue"""
+ pos = self._next_pos()
+
+ self.presence_map.update({
+ state.user_id: state
+ for state in states
+ })
+
+ self.presence_changed[pos] = [
+ (destination, state.user_id) for state in states
+ ]
+
+ def send_failure(self, failure, destination):
+ """As per TransactionQueue"""
+ pos = self._next_pos()
+
+ self.failures[pos] = (destination, str(failure))
+
+ def send_device_messages(self, destination):
+ """As per TransactionQueue"""
+ pos = self._next_pos()
+ self.device_messages[pos] = destination
+
+ def get_current_token(self):
+ return self.pos - 1
+
+ def get_replication_rows(self, token, limit, federation_ack=None):
+ """
+ Args:
+ token (int)
+ limit (int)
+ federation_ack (int): Optional. The position where the worker is
+ explicitly acknowledged it has handled. Allows us to drop
+ data from before that point
+ """
+ # TODO: Handle limit.
+
+ # To handle restarts where we wrap around
+ if token > self.pos:
+ token = -1
+
+ rows = []
+
+ # There should be only one reader, so lets delete everything its
+ # acknowledged its seen.
+ if federation_ack:
+ self._clear_queue_before_pos(federation_ack)
+
+ # Fetch changed presence
+ keys = self.presence_changed.keys()
+ i = keys.bisect_right(token)
+ dest_user_ids = set(
+ (pos, dest_user_id)
+ for pos in keys[i:]
+ for dest_user_id in self.presence_changed[pos]
+ )
+
+ for (key, (dest, user_id)) in dest_user_ids:
+ rows.append((key, PRESENCE_TYPE, ujson.dumps({
+ "destination": dest,
+ "state": self.presence_map[user_id].as_dict(),
+ })))
+
+ # Fetch changes keyed edus
+ keys = self.keyed_edu_changed.keys()
+ i = keys.bisect_right(token)
+ keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
+
+ for (pos, (destination, edu_key)) in keyed_edus:
+ rows.append(
+ (pos, KEYED_EDU_TYPE, ujson.dumps({
+ "key": edu_key,
+ "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
+ }))
+ )
+
+ # Fetch changed edus
+ keys = self.edus.keys()
+ i = keys.bisect_right(token)
+ edus = set((k, self.edus[k]) for k in keys[i:])
+
+ for (pos, edu) in edus:
+ rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+
+ # Fetch changed failures
+ keys = self.failures.keys()
+ i = keys.bisect_right(token)
+ failures = set((k, self.failures[k]) for k in keys[i:])
+
+ for (pos, (destination, failure)) in failures:
+ rows.append((pos, FAILURE_TYPE, ujson.dumps({
+ "destination": destination,
+ "failure": failure,
+ })))
+
+ # Fetch changed device messages
+ keys = self.device_messages.keys()
+ i = keys.bisect_right(token)
+ device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+
+ for (pos, destination) in device_messages:
+ rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
+ "destination": destination,
+ })))
+
+ # Sort rows based on pos
+ rows.sort()
+
+ return rows
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index f8ca93e4..51b656d7 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction, Edu
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
@@ -26,6 +27,7 @@ from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func
+from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state
import synapse.metrics
@@ -36,6 +38,12 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
+client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+sent_pdus_destination_dist = client_metrics.register_distribution(
+ "sent_pdu_destinations"
+)
+sent_edus_counter = client_metrics.register_counter("sent_edus")
+
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -44,13 +52,14 @@ class TransactionQueue(object):
It batches pending PDUs into single transactions.
"""
- def __init__(self, hs, transport_layer):
+ def __init__(self, hs):
self.server_name = hs.hostname
self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
self.transaction_actions = TransactionActions(self.store)
- self.transport_layer = transport_layer
+ self.transport_layer = hs.get_federation_transport_client()
self.clock = hs.get_clock()
@@ -95,6 +104,11 @@ class TransactionQueue(object):
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
+ self._order = 1
+
+ self._is_processing = False
+ self._last_poked_id = -1
+
def can_send_to(self, destination):
"""Can we send messages to the given server?
@@ -115,11 +129,61 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
- def enqueue_pdu(self, pdu, destinations, order):
+ @defer.inlineCallbacks
+ def notify_new_events(self, current_id):
+ """This gets called when we have some new events we might want to
+ send out to other servers.
+ """
+ self._last_poked_id = max(current_id, self._last_poked_id)
+
+ if self._is_processing:
+ return
+
+ try:
+ self._is_processing = True
+ while True:
+ last_token = yield self.store.get_federation_out_pos("events")
+ next_token, events = yield self.store.get_all_new_events_stream(
+ last_token, self._last_poked_id, limit=20,
+ )
+
+ logger.debug("Handling %s -> %s", last_token, next_token)
+
+ if not events and next_token >= self._last_poked_id:
+ break
+
+ for event in events:
+ users_in_room = yield self.state.get_current_user_in_room(
+ event.room_id, latest_event_ids=[event.event_id],
+ )
+
+ destinations = set(
+ get_domain_from_id(user_id) for user_id in users_in_room
+ )
+
+ if event.type == EventTypes.Member:
+ if event.content["membership"] == Membership.JOIN:
+ destinations.add(get_domain_from_id(event.state_key))
+
+ logger.debug("Sending %s to %r", event, destinations)
+
+ self._send_pdu(event, destinations)
+
+ yield self.store.update_federation_out_pos(
+ "events", next_token
+ )
+
+ finally:
+ self._is_processing = False
+
+ def _send_pdu(self, pdu, destinations):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
+ order = self._order
+ self._order += 1
+
destinations = set(destinations)
destinations = set(
dest for dest in destinations if self.can_send_to(dest)
@@ -130,6 +194,8 @@ class TransactionQueue(object):
if not destinations:
return
+ sent_pdus_destination_dist.inc_by(len(destinations))
+
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order)
@@ -139,7 +205,10 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_presence(self, destination, states):
+ def send_presence(self, destination, states):
+ if not self.can_send_to(destination):
+ return
+
self.pending_presence_by_dest.setdefault(destination, {}).update({
state.user_id: state for state in states
})
@@ -148,12 +217,19 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_edu(self, edu, key=None):
- destination = edu.destination
+ def send_edu(self, destination, edu_type, content, key=None):
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
if not self.can_send_to(destination):
return
+ sent_edus_counter.inc()
+
if key:
self.pending_edus_keyed_by_dest.setdefault(
destination, {}
@@ -165,7 +241,7 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_failure(self, failure, destination):
+ def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -180,7 +256,7 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_device_messages(self, destination):
+ def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -191,6 +267,9 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
+ def get_current_token(self):
+ return 0
+
@defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
# list of (pending_pdu, deferred, order)
@@ -383,6 +462,13 @@ class TransactionQueue(object):
code = e.code
response = e.response
+ if e.code == 429 or 500 <= e.code:
+ logger.info(
+ "TX [%s] {%s} got %d response",
+ destination, txn_id, code
+ )
+ raise e
+
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index db45c782..491cdc29 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -249,10 +249,15 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_public_rooms(self, remote_server, limit, since_token,
- search_filter=None):
+ search_filter=None, include_all_networks=False,
+ third_party_instance_id=None):
path = PREFIX + "/publicRooms"
- args = {}
+ args = {
+ "include_all_networks": "true" if include_all_networks else "false",
+ }
+ if third_party_instance_id:
+ args["third_party_instance_id"] = third_party_instance_id,
if limit:
args["limit"] = [str(limit)]
if since_token:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fec337be..159dbd17 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
+ parse_boolean_from_args,
)
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
+from synapse.types import ThirdPartyInstanceID
import functools
import logging
@@ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet):
def on_GET(self, origin, content, query):
limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
+ include_all_networks = parse_boolean_from_args(
+ query, "include_all_networks", False
+ )
+ third_party_instance_id = parse_string_from_args(
+ query, "third_party_instance_id", None
+ )
+
+ if include_all_networks:
+ network_tuple = None
+ elif third_party_instance_id:
+ network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+ else:
+ network_tuple = ThirdPartyInstanceID(None, None)
+
data = yield self.room_list_handler.get_local_public_room_list(
- limit, since_token
+ limit, since_token,
+ network_tuple=network_tuple
)
defer.returnValue((200, data))
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 63d05f25..5ad408f5 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -24,7 +24,6 @@ from .profile import ProfileHandler
from .directory import DirectoryHandler
from .admin import AdminHandler
from .identity import IdentityHandler
-from .receipts import ReceiptsHandler
from .search import SearchHandler
@@ -56,7 +55,6 @@ class Handlers(object):
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
- self.receipts_handler = ReceiptsHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3851b358..3b146f09 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -61,6 +61,8 @@ class AuthHandler(BaseHandler):
for module, config in hs.config.password_providers
]
+ logger.info("Extra password_providers: %r", self.password_providers)
+
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@@ -160,7 +162,15 @@ class AuthHandler(BaseHandler):
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
- logger.info("Auth completed with creds: %r", creds)
+ # it's very useful to know what args are stored, but this can
+ # include the password in the case of registering, so only log
+ # the keys (confusingly, clientdict may contain a password
+ # param, creds is just what the user authed as for UI auth
+ # and is not sensitive).
+ logger.info(
+ "Auth completed with creds: %r. Client dict has keys: %r",
+ creds, clientdict.keys()
+ )
defer.returnValue((True, creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
@@ -378,12 +388,10 @@ class AuthHandler(BaseHandler):
return self._check_password(user_id, password)
@defer.inlineCallbacks
- def get_login_tuple_for_user_id(self, user_id, device_id=None,
- initial_display_name=None):
+ def get_access_token_for_user_id(self, user_id, device_id=None,
+ initial_display_name=None):
"""
- Gets login tuple for the user with the given user ID.
-
- Creates a new access/refresh token for the user.
+ Creates a new access token for the user with the given user ID.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case.
@@ -398,16 +406,13 @@ class AuthHandler(BaseHandler):
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns:
- A tuple of:
The access token for the user's session.
- The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
- refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
@@ -418,7 +423,7 @@ class AuthHandler(BaseHandler):
user_id, device_id, initial_display_name
)
- defer.returnValue((access_token, refresh_token))
+ defer.returnValue(access_token)
@defer.inlineCallbacks
def check_user_exists(self, user_id):
@@ -529,35 +534,19 @@ class AuthHandler(BaseHandler):
device_id)
defer.returnValue(access_token)
- @defer.inlineCallbacks
- def issue_refresh_token(self, user_id, device_id=None):
- refresh_token = self.generate_refresh_token(user_id)
- yield self.store.add_refresh_token_to_user(user_id, refresh_token,
- device_id)
- defer.returnValue(refresh_token)
-
- def generate_access_token(self, user_id, extra_caveats=None,
- duration_in_ms=(60 * 60 * 1000)):
+ def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
- now = self.hs.get_clock().time_msec()
- expiry = now + duration_in_ms
- macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ # Include a nonce, to make sure that each login gets a different
+ # access token.
+ macaroon.add_first_party_caveat("nonce = %s" % (
+ stringutils.random_string_with_symbols(16),
+ ))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
- def generate_refresh_token(self, user_id):
- m = self._generate_base_macaroon(user_id)
- m.add_first_party_caveat("type = refresh")
- # Important to add a nonce, because otherwise every refresh token for a
- # user will be the same.
- m.add_first_party_caveat("nonce = %s" % (
- stringutils.random_string_with_symbols(16),
- ))
- return m.serialize()
-
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index c5368e5d..f7fad15c 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -34,9 +34,9 @@ class DeviceMessageHandler(object):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.federation = hs.get_replication_layer()
+ self.federation = hs.get_federation_sender()
- self.federation.register_edu_handler(
+ hs.get_replication_layer().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index c00274af..1b5317ed 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -339,3 +339,22 @@ class DirectoryHandler(BaseHandler):
yield self.auth.check_can_change_room_list(room_id, requester.user)
yield self.store.set_room_is_public(room_id, visibility == "public")
+
+ @defer.inlineCallbacks
+ def edit_published_appservice_room_list(self, appservice_id, network_id,
+ room_id, visibility):
+ """Add or remove a room from the appservice/network specific public
+ room list.
+
+ Args:
+ appservice_id (str): ID of the appservice that owns the list
+ network_id (str): The ID of the network the list is associated with
+ room_id (str)
+ visibility (str): either "public" or "private"
+ """
+ if visibility not in ["public", "private"]:
+ raise SynapseError(400, "Invalid visibility setting")
+
+ yield self.store.set_room_is_public_appservice(
+ room_id, appservice_id, network_id, visibility == "public"
+ )
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index fd11935b..b63a660c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -111,6 +111,11 @@ class E2eKeysHandler(object):
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
+ except Exception as e:
+ # include ConnectionRefused and other errors
+ failures[destination] = {
+ "status": 503, "message": e.message
+ }
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
@@ -222,6 +227,11 @@ class E2eKeysHandler(object):
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
+ except Exception as e:
+ # include ConnectionRefused and other errors
+ failures[destination] = {
+ "status": 503, "message": e.message
+ }
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 2d801bad..1d07e4d0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -80,22 +80,6 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up
self.room_queues = {}
- def handle_new_event(self, event, destinations):
- """ Takes in an event from the client to server side, that has already
- been authed and handled by the state module, and sends it to any
- remote home servers that may be interested.
-
- Args:
- event: The event to send
- destinations: A list of destinations to send it to
-
- Returns:
- Deferred: Resolved when it has successfully been queued for
- processing.
- """
-
- return self.replication_layer.send_pdu(event, destinations)
-
@log_function
@defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
@@ -268,9 +252,12 @@ class FederationHandler(BaseHandler):
except:
return False
+ # Parses mapping `event_id -> (type, state_key) -> state event_id`
+ # to get all state ids that we're interested in.
event_map = yield self.store.get_events([
- e_id for key_to_eid in event_to_state_ids.values()
- for key, e_id in key_to_eid
+ e_id
+ for key_to_eid in event_to_state_ids.values()
+ for key, e_id in key_to_eid.items()
if key[0] != EventTypes.Member or check_match(key[1])
])
@@ -830,25 +817,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- new_pdu = event
-
- users_in_room = yield self.store.get_joined_users_from_context(event, context)
-
- destinations = set(
- get_domain_from_id(user_id) for user_id in users_in_room
- if not self.hs.is_mine_id(user_id)
- )
-
- destinations.discard(origin)
-
- logger.debug(
- "on_send_join_request: Sending event: %s, signatures: %s",
- event.event_id,
- event.signatures,
- )
-
- self.replication_layer.send_pdu(new_pdu, destinations)
-
state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids
@@ -1055,24 +1023,6 @@ class FederationHandler(BaseHandler):
event, event_stream_id, max_stream_id, extra_users=extra_users
)
- new_pdu = event
-
- users_in_room = yield self.store.get_joined_users_from_context(event, context)
-
- destinations = set(
- get_domain_from_id(user_id) for user_id in users_in_room
- if not self.hs.is_mine_id(user_id)
- )
- destinations.discard(origin)
-
- logger.debug(
- "on_send_leave_request: Sending event: %s, signatures: %s",
- event.event_id,
- event.signatures,
- )
-
- self.replication_layer.send_pdu(new_pdu, destinations)
-
defer.returnValue(None)
@defer.inlineCallbacks
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index fbfa5a02..e0ade4c1 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -372,11 +372,12 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks
def get_receipts():
- receipts_handler = self.hs.get_handlers().receipts_handler
- receipts = yield receipts_handler.get_receipts_for_room(
+ receipts = yield self.store.get_linearized_receipts_for_room(
room_id,
- now_token.receipt_key
+ to_key=now_token.receipt_key,
)
+ if not receipts:
+ receipts = []
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 81df4517..7a57a69b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -22,9 +22,9 @@ from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator
from synapse.types import (
- UserID, RoomAlias, RoomStreamToken, get_domain_from_id
+ UserID, RoomAlias, RoomStreamToken,
)
-from synapse.util.async import run_on_reactor, ReadWriteLock
+from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -50,6 +50,10 @@ class MessageHandler(BaseHandler):
self.pagination_lock = ReadWriteLock()
+ # We arbitrarily limit concurrent event creation for a room to 5.
+ # This is to stop us from diverging history *too* much.
+ self.limiter = Limiter(max_count=5)
+
@defer.inlineCallbacks
def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id)
@@ -191,36 +195,38 @@ class MessageHandler(BaseHandler):
"""
builder = self.event_builder_factory.new(event_dict)
- self.validator.validate_new(builder)
-
- if builder.type == EventTypes.Member:
- membership = builder.content.get("membership", None)
- target = UserID.from_string(builder.state_key)
-
- if membership in {Membership.JOIN, Membership.INVITE}:
- # If event doesn't include a display name, add one.
- profile = self.hs.get_handlers().profile_handler
- content = builder.content
+ with (yield self.limiter.queue(builder.room_id)):
+ self.validator.validate_new(builder)
+
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ target = UserID.from_string(builder.state_key)
+
+ if membership in {Membership.JOIN, Membership.INVITE}:
+ # If event doesn't include a display name, add one.
+ profile = self.hs.get_handlers().profile_handler
+ content = builder.content
+
+ try:
+ content["displayname"] = yield profile.get_displayname(target)
+ content["avatar_url"] = yield profile.get_avatar_url(target)
+ except Exception as e:
+ logger.info(
+ "Failed to get profile information for %r: %s",
+ target, e
+ )
- try:
- content["displayname"] = yield profile.get_displayname(target)
- content["avatar_url"] = yield profile.get_avatar_url(target)
- except Exception as e:
- logger.info(
- "Failed to get profile information for %r: %s",
- target, e
- )
+ if token_id is not None:
+ builder.internal_metadata.token_id = token_id
- if token_id is not None:
- builder.internal_metadata.token_id = token_id
+ if txn_id is not None:
+ builder.internal_metadata.txn_id = txn_id
- if txn_id is not None:
- builder.internal_metadata.txn_id = txn_id
+ event, context = yield self._create_new_client_event(
+ builder=builder,
+ prev_event_ids=prev_event_ids,
+ )
- event, context = yield self._create_new_client_event(
- builder=builder,
- prev_event_ids=prev_event_ids,
- )
defer.returnValue((event, context))
@defer.inlineCallbacks
@@ -599,13 +605,6 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id
)
- users_in_room = yield self.store.get_joined_users_from_context(event, context)
-
- destinations = [
- get_domain_from_id(user_id) for user_id in users_in_room
- if not self.hs.is_mine_id(user_id)
- ]
-
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
@@ -618,7 +617,3 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
-
- preserve_fn(federation_handler.handle_new_event)(
- event, destinations=destinations,
- )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index b047ae22..1b89dc62 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -91,28 +91,29 @@ class PresenceHandler(object):
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
- self.federation = hs.get_replication_layer()
+ self.replication = hs.get_replication_layer()
+ self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler()
- self.federation.register_edu_handler(
+ self.replication.register_edu_handler(
"m.presence", self.incoming_presence
)
- self.federation.register_edu_handler(
+ self.replication.register_edu_handler(
"m.presence_invite",
lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
- self.federation.register_edu_handler(
+ self.replication.register_edu_handler(
"m.presence_accept",
lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
- self.federation.register_edu_handler(
+ self.replication.register_edu_handler(
"m.presence_deny",
lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]),
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e536a909..50aa5139 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -33,8 +33,8 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
- self.federation = hs.get_replication_layer()
- self.federation.register_edu_handler(
+ self.federation = hs.get_federation_sender()
+ hs.get_replication_layer().register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
@@ -100,7 +100,7 @@ class ReceiptsHandler(BaseHandler):
if not res:
# res will be None if this read receipt is 'old'
- defer.returnValue(False)
+ continue
stream_id, max_persisted_id = res
@@ -109,6 +109,10 @@ class ReceiptsHandler(BaseHandler):
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
+ if min_batch_id is None:
+ # no new receipts
+ defer.returnValue(False)
+
affected_room_ids = list(set([r["room_id"] for r in receipts]))
with PreserveLoggingContext():
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7e119f13..286f0cef 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -81,7 +81,7 @@ class RegistrationHandler(BaseHandler):
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
- user_data = yield self.auth.get_user_from_macaroon(guest_access_token)
+ user_data = yield self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
403,
@@ -369,7 +369,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
+ def get_or_create_user(self, requester, localpart, displayname,
password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -399,8 +399,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
- token = self.auth_handler().generate_access_token(
- user_id, None, duration_in_ms)
+ token = self.auth_handler().generate_access_token(user_id)
if need_register:
yield self.store.register(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 59e4d1cd..5f18007e 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -44,16 +44,19 @@ class RoomCreationHandler(BaseHandler):
"join_rules": JoinRules.INVITE,
"history_visibility": "shared",
"original_invitees_have_ops": False,
+ "guest_can_join": True,
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "shared",
"original_invitees_have_ops": True,
+ "guest_can_join": True,
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
"history_visibility": "shared",
"original_invitees_have_ops": False,
+ "guest_can_join": False,
},
}
@@ -336,6 +339,13 @@ class RoomCreationHandler(BaseHandler):
content={"history_visibility": config["history_visibility"]}
)
+ if config["guest_can_join"]:
+ if (EventTypes.GuestAccess, '') not in initial_state:
+ yield send(
+ etype=EventTypes.GuestAccess,
+ content={"guest_access": "can_join"}
+ )
+
for (etype, state_key), content in initial_state.items():
yield send(
etype=etype,
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index b04aea01..667223df 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -22,6 +22,7 @@ from synapse.api.constants import (
)
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
+from synapse.types import ThirdPartyInstanceID
from collections import namedtuple
from unpaddedbase64 import encode_base64, decode_base64
@@ -34,6 +35,10 @@ logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
+# This is used to indicate we should only return rooms published to the main list.
+EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
+
+
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
@@ -41,22 +46,43 @@ class RoomListHandler(BaseHandler):
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
- search_filter=None):
- if search_filter:
- # We explicitly don't bother caching searches.
- return self._get_public_room_list(limit, since_token, search_filter)
+ search_filter=None,
+ network_tuple=EMTPY_THIRD_PARTY_ID,):
+ """Generate a local public room list.
+
+ There are multiple different lists: the main one plus one per third
+ party network. A client can ask for a specific list or to return all.
+
+ Args:
+ limit (int)
+ since_token (str)
+ search_filter (dict)
+ network_tuple (ThirdPartyInstanceID): Which public list to use.
+ This can be (None, None) to indicate the main list, or a particular
+ appservice and network id to use an appservice specific one.
+ Setting to None returns all public rooms across all lists.
+ """
+ if search_filter or (network_tuple and network_tuple.appservice_id is not None):
+ # We explicitly don't bother caching searches or requests for
+ # appservice specific lists.
+ return self._get_public_room_list(
+ limit, since_token, search_filter, network_tuple=network_tuple,
+ )
result = self.response_cache.get((limit, since_token))
if not result:
result = self.response_cache.set(
(limit, since_token),
- self._get_public_room_list(limit, since_token)
+ self._get_public_room_list(
+ limit, since_token, network_tuple=network_tuple
+ )
)
return result
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
- search_filter=None):
+ search_filter=None,
+ network_tuple=EMTPY_THIRD_PARTY_ID,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
@@ -73,14 +99,15 @@ class RoomListHandler(BaseHandler):
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
- public_room_stream_id, current_public_id
+ public_room_stream_id, current_public_id,
+ network_tuple=network_tuple,
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
- public_room_stream_id
+ public_room_stream_id, network_tuple=network_tuple,
)
# We want to return rooms in a particular order: the number of joined
@@ -311,7 +338,8 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
- search_filter=None):
+ search_filter=None, include_all_networks=False,
+ third_party_instance_id=None,):
if search_filter:
# We currently don't support searching across federation, so we have
# to do it manually without pagination
@@ -320,6 +348,8 @@ class RoomListHandler(BaseHandler):
res = yield self._get_remote_list_cached(
server_name, limit=limit, since_token=since_token,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
)
if search_filter:
@@ -332,22 +362,30 @@ class RoomListHandler(BaseHandler):
defer.returnValue(res)
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
- search_filter=None):
+ search_filter=None, include_all_networks=False,
+ third_party_instance_id=None,):
repl_layer = self.hs.get_replication_layer()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
- search_filter=search_filter,
+ search_filter=search_filter, include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
)
- result = self.remote_response_cache.get((server_name, limit, since_token))
+ key = (
+ server_name, limit, since_token, include_all_networks,
+ third_party_instance_id,
+ )
+ result = self.remote_response_cache.get(key)
if not result:
result = self.remote_response_cache.set(
- (server_name, limit, since_token),
+ key,
repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
)
)
return result
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1f910ff8..c880f616 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -277,6 +277,7 @@ class SyncHandler(object):
"""
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
+ block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
@@ -293,7 +294,7 @@ class SyncHandler(object):
else:
recents = []
- if not limited:
+ if not limited or block_all_timeline:
defer.returnValue(TimelineBatch(
events=recents,
prev_batch=now_token,
@@ -509,6 +510,7 @@ class SyncHandler(object):
Returns:
Deferred(SyncResult)
"""
+ logger.info("Calculating sync response for %r", sync_config.user)
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
@@ -531,9 +533,14 @@ class SyncHandler(object):
)
newly_joined_rooms, newly_joined_users = res
- yield self._generate_sync_entry_for_presence(
- sync_result_builder, newly_joined_rooms, newly_joined_users
+ block_all_presence_data = (
+ since_token is None and
+ sync_config.filter_collection.blocks_all_presence()
)
+ if not block_all_presence_data:
+ yield self._generate_sync_entry_for_presence(
+ sync_result_builder, newly_joined_rooms, newly_joined_users
+ )
yield self._generate_sync_entry_for_to_device(sync_result_builder)
@@ -569,16 +576,20 @@ class SyncHandler(object):
# We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point.
- logger.debug("Deleting messages up to %d", since_stream_id)
- yield self.store.delete_messages_for_device(
+ deleted = yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
+ logger.info("Deleted %d to-device messages up to %d",
+ deleted, since_stream_id)
- logger.debug("Getting messages up to %d", now_token.to_device_key)
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
- logger.debug("Got messages up to %d: %r", stream_id, messages)
+
+ logger.info(
+ "Returning %d to-device messages between %d and %d (current token: %d)",
+ len(messages), since_stream_id, stream_id, now_token.to_device_key
+ )
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
)
@@ -709,13 +720,20 @@ class SyncHandler(object):
`(newly_joined_rooms, newly_joined_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
-
- now_token, ephemeral_by_room = yield self.ephemeral_by_room(
- sync_result_builder.sync_config,
- now_token=sync_result_builder.now_token,
- since_token=sync_result_builder.since_token,
+ block_all_room_ephemeral = (
+ sync_result_builder.since_token is None and
+ sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
- sync_result_builder.now_token = now_token
+
+ if block_all_room_ephemeral:
+ ephemeral_by_room = {}
+ else:
+ now_token, ephemeral_by_room = yield self.ephemeral_by_room(
+ sync_result_builder.sync_config,
+ now_token=sync_result_builder.now_token,
+ since_token=sync_result_builder.since_token,
+ )
+ sync_result_builder.now_token = now_token
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 27ee715f..0eea7f8f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -55,9 +55,9 @@ class TypingHandler(object):
self.clock = hs.get_clock()
self.wheel_timer = WheelTimer(bucket_size=5000)
- self.federation = hs.get_replication_layer()
+ self.federation = hs.get_federation_sender()
- self.federation.register_edu_handler("m.typing", self._recv_edu)
+ hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index d0556ae3..d5970c05 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
from signedjson.sign import sign_json
+import cgi
import simplejson as json
import logging
import random
@@ -292,12 +293,7 @@ class MatrixFederationHttpClient(object):
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
- c_type = response.headers.getRawHeaders("Content-Type")
-
- if "application/json" not in c_type:
- raise RuntimeError(
- "Content-Type not application/json"
- )
+ check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -342,12 +338,7 @@ class MatrixFederationHttpClient(object):
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
- c_type = response.headers.getRawHeaders("Content-Type")
-
- if "application/json" not in c_type:
- raise RuntimeError(
- "Content-Type not application/json"
- )
+ check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response)
@@ -400,12 +391,7 @@ class MatrixFederationHttpClient(object):
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
- c_type = response.headers.getRawHeaders("Content-Type")
-
- if "application/json" not in c_type:
- raise RuntimeError(
- "Content-Type not application/json"
- )
+ check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response)
@@ -525,3 +511,29 @@ def _flatten_response_never_received(e):
)
else:
return "%s: %s" % (type(e).__name__, e.message,)
+
+
+def check_content_type_is_json(headers):
+ """
+ Check that a set of HTTP headers have a Content-Type header, and that it
+ is application/json.
+
+ Args:
+ headers (twisted.web.http_headers.Headers): headers to check
+
+ Raises:
+ RuntimeError if the
+
+ """
+ c_type = headers.getRawHeaders("Content-Type")
+ if c_type is None:
+ raise RuntimeError(
+ "No Content-Type header"
+ )
+
+ c_type = c_type[0] # only the first header
+ val, options = cgi.parse_header(c_type)
+ if val != "application/json":
+ raise RuntimeError(
+ "Content-Type not application/json: was '%s'" % c_type
+ )
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 93463862..8c22d6f0 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -78,12 +78,16 @@ def parse_boolean(request, name, default=None, required=False):
parameter is present and not one of "true" or "false".
"""
- if name in request.args:
+ return parse_boolean_from_args(request.args, name, default, required)
+
+
+def parse_boolean_from_args(args, name, default=None, required=False):
+ if name in args:
try:
return {
"true": True,
"false": False,
- }[request.args[name][0]]
+ }[args[name][0]]
except:
message = (
"Boolean query parameter %r must be one of"
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 48653ae8..acbd4bb5 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
+from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
@@ -143,6 +144,12 @@ class Notifier(object):
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
+
+ if hs.should_send_federation():
+ self.federation_sender = hs.get_federation_sender()
+ else:
+ self.federation_sender = None
+
self.state_handler = hs.get_state_handler()
self.clock.looping_call(
@@ -220,6 +227,9 @@ class Notifier(object):
# poke any interested application service.
self.appservice_handler.notify_interested_services(room_stream_id)
+ if self.federation_sender:
+ self.federation_sender.notify_new_events(room_stream_id)
+
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
@@ -285,14 +295,7 @@ class Notifier(object):
result = None
if timeout:
- # Will be set to a _NotificationListener that we'll be waiting on.
- # Allows us to cancel it.
- listener = None
-
- def timed_out():
- if listener:
- listener.deferred.cancel()
- timer = self.clock.call_later(timeout / 1000., timed_out)
+ end_time = self.clock.time_msec() + timeout
prev_token = from_token
while not result:
@@ -303,6 +306,10 @@ class Notifier(object):
if result:
break
+ now = self.clock.time_msec()
+ if end_time <= now:
+ break
+
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
# We need to supply the token we supplied to callback so
@@ -310,11 +317,14 @@ class Notifier(object):
prev_token = current_token
listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext():
- yield listener.deferred
+ yield self.clock.time_bound_deferred(
+ listener.deferred,
+ time_out=(end_time - now) / 1000.
+ )
+ except DeferredTimedOutError:
+ break
except defer.CancelledError:
break
-
- self.clock.cancel_call_later(timer, ignore_errs=True)
else:
current_token = user_stream.current_token
result = yield callback(from_token, current_token)
@@ -483,22 +493,27 @@ class Notifier(object):
"""
listener = _NotificationListener(None)
- def timed_out():
- listener.deferred.cancel()
+ end_time = self.clock.time_msec() + timeout
- timer = self.clock.call_later(timeout / 1000., timed_out)
while True:
listener.deferred = self.replication_deferred.observe()
result = yield callback()
if result:
break
+ now = self.clock.time_msec()
+ if end_time <= now:
+ break
+
try:
with PreserveLoggingContext():
- yield listener.deferred
+ yield self.clock.time_bound_deferred(
+ listener.deferred,
+ time_out=(end_time - now) / 1000.
+ )
+ except DeferredTimedOutError:
+ break
except defer.CancelledError:
break
- self.clock.cancel_call_later(timer, ignore_errs=True)
-
defer.returnValue(result)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index be55598c..78b095c9 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -87,12 +87,12 @@ class BulkPushRuleEvaluator:
condition_cache = {}
for uid, rules in self.rules_by_user.items():
- display_name = None
- member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
- if member_ev_id:
- member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
- if member_ev:
- display_name = member_ev.content.get("displayname", None)
+ display_name = room_members.get(uid, {}).get("display_name", None)
+ if not display_name:
+ # Handle the case where we are pushing a membership event to
+ # that user, as they might not be already joined.
+ if event.type == EventTypes.Member and event.state_key == uid:
+ display_name = event.content.get("displayname", None)
filtered = filtered_by_user[uid]
if len(filtered) == 0:
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index b9e41770..3742a25b 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -49,8 +49,8 @@ CONDITIONAL_REQUIREMENTS = {
"Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"],
},
- "ldap": {
- "ldap3>=1.0": ["ldap3>=1.0"],
+ "matrix-synapse-ldap3": {
+ "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"],
},
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
@@ -69,6 +69,7 @@ def requirements(config=None, include_conditional=False):
def github_link(project, version, egg):
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
+
DEPENDENCY_LINKS = {
}
@@ -156,6 +157,7 @@ def list_requirements():
result.append(requirement)
return result
+
if __name__ == "__main__":
import sys
sys.stdout.writelines(req + "\n" for req in list_requirements())
diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py
new file mode 100644
index 00000000..c05a50d7
--- /dev/null
+++ b/synapse/replication/expire_cache.py
@@ -0,0 +1,60 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import respond_with_json_bytes, request_handler
+from synapse.http.servlet import parse_json_object_from_request
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+
+class ExpireCacheResource(Resource):
+ """
+ HTTP endpoint for expiring storage caches.
+
+ POST /_synapse/replication/expire_cache HTTP/1.1
+ Content-Type: application/json
+
+ {
+ "invalidate": [
+ {
+ "name": "func_name",
+ "keys": ["key1", "key2"]
+ }
+ ]
+ }
+ """
+
+ def __init__(self, hs):
+ Resource.__init__(self) # Resource is old-style, so no super()
+
+ self.store = hs.get_datastore()
+ self.version_string = hs.version_string
+ self.clock = hs.get_clock()
+
+ def render_POST(self, request):
+ self._async_render_POST(request)
+ return NOT_DONE_YET
+
+ @request_handler()
+ def _async_render_POST(self, request):
+ content = parse_json_object_from_request(request)
+
+ for row in content["invalidate"]:
+ name = row["name"]
+ keys = tuple(row["keys"])
+
+ getattr(self.store, name).invalidate(keys)
+
+ respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 5a14c51d..4616e9b3 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource
+from synapse.replication.expire_cache import ExpireCacheResource
from synapse.api.errors import SynapseError
from twisted.web.resource import Resource
@@ -44,6 +45,7 @@ STREAM_NAMES = (
("caches",),
("to_device",),
("public_rooms",),
+ ("federation",),
)
@@ -116,11 +118,14 @@ class ReplicationResource(Resource):
self.sources = hs.get_event_sources()
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
+ self.federation_sender = hs.get_federation_sender()
self.notifier = hs.notifier
self.clock = hs.get_clock()
+ self.config = hs.get_config()
self.putChild("remove_pushers", PusherResource(hs))
self.putChild("syncing_users", PresenceResource(hs))
+ self.putChild("expire_cache", ExpireCacheResource(hs))
def render_GET(self, request):
self._async_render_GET(request)
@@ -134,6 +139,7 @@ class ReplicationResource(Resource):
pushers_token = self.store.get_pushers_stream_token()
caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id()
+ federation_token = self.federation_sender.get_current_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@@ -148,6 +154,7 @@ class ReplicationResource(Resource):
caches_token,
int(stream_token.to_device_key),
int(public_rooms_token),
+ int(federation_token),
))
@request_handler()
@@ -164,8 +171,13 @@ class ReplicationResource(Resource):
}
request_streams["streams"] = parse_string(request, "streams")
+ federation_ack = parse_integer(request, "federation_ack", None)
+
def replicate():
- return self.replicate(request_streams, limit)
+ return self.replicate(
+ request_streams, limit,
+ federation_ack=federation_ack
+ )
writer = yield self.notifier.wait_for_replication(replicate, timeout)
result = writer.finish()
@@ -183,7 +195,7 @@ class ReplicationResource(Resource):
finish_request(request)
@defer.inlineCallbacks
- def replicate(self, request_streams, limit):
+ def replicate(self, request_streams, limit, federation_ack=None):
writer = _Writer()
current_token = yield self.current_replication_token()
logger.debug("Replicating up to %r", current_token)
@@ -202,6 +214,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams)
+ self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams)
logger.debug("Replicated %d rows", writer.total)
@@ -462,7 +475,24 @@ class ReplicationResource(Resource):
)
upto_token = _position_from_rows(public_rooms_rows, current_position)
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
- "position", "room_id", "visibility"
+ "position", "room_id", "visibility", "appservice_id", "network_id",
+ ), position=upto_token)
+
+ def federation(self, writer, current_token, limit, request_streams, federation_ack):
+ if self.config.send_federation:
+ return
+
+ current_position = current_token.federation
+
+ federation = request_streams.get("federation")
+
+ if federation is not None and federation != current_position:
+ federation_rows = self.federation_sender.get_replication_rows(
+ federation, limit, federation_ack=federation_ack,
+ )
+ upto_token = _position_from_rows(federation_rows, current_position)
+ writer.write_header_and_rows("federation", federation_rows, (
+ "position", "type", "content",
), position=upto_token)
@@ -497,6 +527,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
+ "federation",
))):
__slots__ = []
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f19540d6..18076e0f 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore):
else:
self._cache_id_gen = None
+ self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
+ self.http_client = hs.get_simple_http_client()
+
def stream_positions(self):
pos = {}
if self._cache_id_gen:
@@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore):
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)
+
+ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ txn.call_after(cache_func.invalidate, keys)
+ txn.call_after(self._send_invalidation_poke, cache_func, keys)
+
+ @defer.inlineCallbacks
+ def _send_invalidation_poke(self, cache_func, keys):
+ try:
+ yield self.http_client.post_json_get_json(self.expire_cache_url, {
+ "invalidate": [{
+ "name": cache_func.__name__,
+ "keys": list(keys),
+ }]
+ })
+ except:
+ logger.exception("Failed to poke on expire_cache")
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 3bfd5e82..cc860f9f 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
"DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token()
)
+ self._device_federation_outbox_stream_cache = StreamChangeCache(
+ "DeviceFederationOutboxStreamChangeCache",
+ self._device_inbox_id_gen.get_current_token()
+ )
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
+ get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
+ delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
@@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
self._device_inbox_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
- user_id = row[1]
- self._device_inbox_stream_cache.entity_has_changed(
- user_id, stream_id
- )
+ entity = row[1]
+
+ if entity.startswith("@"):
+ self._device_inbox_stream_cache.entity_has_changed(
+ entity, stream_id
+ )
+ else:
+ self._device_federation_outbox_stream_cache.entity_has_changed(
+ entity, stream_id
+ )
return super(SlavedDeviceInboxStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 0c26e96e..64f18bbb 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
@@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore):
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
)
+ get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
+
+ get_federation_out_pos = DataStore.get_federation_out_pos.__func__
+ update_federation_out_pos = DataStore.update_federation_out_pos.__func__
+
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
@@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore):
stream = result.get("events")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
+
+ if stream["rows"]:
+ logger.info("Got %d event rows", len(stream["rows"]))
+
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 23c61386..6df9a25e 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -15,6 +15,7 @@
from ._base import BaseSlavedStore
from synapse.storage import DataStore
+from synapse.storage.room import RoomStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -30,7 +31,7 @@ class RoomStore(BaseSlavedStore):
DataStore.get_current_public_room_stream_id.__func__
)
get_public_room_ids_at_stream_id = (
- DataStore.get_public_room_ids_at_stream_id.__func__
+ RoomStore.__dict__["get_public_room_ids_at_stream_id"]
)
get_public_room_ids_at_stream_id_txn = (
DataStore.get_public_room_ids_at_stream_id_txn.__func__
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 6f2ba98a..fbb58f35 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
@@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore
class TransactionStore(BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[
"get_destination_retry_timings"
- ].orig
+ ]
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
+ set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
+ _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
- # For now, don't record the destination rety timings
- def set_destination_retry_timings(*args, **kwargs):
- return defer.succeed(None)
+ prep_send_transaction = DataStore.prep_send_transaction.__func__
+ delivered_txn = DataStore.delivered_txn.__func__
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 09d08315..8930f182 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
def register_servlets(hs, http_server):
ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server)
+ ClientAppserviceDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet):
@@ -184,3 +185,36 @@ class ClientDirectoryListServer(ClientV1RestServlet):
)
defer.returnValue((200, {}))
+
+
+class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(ClientAppserviceDirectoryListServer, self).__init__(hs)
+ self.store = hs.get_datastore()
+ self.handlers = hs.get_handlers()
+
+ def on_PUT(self, request, network_id, room_id):
+ content = parse_json_object_from_request(request)
+ visibility = content.get("visibility", "public")
+ return self._edit(request, network_id, room_id, visibility)
+
+ def on_DELETE(self, request, network_id, room_id):
+ return self._edit(request, network_id, room_id, "private")
+
+ @defer.inlineCallbacks
+ def _edit(self, request, network_id, room_id, visibility):
+ requester = yield self.auth.get_user_by_req(request)
+ if not requester.app_service:
+ raise AuthError(
+ 403, "Only appservices can edit the appservice published room list"
+ )
+
+ yield self.handlers.directory_handler.edit_published_appservice_room_list(
+ requester.app_service.id, network_id, room_id, visibility,
+ )
+
+ defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 345018a8..093bc072 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -137,16 +137,13 @@ class LoginRestServlet(ClientV1RestServlet):
password=login_submission["password"],
)
device_id = yield self._register_device(user_id, login_submission)
- access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(
- user_id, device_id,
- login_submission.get("initial_device_display_name")
- )
+ access_token = yield auth_handler.get_access_token_for_user_id(
+ user_id, device_id,
+ login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
- "refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
@@ -161,16 +158,13 @@ class LoginRestServlet(ClientV1RestServlet):
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
device_id = yield self._register_device(user_id, login_submission)
- access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(
- user_id, device_id,
- login_submission.get("initial_device_display_name")
- )
+ access_token = yield auth_handler.get_access_token_for_user_id(
+ user_id, device_id,
+ login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
- "refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
@@ -207,16 +201,14 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device(
registered_user_id, login_submission
)
- access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(
- registered_user_id, device_id,
- login_submission.get("initial_device_display_name")
- )
+ access_token = yield auth_handler.get_access_token_for_user_id(
+ registered_user_id, device_id,
+ login_submission.get("initial_device_display_name"),
)
+
result = {
"user_id": registered_user_id,
"access_token": access_token,
- "refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index b5a76fef..ecf7e311 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -384,7 +384,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
- self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
@@ -418,18 +417,8 @@ class CreateUserRestServlet(ClientV1RestServlet):
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
- if "duration_seconds" not in user_json:
- raise SynapseError(400, "Expected 'duration_seconds' key.")
-
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")
- duration_seconds = 0
- try:
- duration_seconds = int(user_json["duration_seconds"])
- except ValueError:
- raise SynapseError(400, "Failed to parse 'duration_seconds'")
- if duration_seconds > self.direct_user_creation_max_duration:
- duration_seconds = self.direct_user_creation_max_duration
password_hash = user_json["password_hash"].encode("utf-8") \
if user_json.get("password_hash") else None
@@ -438,7 +427,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
requester=requester,
localpart=localpart,
displayname=displayname,
- duration_in_ms=(duration_seconds * 1000),
password_hash=password_hash
)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 3fb1f2de..eead435b 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,7 +21,7 @@ from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
-from synapse.types import UserID, RoomID, RoomAlias
+from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID
from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
@@ -321,6 +321,20 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
since_token = content.get("since", None)
search_filter = content.get("filter", None)
+ include_all_networks = content.get("include_all_networks", False)
+ third_party_instance_id = content.get("third_party_instance_id", None)
+
+ if include_all_networks:
+ network_tuple = None
+ if third_party_instance_id is not None:
+ raise SynapseError(
+ 400, "Can't use include_all_networks with an explicit network"
+ )
+ elif third_party_instance_id is None:
+ network_tuple = ThirdPartyInstanceID(None, None)
+ else:
+ network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+
handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(
@@ -328,12 +342,15 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
limit=limit,
since_token=since_token,
search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
)
else:
data = yield handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
search_filter=search_filter,
+ network_tuple=network_tuple,
)
defer.returnValue((200, data))
@@ -369,6 +386,24 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
}))
+class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")
+
+ def __init__(self, hs):
+ super(JoinedRoomMemberListRestServlet, self).__init__(hs)
+ self.state = hs.get_state_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ yield self.auth.get_user_by_req(request)
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+ defer.returnValue((200, {
+ "joined": users_with_profile
+ }))
+
+
# TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
@@ -692,6 +727,22 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results))
+class JoinedRoomsRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/joined_rooms$")
+
+ def __init__(self, hs):
+ super(JoinedRoomsRestServlet, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
+ room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
+ defer.returnValue((200, {"joined_rooms": list(room_ids)}))
+
+
def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""Registers a transaction-based path.
@@ -727,6 +778,7 @@ def register_servlets(hs, http_server):
RoomStateEventRestServlet(hs).register(http_server)
RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
+ JoinedRoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server)
@@ -738,4 +790,5 @@ def register_servlets(hs, http_server):
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
+ JoinedRoomsRestServlet(hs).register(http_server)
RoomEventContext(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 3ba0b0fc..a1feaf3d 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -39,7 +39,7 @@ class DevicesRestServlet(servlet.RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string()
)
@@ -63,7 +63,7 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
device = yield self.device_handler.get_device(
requester.user.to_string(),
device_id,
@@ -99,7 +99,7 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index f185f9a7..46789775 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -65,7 +65,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -94,10 +94,6 @@ class KeyUploadServlet(RestServlet):
class KeyQueryServlet(RestServlet):
"""
- GET /keys/query/<user_id> HTTP/1.1
-
- GET /keys/query/<user_id>/<device_id> HTTP/1.1
-
POST /keys/query HTTP/1.1
Content-Type: application/json
{
@@ -131,11 +127,7 @@ class KeyQueryServlet(RestServlet):
"""
PATTERNS = client_v2_patterns(
- "/keys/query(?:"
- "/(?P<user_id>[^/]*)(?:"
- "/(?P<device_id>[^/]*)"
- ")?"
- ")?",
+ "/keys/query$",
releases=()
)
@@ -149,31 +141,16 @@ class KeyQueryServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks
- def on_POST(self, request, user_id, device_id):
- yield self.auth.get_user_by_req(request)
+ def on_POST(self, request):
+ yield self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue((200, result))
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, device_id):
- requester = yield self.auth.get_user_by_req(request)
- timeout = parse_integer(request, "timeout", 10 * 1000)
- auth_user_id = requester.user.to_string()
- user_id = user_id if user_id else auth_user_id
- device_ids = [device_id] if device_id else []
- result = yield self.e2e_keys_handler.query_devices(
- {"device_keys": {user_id: device_ids}},
- timeout,
- )
- defer.returnValue((200, result))
-
class OneTimeKeyServlet(RestServlet):
"""
- GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
-
POST /keys/claim HTTP/1.1
{
"one_time_keys": {
@@ -191,9 +168,7 @@ class OneTimeKeyServlet(RestServlet):
"""
PATTERNS = client_v2_patterns(
- "/keys/claim(?:/?|(?:/"
- "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
- ")?)",
+ "/keys/claim$",
releases=()
)
@@ -203,18 +178,8 @@ class OneTimeKeyServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks
- def on_GET(self, request, user_id, device_id, algorithm):
- yield self.auth.get_user_by_req(request)
- timeout = parse_integer(request, "timeout", 10 * 1000)
- result = yield self.e2e_keys_handler.claim_one_time_keys(
- {"one_time_keys": {user_id: {device_id: algorithm}}},
- timeout,
- )
- defer.returnValue((200, result))
-
- @defer.inlineCallbacks
- def on_POST(self, request, user_id, device_id, algorithm):
- yield self.auth.get_user_by_req(request)
+ def on_POST(self, request):
+ yield self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.claim_one_time_keys(
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 891cef99..1fbff2ed 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -36,7 +36,7 @@ class ReceiptRestServlet(RestServlet):
super(ReceiptRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.receipts_handler = hs.get_handlers().receipts_handler
+ self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 6cfb2086..3e7a285e 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+import synapse
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
@@ -100,12 +101,14 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request):
yield run_on_reactor()
+ body = parse_json_object_from_request(request)
+
kind = "user"
if "kind" in request.args:
kind = request.args["kind"][0]
if kind == "guest":
- ret = yield self._do_guest_registration()
+ ret = yield self._do_guest_registration(body)
defer.returnValue(ret)
return
elif kind != "user":
@@ -113,8 +116,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind,)
)
- body = parse_json_object_from_request(request)
-
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
desired_password = None
@@ -373,8 +374,7 @@ class RegisterRestServlet(RestServlet):
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
- Allocates device_id if one was not given; also creates access_token
- and refresh_token.
+ Allocates device_id if one was not given; also creates access_token.
Args:
(str) user_id: full canonical @user:id
@@ -385,8 +385,8 @@ class RegisterRestServlet(RestServlet):
"""
device_id = yield self._register_device(user_id, params)
- access_token, refresh_token = (
- yield self.auth_handler.get_login_tuple_for_user_id(
+ access_token = (
+ yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
@@ -396,7 +396,6 @@ class RegisterRestServlet(RestServlet):
"user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
- "refresh_token": refresh_token,
"device_id": device_id,
})
@@ -421,20 +420,28 @@ class RegisterRestServlet(RestServlet):
)
@defer.inlineCallbacks
- def _do_guest_registration(self):
+ def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled"))
user_id, _ = yield self.registration_handler.register(
generate_token=False,
make_guest=True
)
+
+ # we don't allow guests to specify their own device_id, because
+ # we have nowhere to store it.
+ device_id = synapse.api.auth.GUEST_DEVICE_ID
+ initial_display_name = params.get("initial_device_display_name")
+ self.device_handler.check_device_registered(
+ user_id, device_id, initial_display_name
+ )
+
access_token = self.auth_handler.generate_access_token(
user_id, ["guest = true"]
)
- # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
- # so long as we don't return a refresh_token here.
defer.returnValue((200, {
"user_id": user_id,
+ "device_id": device_id,
"access_token": access_token,
"home_server": self.hs.hostname,
}))
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index ac660669..d607bd29 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -50,7 +50,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks
def _put(self, request, message_type, txn_id):
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 6fc63715..7199ec88 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec()
joined = self.encode_joined(
- sync_result.joined, time_now, requester.access_token_id
+ sync_result.joined, time_now, requester.access_token_id, filter.event_fields
)
invited = self.encode_invited(
@@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet):
)
archived = self.encode_archived(
- sync_result.archived, time_now, requester.access_token_id
+ sync_result.archived, time_now, requester.access_token_id, filter.event_fields
)
response_content = {
@@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet):
formatted.append(event)
return {"events": formatted}
- def encode_joined(self, rooms, time_now, token_id):
+ def encode_joined(self, rooms, time_now, token_id, event_fields):
"""
Encode the joined rooms in a sync result
@@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet):
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
-
+ event_fields(list<str>): List of event fields to include. If empty,
+ all fields will be returned.
Returns:
dict[str, dict[str, object]]: the joined rooms list, in our
response format
@@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = self.encode_room(
- room, time_now, token_id
+ room, time_now, token_id, only_fields=event_fields
)
return joined
@@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet):
return invited
- def encode_archived(self, rooms, time_now, token_id):
+ def encode_archived(self, rooms, time_now, token_id, event_fields):
"""
Encode the archived rooms in a sync result
@@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet):
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
-
+ event_fields(list<str>): List of event fields to include. If empty,
+ all fields will be returned.
Returns:
dict[str, dict[str, object]]: The invited rooms list, in our
response format
@@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = self.encode_room(
- room, time_now, token_id, joined=False
+ room, time_now, token_id, joined=False, only_fields=event_fields
)
return joined
@staticmethod
- def encode_room(room, time_now, token_id, joined=True):
+ def encode_room(room, time_now, token_id, joined=True, only_fields=None):
"""
Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet):
of transaction IDs
joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events
-
+ only_fields(list<str>): Optional. The list of event fields to include.
Returns:
dict[str, object]: the room, encoded in our response format
"""
@@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet):
return serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id,
+ only_event_fields=only_fields,
)
state_dict = room.state
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 0d312c91..6e76b9e9 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -15,8 +15,8 @@
from twisted.internet import defer
-from synapse.api.errors import AuthError, StoreError, SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import AuthError
+from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
@@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet):
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
- self.hs = hs
- self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
- body = parse_json_object_from_request(request)
- try:
- old_refresh_token = body["refresh_token"]
- auth_handler = self.hs.get_auth_handler()
- refresh_result = yield self.store.exchange_refresh_token(
- old_refresh_token, auth_handler.generate_refresh_token
- )
- (user_id, new_refresh_token, device_id) = refresh_result
- new_access_token = yield auth_handler.issue_access_token(
- user_id, device_id
- )
- defer.returnValue((200, {
- "access_token": new_access_token,
- "refresh_token": new_refresh_token,
- }))
- except KeyError:
- raise SynapseError(400, "Missing required key 'refresh_token'.")
- except StoreError:
- raise AuthError(403, "Did not recognize refresh token")
+ raise AuthError(403, "tokenrefresh is no longer supported.")
def register_servlets(hs, http_server):
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 33f35fb4..99760d62 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -381,7 +381,10 @@ def _calc_og(tree, media_uri):
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
- og['og:title'] = title[0].text.strip() if title else None
+ if title and title[0].text is not None:
+ og['og:title'] = title[0].text.strip()
+ else:
+ og['og:title'] = None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
@@ -543,5 +546,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
- description = new_desc.strip() + "…"
+ description = new_desc.strip() + u"…"
return description if description else None
diff --git a/synapse/server.py b/synapse/server.py
index 374124a1..0bfb4112 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -32,6 +32,9 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.federation import initialize_http_replication
+from synapse.federation.send_queue import FederationRemoteSendQueue
+from synapse.federation.transport.client import TransportLayerClient
+from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler
@@ -44,6 +47,7 @@ from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler
+from synapse.handlers.receipts import ReceiptsHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
@@ -124,6 +128,9 @@ class HomeServer(object):
'http_client_context_factory',
'simple_http_client',
'media_repository',
+ 'federation_transport_client',
+ 'federation_sender',
+ 'receipts_handler',
]
def __init__(self, hostname, **kwargs):
@@ -265,9 +272,30 @@ class HomeServer(object):
def build_media_repository(self):
return MediaRepository(self)
+ def build_federation_transport_client(self):
+ return TransportLayerClient(self)
+
+ def build_federation_sender(self):
+ if self.should_send_federation():
+ return TransactionQueue(self)
+ elif not self.config.worker_app:
+ return FederationRemoteSendQueue(self)
+ else:
+ raise Exception("Workers cannot send federation traffic")
+
+ def build_receipts_handler(self):
+ return ReceiptsHandler(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
+ def should_send_federation(self):
+ "Should this server be sending federation traffic directly?"
+ return self.config.send_federation and (
+ not self.config.worker_app
+ or self.config.worker_app == "synapse.app.federation_sender"
+ )
+
def _make_dependency_method(depname):
def _get(hs):
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 9996f195..fe936b3e 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
- self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
@@ -223,6 +222,7 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
+ self._min_stream_order_on_start = self.get_room_min_stream_ordering()
super(DataStore, self).__init__(hs)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d828d6ee..b62c459d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -561,12 +561,17 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+ else:
+ where = ""
+
sql = (
- "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
+ "SELECT %(retcol)s FROM %(table)s %(where)s"
) % {
"retcol": retcol,
"table": table,
- "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+ "where": where,
}
txn.execute(sql, keyvalues.values())
@@ -744,10 +749,15 @@ class SQLBaseStore(object):
@staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
- update_sql = "UPDATE %s SET %s WHERE %s" % (
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+ else:
+ where = ""
+
+ update_sql = "UPDATE %s SET %s %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
- " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ where,
)
txn.execute(
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 3d5994a5..51457056 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -39,6 +39,14 @@ class ApplicationServiceStore(SQLBaseStore):
def get_app_services(self):
return self.services_cache
+ def get_if_app_services_interested_in_user(self, user_id):
+ """Check if the user is one associated with an app service
+ """
+ for service in self.services_cache:
+ if service.is_interested_in_user(user_id):
+ return True
+ return False
+
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index f640e737..2821eb89 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -242,7 +242,7 @@ class DeviceInboxStore(SQLBaseStore):
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
Returns:
- A deferred that resolves when the messages have been deleted.
+ A deferred that resolves to the number of messages deleted.
"""
def delete_messages_for_device_txn(txn):
sql = (
@@ -251,6 +251,7 @@ class DeviceInboxStore(SQLBaseStore):
" AND stream_id <= ?"
)
txn.execute(sql, (user_id, device_id, up_to_stream_id))
+ return txn.rowcount
return self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
@@ -269,27 +270,29 @@ class DeviceInboxStore(SQLBaseStore):
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
+ # We limit like this as we might have multiple rows per stream_id, and
+ # we want to make sure we always get all entries for any stream_id
+ # we return.
+ upper_pos = min(current_pos, last_pos + limit)
sql = (
- "SELECT stream_id FROM device_inbox"
+ "SELECT stream_id, user_id"
+ " FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY stream_id"
" ORDER BY stream_id ASC"
- " LIMIT ?"
)
- txn.execute(sql, (last_pos, current_pos, limit))
- stream_ids = txn.fetchall()
- if not stream_ids:
- return []
- max_stream_id_in_limit = stream_ids[-1]
+ txn.execute(sql, (last_pos, upper_pos))
+ rows = txn.fetchall()
sql = (
- "SELECT stream_id, user_id, device_id, message_json"
- " FROM device_inbox"
+ "SELECT stream_id, destination"
+ " FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
- txn.execute(sql, (last_pos, max_stream_id_in_limit))
- return txn.fetchall()
+ txn.execute(sql, (last_pos, upper_pos))
+ rows.extend(txn.fetchall())
+
+ return rows
return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 9cd923eb..7de3e8c5 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -39,6 +39,14 @@ class EventPushActionsStore(SQLBaseStore):
columns=["user_id", "stream_ordering"],
)
+ self.register_background_index_update(
+ "event_push_actions_highlights_index",
+ index_name="event_push_actions_highlights_index",
+ table="event_push_actions",
+ columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+ where_clause="highlight=1"
+ )
+
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
Args:
@@ -88,8 +96,11 @@ class EventPushActionsStore(SQLBaseStore):
topological_ordering, stream_ordering
)
+ # First get number of notifications.
+ # We don't need to put a notif=1 clause as all rows always have
+ # notif=1
sql = (
- "SELECT sum(notif), sum(highlight)"
+ "SELECT count(*)"
" FROM event_push_actions ea"
" WHERE"
" user_id = ?"
@@ -99,13 +110,27 @@ class EventPushActionsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id))
row = txn.fetchone()
- if row:
- return {
- "notify_count": row[0] or 0,
- "highlight_count": row[1] or 0,
- }
- else:
- return {"notify_count": 0, "highlight_count": 0}
+ notify_count = row[0] if row else 0
+
+ # Now get the number of highlights
+ sql = (
+ "SELECT count(*)"
+ " FROM event_push_actions ea"
+ " WHERE"
+ " highlight = 1"
+ " AND user_id = ?"
+ " AND room_id = ?"
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
+ row = txn.fetchone()
+ highlight_count = row[0] if row else 0
+
+ return {
+ "notify_count": notify_count,
+ "highlight_count": highlight_count,
+ }
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 49aeb953..ecb79c07 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -54,6 +54,7 @@ def encode_json(json_object):
else:
return json.dumps(json_object, ensure_ascii=False)
+
# These values are used in the `enqueus_event` and `_do_fetch` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 52487368..a2ccc66e 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
from ._base import SQLBaseStore
+from synapse.api.errors import SynapseError, Codes
from synapse.util.caches.descriptors import cachedInlineCallbacks
import simplejson as json
@@ -24,6 +25,13 @@ import simplejson as json
class FilteringStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id):
+ # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
+ # with a coherent error message rather than 500 M_UNKNOWN.
+ try:
+ int(filter_id)
+ except ValueError:
+ raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
+
def_json = yield self._simple_select_one_onecol(
table="user_filters",
keyvalues={
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6576a300..e46ae650 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 38
+SCHEMA_VERSION = 39
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 21d06966..7460f98a 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -37,6 +37,13 @@ class UserPresenceState(namedtuple("UserPresenceState",
status_msg (str): User set status message.
"""
+ def as_dict(self):
+ return dict(self._asdict())
+
+ @staticmethod
+ def from_dict(d):
+ return UserPresenceState(**d)
+
def copy_and_replace(self, **kwargs):
return self._replace(**kwargs)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 49721656..cbec2559 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -156,12 +156,20 @@ class PushRuleStore(SQLBaseStore):
event=event,
)
- local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u))
+ # We ignore app service users for now. This is so that we don't fill
+ # up the `get_if_users_have_pushers` cache with AS entries that we
+ # know don't have pushers, nor even read receipts.
+ local_users_in_room = set(
+ u for u in users_in_room
+ if self.hs.is_mine_id(u)
+ and not self.get_if_app_services_interested_in_user(u)
+ )
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
- local_users_in_room, on_invalidate=cache_context.invalidate,
+ local_users_in_room,
+ on_invalidate=cache_context.invalidate,
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 9747a04a..f72d15f5 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -405,7 +405,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data
)
- max_persisted_id = self._stream_id_gen.get_current_token()
+ max_persisted_id = self._receipts_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id))
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index e404fa72..983a8ec5 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -68,31 +68,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="add_access_token_to_user",
)
- @defer.inlineCallbacks
- def add_refresh_token_to_user(self, user_id, token, device_id=None):
- """Adds a refresh token for the given user.
-
- Args:
- user_id (str): The user ID.
- token (str): The new refresh token to add.
- device_id (str): ID of the device to associate with the access
- token
- Raises:
- StoreError if there was a problem adding this.
- """
- next_id = self._refresh_tokens_id_gen.get_next()
-
- yield self._simple_insert(
- "refresh_tokens",
- {
- "id": next_id,
- "user_id": user_id,
- "token": token,
- "device_id": device_id,
- },
- desc="add_refresh_token_to_user",
- )
-
def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False):
@@ -353,47 +328,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
token
)
- def exchange_refresh_token(self, refresh_token, token_generator):
- """Exchange a refresh token for a new one.
-
- Doing so invalidates the old refresh token - refresh tokens are single
- use.
-
- Args:
- refresh_token (str): The refresh token of a user.
- token_generator (fn: str -> str): Function which, when given a
- user ID, returns a unique refresh token for that user. This
- function must never return the same value twice.
- Returns:
- tuple of (user_id, new_refresh_token, device_id)
- Raises:
- StoreError if no user was found with that refresh token.
- """
- return self.runInteraction(
- "exchange_refresh_token",
- self._exchange_refresh_token,
- refresh_token,
- token_generator
- )
-
- def _exchange_refresh_token(self, txn, old_token, token_generator):
- sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
- txn.execute(sql, (old_token,))
- rows = self.cursor_to_dict(txn)
- if not rows:
- raise StoreError(403, "Did not recognize refresh token")
- user_id = rows[0]["user_id"]
- device_id = rows[0]["device_id"]
-
- # TODO(danielwh): Maybe perform a validation on the macaroon that
- # macaroon.user_id == user_id.
-
- new_token = token_generator(user_id)
- sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
- txn.execute(sql, (new_token, old_token,))
-
- return user_id, new_token, device_id
-
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 11813b44..8a2fe2fd 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
from .engines import PostgresEngine, Sqlite3Engine
@@ -106,7 +107,11 @@ class RoomStore(SQLBaseStore):
entries = self._simple_select_list_txn(
txn,
table="public_room_list_stream",
- keyvalues={"room_id": room_id},
+ keyvalues={
+ "room_id": room_id,
+ "appservice_id": None,
+ "network_id": None,
+ },
retcols=("stream_id", "visibility"),
)
@@ -124,6 +129,8 @@ class RoomStore(SQLBaseStore):
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
+ "appservice_id": None,
+ "network_id": None,
}
)
@@ -132,6 +139,87 @@ class RoomStore(SQLBaseStore):
"set_room_is_public",
set_room_is_public_txn, next_id,
)
+ self.hs.get_notifier().on_new_replication_data()
+
+ @defer.inlineCallbacks
+ def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
+ is_public):
+ """Edit the appservice/network specific public room list.
+
+ Each appservice can have a number of published room lists associated
+ with them, keyed off of an appservice defined `network_id`, which
+ basically represents a single instance of a bridge to a third party
+ network.
+
+ Args:
+ room_id (str)
+ appservice_id (str)
+ network_id (str)
+ is_public (bool): Whether to publish or unpublish the room from the
+ list.
+ """
+ def set_room_is_public_appservice_txn(txn, next_id):
+ if is_public:
+ try:
+ self._simple_insert_txn(
+ txn,
+ table="appservice_room_list",
+ values={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id
+ },
+ )
+ except self.database_engine.module.IntegrityError:
+ # We've already inserted, nothing to do.
+ return
+ else:
+ self._simple_delete_txn(
+ txn,
+ table="appservice_room_list",
+ keyvalues={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id
+ },
+ )
+
+ entries = self._simple_select_list_txn(
+ txn,
+ table="public_room_list_stream",
+ keyvalues={
+ "room_id": room_id,
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ },
+ retcols=("stream_id", "visibility"),
+ )
+
+ entries.sort(key=lambda r: r["stream_id"])
+
+ add_to_stream = True
+ if entries:
+ add_to_stream = bool(entries[-1]["visibility"]) != is_public
+
+ if add_to_stream:
+ self._simple_insert_txn(
+ txn,
+ table="public_room_list_stream",
+ values={
+ "stream_id": next_id,
+ "room_id": room_id,
+ "visibility": is_public,
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ }
+ )
+
+ with self._public_room_id_gen.get_next() as next_id:
+ yield self.runInteraction(
+ "set_room_is_public_appservice",
+ set_room_is_public_appservice_txn, next_id,
+ )
+ self.hs.get_notifier().on_new_replication_data()
def get_public_room_ids(self):
return self._simple_select_onecol(
@@ -259,38 +347,96 @@ class RoomStore(SQLBaseStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def get_public_room_ids_at_stream_id(self, stream_id):
+ @cached(num_args=2, max_entries=100)
+ def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+ """Get pulbic rooms for a particular list, or across all lists.
+
+ Args:
+ stream_id (int)
+ network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+ means the main list, None means all lsits.
+ """
return self.runInteraction(
"get_public_room_ids_at_stream_id",
- self.get_public_room_ids_at_stream_id_txn, stream_id
+ self.get_public_room_ids_at_stream_id_txn,
+ stream_id, network_tuple=network_tuple
)
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id):
+ def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+ network_tuple):
return {
rm
- for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items()
+ for rm, vis in self.get_published_at_stream_id_txn(
+ txn, stream_id, network_tuple=network_tuple
+ ).items()
if vis
}
- def get_published_at_stream_id_txn(self, txn, stream_id):
- sql = ("""
- SELECT room_id, visibility FROM public_room_list_stream
- INNER JOIN (
- SELECT room_id, max(stream_id) AS stream_id
+ def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+ if network_tuple:
+ # We want to get from a particular list. No aggregation required.
+
+ sql = ("""
+ SELECT room_id, visibility FROM public_room_list_stream
+ INNER JOIN (
+ SELECT room_id, max(stream_id) AS stream_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ? %s
+ GROUP BY room_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ if network_tuple.appservice_id is not None:
+ txn.execute(
+ sql % ("AND appservice_id = ? AND network_id = ?",),
+ (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+ )
+ else:
+ txn.execute(
+ sql % ("AND appservice_id IS NULL",),
+ (stream_id,)
+ )
+ return dict(txn.fetchall())
+ else:
+ # We want to get from all lists, so we need to aggregate the results
+
+ logger.info("Executing full list")
+
+ sql = ("""
+ SELECT room_id, visibility
FROM public_room_list_stream
- WHERE stream_id <= ?
- GROUP BY room_id
- ) grouped USING (room_id, stream_id)
- """)
+ INNER JOIN (
+ SELECT
+ room_id, max(stream_id) AS stream_id, appservice_id,
+ network_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ?
+ GROUP BY room_id, appservice_id, network_id
+ ) grouped USING (room_id, stream_id)
+ """)
- txn.execute(sql, (stream_id,))
- return dict(txn.fetchall())
+ txn.execute(
+ sql,
+ (stream_id,)
+ )
+
+ results = {}
+ # A room is visible if its visible on any list.
+ for room_id, visibility in txn.fetchall():
+ results[room_id] = bool(visibility) or results.get(room_id, False)
- def get_public_room_changes(self, prev_stream_id, new_stream_id):
+ return results
+
+ def get_public_room_changes(self, prev_stream_id, new_stream_id,
+ network_tuple):
def get_public_room_changes_txn(txn):
- then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id)
+ then_rooms = self.get_public_room_ids_at_stream_id_txn(
+ txn, prev_stream_id, network_tuple
+ )
- now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id)
+ now_rooms_dict = self.get_published_at_stream_id_txn(
+ txn, new_stream_id, network_tuple
+ )
now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis
@@ -311,7 +457,8 @@ class RoomStore(SQLBaseStore):
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = ("""
- SELECT stream_id, room_id, visibility FROM public_room_list_stream
+ SELECT stream_id, room_id, visibility, appservice_id, network_id
+ FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 866d64e6..946d5a81 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -24,6 +24,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id
import logging
+import ujson as json
logger = logging.getLogger(__name__)
@@ -34,7 +35,15 @@ RoomsForUser = namedtuple(
)
+_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+
+
class RoomMemberStore(SQLBaseStore):
+ def __init__(self, hs):
+ super(RoomMemberStore, self).__init__(hs)
+ self.register_background_update_handler(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+ )
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
@@ -49,6 +58,8 @@ class RoomMemberStore(SQLBaseStore):
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
}
for event in events
]
@@ -398,7 +409,7 @@ class RoomMemberStore(SQLBaseStore):
table="room_memberships",
column="event_id",
iterable=member_event_ids,
- retcols=['user_id'],
+ retcols=['user_id', 'display_name', 'avatar_url'],
keyvalues={
"membership": Membership.JOIN,
},
@@ -406,11 +417,21 @@ class RoomMemberStore(SQLBaseStore):
desc="_get_joined_users_from_context",
)
- users_in_room = set(row["user_id"] for row in rows)
+ users_in_room = {
+ row["user_id"]: {
+ "display_name": row["display_name"],
+ "avatar_url": row["avatar_url"],
+ }
+ for row in rows
+ }
+
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
- users_in_room.add(event.state_key)
+ users_in_room[event.state_key] = {
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
defer.returnValue(users_in_room)
@@ -448,3 +469,78 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(True)
defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def _background_add_membership_profile(self, progress, batch_size):
+ target_min_stream_id = progress.get(
+ "target_min_stream_id_inclusive", self._min_stream_order_on_start
+ )
+ max_stream_id = progress.get(
+ "max_stream_id_exclusive", self._stream_order_on_start + 1
+ )
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def add_membership_profile_txn(txn):
+ sql = ("""
+ SELECT stream_ordering, event_id, events.room_id, content
+ FROM events
+ INNER JOIN room_memberships USING (event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND type = 'm.room.member'
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """)
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = self.cursor_to_dict(txn)
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1]["stream_ordering"]
+
+ to_update = []
+ for row in rows:
+ event_id = row["event_id"]
+ room_id = row["room_id"]
+ try:
+ content = json.loads(row["content"])
+ except:
+ continue
+
+ display_name = content.get("displayname", None)
+ avatar_url = content.get("avatar_url", None)
+
+ if display_name or avatar_url:
+ to_update.append((
+ display_name, avatar_url, event_id, room_id
+ ))
+
+ to_update_sql = ("""
+ UPDATE room_memberships SET display_name = ?, avatar_url = ?
+ WHERE event_id = ? AND room_id = ?
+ """)
+ for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
+ clump = to_update[index:index + INSERT_CLUMP_SIZE]
+ txn.executemany(to_update_sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ }
+
+ self._background_update_progress_txn(
+ txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
+ )
+
+ return len(rows)
+
+ result = yield self.runInteraction(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
+ )
+
+ if not result:
+ yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
+
+ defer.returnValue(result)
diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/schema/delta/39/appservice_room_list.sql
new file mode 100644
index 00000000..74bdc490
--- /dev/null
+++ b/synapse/storage/schema/delta/39/appservice_room_list.sql
@@ -0,0 +1,29 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE appservice_room_list(
+ appservice_id TEXT NOT NULL,
+ network_id TEXT NOT NULL,
+ room_id TEXT NOT NULL
+);
+
+-- Each appservice can have multiple published room lists associated with them,
+-- keyed of a particular network_id
+CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list(
+ appservice_id, network_id, room_id
+);
+
+ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT;
+ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT;
diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql
new file mode 100644
index 00000000..00be801e
--- /dev/null
+++ b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id);
diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/schema/delta/39/event_push_index.sql
new file mode 100644
index 00000000..de2ad93e
--- /dev/null
+++ b/synapse/storage/schema/delta/39/event_push_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('event_push_actions_highlights_index', '{}');
diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/schema/delta/39/federation_out_position.sql
new file mode 100644
index 00000000..5af81429
--- /dev/null
+++ b/synapse/storage/schema/delta/39/federation_out_position.sql
@@ -0,0 +1,22 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ CREATE TABLE federation_stream_position(
+ type TEXT NOT NULL,
+ stream_id INTEGER NOT NULL
+ );
+
+ INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1);
+ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events;
diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/schema/delta/39/membership_profile.sql
new file mode 100644
index 00000000..1bf911c8
--- /dev/null
+++ b/synapse/storage/schema/delta/39/membership_profile.sql
@@ -0,0 +1,20 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_memberships ADD COLUMN display_name TEXT;
+ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT;
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('room_membership_profile_update', '{}');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 49abf0ac..23e7ad99 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -653,7 +653,10 @@ class StateStore(SQLBaseStore):
else:
state_dict = results[group]
- state_dict.update(group_state_dict)
+ state_dict.update({
+ (intern_string(k[0]), intern_string(k[1])): v
+ for k, v in group_state_dict.items()
+ })
self._state_group_cache.update(
cache_seq_num,
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 888b1cb3..2dc24951 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -541,6 +541,9 @@ class StreamStore(SQLBaseStore):
def get_room_max_stream_ordering(self):
return self._stream_id_gen.get_current_token()
+ def get_room_min_stream_ordering(self):
+ return self._backfill_id_gen.get_current_token()
+
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
Args:
@@ -765,3 +768,50 @@ class StreamStore(SQLBaseStore):
"token": end_token,
},
}
+
+ @defer.inlineCallbacks
+ def get_all_new_events_stream(self, from_id, current_id, limit):
+ """Get all new events"""
+
+ def get_all_new_events_stream_txn(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id"
+ " FROM events AS e"
+ " WHERE"
+ " ? < e.stream_ordering AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (from_id, current_id, limit))
+ rows = txn.fetchall()
+
+ upper_bound = current_id
+ if len(rows) == limit:
+ upper_bound = rows[-1][0]
+
+ return upper_bound, [row[1] for row in rows]
+
+ upper_bound, event_ids = yield self.runInteraction(
+ "get_all_new_events_stream", get_all_new_events_stream_txn,
+ )
+
+ events = yield self._get_events(event_ids)
+
+ defer.returnValue((upper_bound, events))
+
+ def get_federation_out_pos(self, typ):
+ return self._simple_select_one_onecol(
+ table="federation_stream_position",
+ retcol="stream_id",
+ keyvalues={"type": typ},
+ desc="get_federation_out_pos"
+ )
+
+ def update_federation_out_pos(self, typ, stream_id):
+ return self._simple_update_one(
+ table="federation_stream_position",
+ keyvalues={"type": typ},
+ updatevalues={"stream_id": stream_id},
+ desc="update_federation_out_pos",
+ )
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index adab520c..809fdd31 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -200,25 +200,48 @@ class TransactionStore(SQLBaseStore):
def _set_destination_retry_timings(self, txn, destination,
retry_last_ts, retry_interval):
- txn.call_after(self.get_destination_retry_timings.invalidate, (destination,))
+ self.database_engine.lock_table(txn, "destinations")
- self._simple_upsert_txn(
+ self._invalidate_cache_and_stream(
+ txn, self.get_destination_retry_timings, (destination,)
+ )
+
+ # We need to be careful here as the data may have changed from under us
+ # due to a worker setting the timings.
+
+ prev_row = self._simple_select_one_txn(
txn,
- "destinations",
+ table="destinations",
keyvalues={
"destination": destination,
},
- values={
- "retry_last_ts": retry_last_ts,
- "retry_interval": retry_interval,
- },
- insertion_values={
- "destination": destination,
- "retry_last_ts": retry_last_ts,
- "retry_interval": retry_interval,
- }
+ retcols=("retry_last_ts", "retry_interval"),
+ allow_none=True,
)
+ if not prev_row:
+ self._simple_insert_txn(
+ txn,
+ table="destinations",
+ values={
+ "destination": destination,
+ "retry_last_ts": retry_last_ts,
+ "retry_interval": retry_interval,
+ }
+ )
+ elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
+ self._simple_update_one_txn(
+ txn,
+ "destinations",
+ keyvalues={
+ "destination": destination,
+ },
+ updatevalues={
+ "retry_last_ts": retry_last_ts,
+ "retry_interval": retry_interval,
+ },
+ )
+
def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction.
diff --git a/synapse/types.py b/synapse/types.py
index ffab12df..3a3ab21d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -274,3 +274,37 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
+
+
+class ThirdPartyInstanceID(
+ namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
+):
+ # Deny iteration because it will bite you if you try to create a singleton
+ # set by:
+ # users = set(user)
+ def __iter__(self):
+ raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
+
+ # Because this class is a namedtuple of strings, it is deeply immutable.
+ def __copy__(self):
+ return self
+
+ def __deepcopy__(self, memo):
+ return self
+
+ @classmethod
+ def from_string(cls, s):
+ bits = s.split("|", 2)
+ if len(bits) != 2:
+ raise SynapseError(400, "Invalid ID %r" % (s,))
+
+ return cls(appservice_id=bits[0], network_id=bits[1])
+
+ def to_string(self):
+ return "%s|%s" % (self.appservice_id, self.network_id,)
+
+ __str__ = to_string
+
+ @classmethod
+ def create(cls, appservice_id, network_id,):
+ return cls(appservice_id=appservice_id, network_id=network_id)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index c05b9450..30fc4801 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -24,6 +24,11 @@ import logging
logger = logging.getLogger(__name__)
+class DeferredTimedOutError(SynapseError):
+ def __init__(self):
+ super(SynapseError).__init__(504, "Timed out")
+
+
def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
@@ -89,7 +94,7 @@ class Clock(object):
def timed_out_fn():
try:
- ret_deferred.errback(SynapseError(504, "Timed out"))
+ ret_deferred.errback(DeferredTimedOutError())
except:
pass
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 347fb1e3..16ed183d 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -197,6 +197,64 @@ class Linearizer(object):
defer.returnValue(_ctx_manager())
+class Limiter(object):
+ """Limits concurrent access to resources based on a key. Useful to ensure
+ only a few thing happen at a time on a given resource.
+
+ Example:
+
+ with (yield limiter.queue("test_key")):
+ # do some work.
+
+ """
+ def __init__(self, max_count):
+ """
+ Args:
+ max_count(int): The maximum number of concurrent access
+ """
+ self.max_count = max_count
+
+ # key_to_defer is a map from the key to a 2 element list where
+ # the first element is the number of things executing
+ # the second element is a list of deferreds for the things blocked from
+ # executing.
+ self.key_to_defer = {}
+
+ @defer.inlineCallbacks
+ def queue(self, key):
+ entry = self.key_to_defer.setdefault(key, [0, []])
+
+ # If the number of things executing is greater than the maximum
+ # then add a deferred to the list of blocked items
+ # When on of the things currently executing finishes it will callback
+ # this item so that it can continue executing.
+ if entry[0] >= self.max_count:
+ new_defer = defer.Deferred()
+ entry[1].append(new_defer)
+ with PreserveLoggingContext():
+ yield new_defer
+
+ entry[0] += 1
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ # We've finished executing so check if there are any things
+ # blocked waiting to execute and start one of them
+ entry[0] -= 1
+ try:
+ entry[1].pop(0).callback(None)
+ except IndexError:
+ # If nothing else is executing for this key then remove it
+ # from the map
+ if entry[0] == 0:
+ self.key_to_defer.pop(key, None)
+
+ defer.returnValue(_ctx_manager())
+
+
class ReadWriteLock(object):
"""A deferred style read write lock.
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 3fd5c3d9..d668e5a6 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -76,15 +76,26 @@ class JsonEncodedObject(object):
d.update(self.unrecognized_keys)
return d
+ def get_internal_dict(self):
+ d = {
+ k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
+ if k in self.valid_keys
+ }
+ d.update(self.unrecognized_keys)
+ return d
+
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
-def _encode(obj):
+def _encode(obj, internal=False):
if type(obj) is list:
- return [_encode(o) for o in obj]
+ return [_encode(o, internal=internal) for o in obj]
if isinstance(obj, JsonEncodedObject):
- return obj.get_dict()
+ if internal:
+ return obj.get_internal_dict()
+ else:
+ return obj.get_dict()
return obj
diff --git a/synapse/util/ldap_auth_provider.py b/synapse/util/ldap_auth_provider.py
deleted file mode 100644
index 1b989248..00000000
--- a/synapse/util/ldap_auth_provider.py
+++ /dev/null
@@ -1,369 +0,0 @@
-
-from twisted.internet import defer
-
-from synapse.config._base import ConfigError
-from synapse.types import UserID
-
-import ldap3
-import ldap3.core.exceptions
-
-import logging
-
-try:
- import ldap3
- import ldap3.core.exceptions
-except ImportError:
- ldap3 = None
- pass
-
-
-logger = logging.getLogger(__name__)
-
-
-class LDAPMode(object):
- SIMPLE = "simple",
- SEARCH = "search",
-
- LIST = (SIMPLE, SEARCH)
-
-
-class LdapAuthProvider(object):
- __version__ = "0.1"
-
- def __init__(self, config, account_handler):
- self.account_handler = account_handler
-
- if not ldap3:
- raise RuntimeError(
- 'Missing ldap3 library. This is required for LDAP Authentication.'
- )
-
- self.ldap_mode = config.mode
- self.ldap_uri = config.uri
- self.ldap_start_tls = config.start_tls
- self.ldap_base = config.base
- self.ldap_attributes = config.attributes
- if self.ldap_mode == LDAPMode.SEARCH:
- self.ldap_bind_dn = config.bind_dn
- self.ldap_bind_password = config.bind_password
- self.ldap_filter = config.filter
-
- @defer.inlineCallbacks
- def check_password(self, user_id, password):
- """ Attempt to authenticate a user against an LDAP Server
- and register an account if none exists.
-
- Returns:
- True if authentication against LDAP was successful
- """
- localpart = UserID.from_string(user_id).localpart
-
- try:
- server = ldap3.Server(self.ldap_uri)
- logger.debug(
- "Attempting LDAP connection with %s",
- self.ldap_uri
- )
-
- if self.ldap_mode == LDAPMode.SIMPLE:
- result, conn = self._ldap_simple_bind(
- server=server, localpart=localpart, password=password
- )
- logger.debug(
- 'LDAP authentication method simple bind returned: %s (conn: %s)',
- result,
- conn
- )
- if not result:
- defer.returnValue(False)
- elif self.ldap_mode == LDAPMode.SEARCH:
- result, conn = self._ldap_authenticated_search(
- server=server, localpart=localpart, password=password
- )
- logger.debug(
- 'LDAP auth method authenticated search returned: %s (conn: %s)',
- result,
- conn
- )
- if not result:
- defer.returnValue(False)
- else:
- raise RuntimeError(
- 'Invalid LDAP mode specified: {mode}'.format(
- mode=self.ldap_mode
- )
- )
-
- try:
- logger.info(
- "User authenticated against LDAP server: %s",
- conn
- )
- except NameError:
- logger.warn(
- "Authentication method yielded no LDAP connection, aborting!"
- )
- defer.returnValue(False)
-
- # check if user with user_id exists
- if (yield self.account_handler.check_user_exists(user_id)):
- # exists, authentication complete
- conn.unbind()
- defer.returnValue(True)
-
- else:
- # does not exist, fetch metadata for account creation from
- # existing ldap connection
- query = "({prop}={value})".format(
- prop=self.ldap_attributes['uid'],
- value=localpart
- )
-
- if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
- query = "(&{filter}{user_filter})".format(
- filter=query,
- user_filter=self.ldap_filter
- )
- logger.debug(
- "ldap registration filter: %s",
- query
- )
-
- conn.search(
- search_base=self.ldap_base,
- search_filter=query,
- attributes=[
- self.ldap_attributes['name'],
- self.ldap_attributes['mail']
- ]
- )
-
- if len(conn.response) == 1:
- attrs = conn.response[0]['attributes']
- mail = attrs[self.ldap_attributes['mail']][0]
- name = attrs[self.ldap_attributes['name']][0]
-
- # create account
- user_id, access_token = (
- yield self.account_handler.register(localpart=localpart)
- )
-
- # TODO: bind email, set displayname with data from ldap directory
-
- logger.info(
- "Registration based on LDAP data was successful: %d: %s (%s, %)",
- user_id,
- localpart,
- name,
- mail
- )
-
- defer.returnValue(True)
- else:
- if len(conn.response) == 0:
- logger.warn("LDAP registration failed, no result.")
- else:
- logger.warn(
- "LDAP registration failed, too many results (%s)",
- len(conn.response)
- )
-
- defer.returnValue(False)
-
- defer.returnValue(False)
-
- except ldap3.core.exceptions.LDAPException as e:
- logger.warn("Error during ldap authentication: %s", e)
- defer.returnValue(False)
-
- @staticmethod
- def parse_config(config):
- class _LdapConfig(object):
- pass
-
- ldap_config = _LdapConfig()
-
- ldap_config.enabled = config.get("enabled", False)
-
- ldap_config.mode = LDAPMode.SIMPLE
-
- # verify config sanity
- _require_keys(config, [
- "uri",
- "base",
- "attributes",
- ])
-
- ldap_config.uri = config["uri"]
- ldap_config.start_tls = config.get("start_tls", False)
- ldap_config.base = config["base"]
- ldap_config.attributes = config["attributes"]
-
- if "bind_dn" in config:
- ldap_config.mode = LDAPMode.SEARCH
- _require_keys(config, [
- "bind_dn",
- "bind_password",
- ])
-
- ldap_config.bind_dn = config["bind_dn"]
- ldap_config.bind_password = config["bind_password"]
- ldap_config.filter = config.get("filter", None)
-
- # verify attribute lookup
- _require_keys(config['attributes'], [
- "uid",
- "name",
- "mail",
- ])
-
- return ldap_config
-
- def _ldap_simple_bind(self, server, localpart, password):
- """ Attempt a simple bind with the credentials
- given by the user against the LDAP server.
-
- Returns True, LDAP3Connection
- if the bind was successful
- Returns False, None
- if an error occured
- """
-
- try:
- # bind with the the local users ldap credentials
- bind_dn = "{prop}={value},{base}".format(
- prop=self.ldap_attributes['uid'],
- value=localpart,
- base=self.ldap_base
- )
- conn = ldap3.Connection(server, bind_dn, password,
- authentication=ldap3.AUTH_SIMPLE)
- logger.debug(
- "Established LDAP connection in simple bind mode: %s",
- conn
- )
-
- if self.ldap_start_tls:
- conn.start_tls()
- logger.debug(
- "Upgraded LDAP connection in simple bind mode through StartTLS: %s",
- conn
- )
-
- if conn.bind():
- # GOOD: bind okay
- logger.debug("LDAP Bind successful in simple bind mode.")
- return True, conn
-
- # BAD: bind failed
- logger.info(
- "Binding against LDAP failed for '%s' failed: %s",
- localpart, conn.result['description']
- )
- conn.unbind()
- return False, None
-
- except ldap3.core.exceptions.LDAPException as e:
- logger.warn("Error during LDAP authentication: %s", e)
- return False, None
-
- def _ldap_authenticated_search(self, server, localpart, password):
- """ Attempt to login with the preconfigured bind_dn
- and then continue searching and filtering within
- the base_dn
-
- Returns (True, LDAP3Connection)
- if a single matching DN within the base was found
- that matched the filter expression, and with which
- a successful bind was achieved
-
- The LDAP3Connection returned is the instance that was used to
- verify the password not the one using the configured bind_dn.
- Returns (False, None)
- if an error occured
- """
-
- try:
- conn = ldap3.Connection(
- server,
- self.ldap_bind_dn,
- self.ldap_bind_password
- )
- logger.debug(
- "Established LDAP connection in search mode: %s",
- conn
- )
-
- if self.ldap_start_tls:
- conn.start_tls()
- logger.debug(
- "Upgraded LDAP connection in search mode through StartTLS: %s",
- conn
- )
-
- if not conn.bind():
- logger.warn(
- "Binding against LDAP with `bind_dn` failed: %s",
- conn.result['description']
- )
- conn.unbind()
- return False, None
-
- # construct search_filter like (uid=localpart)
- query = "({prop}={value})".format(
- prop=self.ldap_attributes['uid'],
- value=localpart
- )
- if self.ldap_filter:
- # combine with the AND expression
- query = "(&{query}{filter})".format(
- query=query,
- filter=self.ldap_filter
- )
- logger.debug(
- "LDAP search filter: %s",
- query
- )
- conn.search(
- search_base=self.ldap_base,
- search_filter=query
- )
-
- if len(conn.response) == 1:
- # GOOD: found exactly one result
- user_dn = conn.response[0]['dn']
- logger.debug('LDAP search found dn: %s', user_dn)
-
- # unbind and simple bind with user_dn to verify the password
- # Note: do not use rebind(), for some reason it did not verify
- # the password for me!
- conn.unbind()
- return self._ldap_simple_bind(server, localpart, password)
- else:
- # BAD: found 0 or > 1 results, abort!
- if len(conn.response) == 0:
- logger.info(
- "LDAP search returned no results for '%s'",
- localpart
- )
- else:
- logger.info(
- "LDAP search returned too many (%s) results for '%s'",
- len(conn.response), localpart
- )
- conn.unbind()
- return False, None
-
- except ldap3.core.exceptions.LDAPException as e:
- logger.warn("Error during LDAP authentication: %s", e)
- return False, None
-
-
-def _require_keys(config, required):
- missing = [key for key in required if key not in config]
- if missing:
- raise ConfigError(
- "LDAP enabled but missing required config values: {}".format(
- ", ".join(missing)
- )
- )
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 49527f4d..e2de7fce 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -121,15 +121,9 @@ class RetryDestinationLimiter(object):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
- def err(failure):
- logger.exception(
- "Failed to store set_destination_retry_timings",
- failure.value
- )
-
valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException):
- valid_err_code = 0 <= exc_val.code < 500
+ valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500
if exc_type is None or valid_err_code:
# We connected successfully.
@@ -151,6 +145,15 @@ class RetryDestinationLimiter(object):
retry_last_ts = int(self.clock.time_msec())
- self.store.set_destination_retry_timings(
- self.destination, retry_last_ts, self.retry_interval
- ).addErrback(err)
+ @defer.inlineCallbacks
+ def store_retry_timings():
+ try:
+ yield self.store.set_destination_retry_timings(
+ self.destination, retry_last_ts, self.retry_interval
+ )
+ except:
+ logger.exception(
+ "Failed to store set_destination_retry_timings",
+ )
+
+ store_retry_timings()