summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES.rst47
-rw-r--r--README.rst2
-rw-r--r--docs/workers.rst97
-rwxr-xr-xjenkins-unittests.sh1
-rwxr-xr-xjenkins/prepare_synapse.sh1
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py45
-rw-r--r--synapse/app/appservice.py209
-rwxr-xr-xsynapse/app/homeserver.py28
-rw-r--r--synapse/app/media_repository.py212
-rw-r--r--synapse/app/pusher.py16
-rw-r--r--synapse/app/synchrotron.py54
-rw-r--r--synapse/appservice/__init__.py90
-rw-r--r--synapse/appservice/api.py60
-rw-r--r--synapse/appservice/scheduler.py67
-rw-r--r--synapse/config/appservice.py11
-rw-r--r--synapse/crypto/keyring.py166
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/federation/federation_base.py7
-rw-r--r--synapse/federation/federation_client.py64
-rw-r--r--synapse/federation/transaction_queue.py381
-rw-r--r--synapse/handlers/__init__.py3
-rw-r--r--synapse/handlers/appservice.py136
-rw-r--r--synapse/handlers/auth.py31
-rw-r--r--synapse/handlers/federation.py68
-rw-r--r--synapse/handlers/message.py35
-rw-r--r--synapse/handlers/presence.py36
-rw-r--r--synapse/handlers/room_member.py16
-rw-r--r--synapse/handlers/sync.py14
-rw-r--r--synapse/handlers/typing.py12
-rw-r--r--synapse/http/matrixfederationclient.py4
-rw-r--r--synapse/http/server.py138
-rw-r--r--synapse/notifier.py80
-rw-r--r--synapse/push/action_generator.py11
-rw-r--r--synapse/push/baserules.py36
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py41
-rw-r--r--synapse/push/push_tools.py9
-rw-r--r--synapse/push/pusherpool.py20
-rw-r--r--synapse/replication/resource.py21
-rw-r--r--synapse/replication/slave/storage/_base.py30
-rw-r--r--synapse/replication/slave/storage/appservice.py10
-rw-r--r--synapse/replication/slave/storage/directory.py2
-rw-r--r--synapse/replication/slave/storage/registration.py5
-rw-r--r--synapse/rest/__init__.py4
-rw-r--r--synapse/rest/client/v1/admin.py8
-rw-r--r--synapse/rest/client/v1/base.py1
-rw-r--r--synapse/rest/client/v1/directory.py5
-rw-r--r--synapse/rest/client/v1/events.py11
-rw-r--r--synapse/rest/client/v1/initial_sync.py4
-rw-r--r--synapse/rest/client/v1/login.py162
-rw-r--r--synapse/rest/client/v1/profile.py12
-rw-r--r--synapse/rest/client/v1/register.py2
-rw-r--r--synapse/rest/client/v1/room.py49
-rw-r--r--synapse/rest/client/v2_alpha/notifications.py99
-rw-r--r--synapse/rest/client/v2_alpha/register.py3
-rw-r--r--synapse/rest/client/v2_alpha/sync.py2
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py78
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py4
-rw-r--r--synapse/rest/media/v1/download_resource.py1
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py354
-rw-r--r--synapse/server.py9
-rw-r--r--synapse/server.pyi4
-rw-r--r--synapse/storage/__init__.py8
-rw-r--r--synapse/storage/_base.py74
-rw-r--r--synapse/storage/appservice.py191
-rw-r--r--synapse/storage/directory.py37
-rw-r--r--synapse/storage/event_push_actions.py32
-rw-r--r--synapse/storage/events.py101
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/presence.py32
-rw-r--r--synapse/storage/push_rule.py60
-rw-r--r--synapse/storage/pusher.py2
-rw-r--r--synapse/storage/receipts.py27
-rw-r--r--synapse/storage/registration.py103
-rw-r--r--synapse/storage/roommember.py12
-rw-r--r--synapse/storage/schema/delta/34/appservice_stream.sql23
-rw-r--r--synapse/storage/schema/delta/34/cache_stream.py46
-rw-r--r--synapse/storage/schema/delta/34/push_display_name_rename.sql20
-rw-r--r--synapse/storage/schema/delta/34/received_txn_purge.py32
-rw-r--r--synapse/storage/signatures.py2
-rw-r--r--synapse/storage/state.py4
-rw-r--r--synapse/storage/stream.py6
-rw-r--r--synapse/storage/transactions.py17
-rw-r--r--synapse/types.py7
-rw-r--r--synapse/util/async.py9
-rw-r--r--synapse/util/caches/descriptors.py117
-rw-r--r--synapse/util/caches/lrucache.py39
-rw-r--r--synapse/util/caches/treecache.py3
-rw-r--r--synapse/util/logcontext.py16
-rw-r--r--synapse/util/metrics.py19
-rw-r--r--synapse/visibility.py6
-rw-r--r--tests/appservice/test_appservice.py109
-rw-r--r--tests/appservice/test_scheduler.py2
-rw-r--r--tests/handlers/test_appservice.py24
-rw-r--r--tests/handlers/test_auth.py52
-rw-r--r--tests/storage/test__base.py116
-rw-r--r--tests/test_preview.py80
-rw-r--r--tests/util/test_lrucache.py153
98 files changed, 3281 insertions, 1434 deletions
diff --git a/CHANGES.rst b/CHANGES.rst
index 7ebb42b0..49673ccc 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,50 @@
+Changes in synapse v0.17.1 (2016-08-24)
+=======================================
+
+Changes:
+
+* Delete old received_transactions rows (PR #1038)
+* Pass through user-supplied content in /join/$room_id (PR #1039)
+
+
+Bug fixes:
+
+* Fix bug with backfill (PR #1040)
+
+
+Changes in synapse v0.17.1-rc1 (2016-08-22)
+===========================================
+
+Features:
+
+* Add notification API (PR #1028)
+
+
+Changes:
+
+* Don't print stack traces when failing to get remote keys (PR #996)
+* Various federation /event/ perf improvements (PR #998)
+* Only process one local membership event per room at a time (PR #1005)
+* Move default display name push rule (PR #1011, #1023)
+* Fix up preview URL API. Add tests. (PR #1015)
+* Set ``Content-Security-Policy`` on media repo (PR #1021)
+* Make notify_interested_services faster (PR #1022)
+* Add usage stats to prometheus monitoring (PR #1037)
+
+
+Bug fixes:
+
+* Fix token login (PR #993)
+* Fix CAS login (PR #994, #995)
+* Fix /sync to not clobber status_msg (PR #997)
+* Fix redacted state events to include prev_content (PR #1003)
+* Fix some bugs in the auth/ldap handler (PR #1007)
+* Fix backfill request to limit URI length, so that remotes don't reject the
+ requests due to path length limits (PR #1012)
+* Fix AS push code to not send duplicate events (PR #1025)
+
+
+
Changes in synapse v0.17.0 (2016-08-08)
=======================================
diff --git a/README.rst b/README.rst
index d6586708..172dd4df 100644
--- a/README.rst
+++ b/README.rst
@@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- Python 2.7
-- At least 512 MB RAM.
+- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the
diff --git a/docs/workers.rst b/docs/workers.rst
new file mode 100644
index 00000000..4eb05b0e
--- /dev/null
+++ b/docs/workers.rst
@@ -0,0 +1,97 @@
+Scaling synapse via workers
+---------------------------
+
+Synapse has experimental support for splitting out functionality into
+multiple separate python processes, helping greatly with scalability. These
+processes are called 'workers', and are (eventually) intended to scale
+horizontally independently.
+
+All processes continue to share the same database instance, and as such, workers
+only work with postgres based synapse deployments (sharing a single sqlite
+across multiple processes is a recipe for disaster, plus you should be using
+postgres anyway if you care about scalability).
+
+The workers communicate with the master synapse process via a synapse-specific
+HTTP protocol called 'replication' - analogous to MySQL or Postgres style
+database replication; feeding a stream of relevant data to the workers so they
+can be kept in sync with the main synapse process and database state.
+
+To enable workers, you need to add a replication listener to the master synapse, e.g.::
+
+ listeners:
+ - port: 9092
+ bind_address: '127.0.0.1'
+ type: http
+ tls: false
+ x_forwarded: false
+ resources:
+ - names: [replication]
+ compress: false
+
+Under **no circumstances** should this replication API listener be exposed to the
+public internet; it currently implements no authentication whatsoever and is
+unencrypted HTTP.
+
+You then create a set of configs for the various worker processes. These should be
+worker configuration files should be stored in a dedicated subdirectory, to allow
+synctl to manipulate them.
+
+The current available worker applications are:
+ * synapse.app.pusher - handles sending push notifications to sygnal and email
+ * synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
+ * synapse.app.appservice - handles output traffic to Application Services
+ * synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
+ * synapse.app.media_repository - handles the media repository.
+
+Each worker configuration file inherits the configuration of the main homeserver
+configuration file. You can then override configuration specific to that worker,
+e.g. the HTTP listener that it provides (if any); logging configuration; etc.
+You should minimise the number of overrides though to maintain a usable config.
+
+You must specify the type of worker application (worker_app) and the replication
+endpoint that it's talking to on the main synapse process (worker_replication_url).
+
+For instance::
+
+ worker_app: synapse.app.synchrotron
+
+ # The replication listener on the synapse to talk to.
+ worker_replication_url: http://127.0.0.1:9092/_synapse/replication
+
+ worker_listeners:
+ - type: http
+ port: 8083
+ resources:
+ - names:
+ - client
+
+ worker_daemonize: True
+ worker_pid_file: /home/matrix/synapse/synchrotron.pid
+ worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
+
+...is a full configuration for a synchrotron worker instance, which will expose a
+plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
+by the main synapse.
+
+Obviously you should configure your loadbalancer to route the /sync endpoint to
+the synchrotron instance(s) in this instance.
+
+Finally, to actually run your worker-based synapse, you must pass synctl the -a
+commandline option to tell it to operate on all the worker configurations found
+in the given directory, e.g.::
+
+ synctl -a $CONFIG/workers start
+
+Currently one should always restart all workers when restarting or upgrading
+synapse, unless you explicitly know it's safe not to. For instance, restarting
+synapse without restarting all the synchrotrons may result in broken typing
+notifications.
+
+To manipulate a specific worker, you pass the -w option to synctl::
+
+ synctl -w $CONFIG/workers/synchrotron.yaml restart
+
+All of the above is highly experimental and subject to change as Synapse evolves,
+but documenting it here to help folks needing highly scalable Synapses similar
+to the one running matrix.org!
+
diff --git a/jenkins-unittests.sh b/jenkins-unittests.sh
index 6b0c296c..4c2f103e 100755
--- a/jenkins-unittests.sh
+++ b/jenkins-unittests.sh
@@ -25,5 +25,6 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
+$TOX_BIN/pip install lxml
tox -e py27
diff --git a/jenkins/prepare_synapse.sh b/jenkins/prepare_synapse.sh
index 237223c8..6c26c584 100755
--- a/jenkins/prepare_synapse.sh
+++ b/jenkins/prepare_synapse.sh
@@ -14,6 +14,7 @@ fi
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
+$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2
diff --git a/synapse/__init__.py b/synapse/__init__.py
index a63ee565..43bf78f8 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.17.0"
+__version__ = "0.17.1"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 59db76de..0db26fcf 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -675,27 +675,18 @@ class Auth(object):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
- user_prefix = "user_id = "
- user = None
- user_id = None
- guest = False
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- user_id = caveat.caveat_id[len(user_prefix):]
- user = UserID.from_string(user_id)
- elif caveat.caveat_id == "guest = true":
- guest = True
+ user_id = self.get_user_id_from_macaroon(macaroon)
+ user = UserID.from_string(user_id)
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
)
- if user is None:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
- errcode=Codes.UNKNOWN_TOKEN
- )
+ guest = False
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id == "guest = true":
+ guest = True
if guest:
ret = {
@@ -743,6 +734,29 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
+ def get_user_id_from_macaroon(self, macaroon):
+ """Retrieve the user_id given by the caveats on the macaroon.
+
+ Does *not* validate the macaroon.
+
+ Args:
+ macaroon (pymacaroons.Macaroon): The macaroon to validate
+
+ Returns:
+ (str) user id
+
+ Raises:
+ AuthError if there is no user_id caveat in the macaroon
+ """
+ user_prefix = "user_id = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(user_prefix):
+ return caveat.caveat_id[len(user_prefix):]
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -754,6 +768,7 @@ class Auth(object):
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
+ user_id (str): The user_id required
"""
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
new file mode 100644
index 00000000..57587aed
--- /dev/null
+++ b/synapse/app/appservice.py
@@ -0,0 +1,209 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse
+
+from synapse.server import HomeServer
+from synapse.config._base import ConfigError
+from synapse.config.logger import setup_logging
+from synapse.config.homeserver import HomeServerConfig
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.storage.engines import create_engine
+from synapse.util.async import sleep
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.rlimit import change_resource_limit
+from synapse.util.versionstring import get_version_string
+
+from twisted.internet import reactor, defer
+from twisted.web.resource import Resource
+
+from daemonize import Daemonize
+
+import sys
+import logging
+import gc
+
+logger = logging.getLogger("synapse.app.appservice")
+
+
+class AppserviceSlaveStore(
+ DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+):
+ pass
+
+
+class AppserviceServer(HomeServer):
+ def get_db_conn(self, run_new_connection=True):
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_address = listener_config.get("bind_address", "")
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+
+ root_resource = create_resource_tree(resources, Resource())
+ reactor.listenTCP(
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ ),
+ interface=bind_address
+ )
+ logger.info("Synapse appservice now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ reactor.listenTCP(
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ ),
+ interface=listener.get("bind_address", '127.0.0.1')
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ @defer.inlineCallbacks
+ def replicate(self):
+ http_client = self.get_simple_http_client()
+ store = self.get_datastore()
+ replication_url = self.config.worker_replication_url
+ appservice_handler = self.get_application_service_handler()
+
+ @defer.inlineCallbacks
+ def replicate(results):
+ stream = results.get("events")
+ if stream:
+ max_stream_id = stream["position"]
+ yield appservice_handler.notify_interested_services(max_stream_id)
+
+ while True:
+ try:
+ args = store.stream_positions()
+ args["timeout"] = 30000
+ result = yield http_client.get_json(replication_url, args=args)
+ yield store.process_replication(result)
+ replicate(result)
+ except:
+ logger.exception("Error replicating from %r", replication_url)
+ yield sleep(30)
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse appservice", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.appservice"
+
+ setup_logging(config.worker_log_config, config.worker_log_file)
+
+ database_engine = create_engine(config.database_config)
+
+ if config.notify_appservices:
+ sys.stderr.write(
+ "\nThe appservices must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``notify_appservices: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.notify_appservices = True
+
+ ps = AppserviceServer(
+ config.server_name,
+ db_config=config.database_config,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ps.setup()
+ ps.start_listening(config.worker_listeners)
+
+ def run():
+ with LoggingContext("run"):
+ logger.info("Running")
+ change_resource_limit(config.soft_file_limit)
+ if config.gc_thresholds:
+ gc.set_threshold(*config.gc_thresholds)
+ reactor.run()
+
+ def start():
+ ps.replicate()
+ ps.get_datastore().start_profiling()
+
+ reactor.callWhenRunning(start)
+
+ if config.worker_daemonize:
+ daemon = Daemonize(
+ app="synapse-appservice",
+ pid=config.worker_pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 40e6f652..54f35900 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -51,7 +51,7 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
-from synapse.metrics import register_memory_metrics
+from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
@@ -385,6 +385,8 @@ def run(hs):
start_time = hs.get_clock().time()
+ stats = {}
+
@defer.inlineCallbacks
def phone_stats_home():
logger.info("Gathering stats for reporting")
@@ -393,7 +395,10 @@ def run(hs):
if uptime < 0:
uptime = 0
- stats = {}
+ # If the stats directory is empty then this is the first time we've
+ # reported stats.
+ first_time = not stats
+
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
@@ -406,6 +411,25 @@ def run(hs):
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
+ else:
+ stats.pop("daily_messages", None)
+
+ if first_time:
+ # Add callbacks to report the synapse stats as metrics whenever
+ # prometheus requests them, typically every 30s.
+ # As some of the stats are expensive to calculate we only update
+ # them when synapse phones home to matrix.org every 24 hours.
+ metrics = get_metrics_for("synapse.usage")
+ metrics.add_callback("timestamp", lambda: stats["timestamp"])
+ metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
+ metrics.add_callback("total_users", lambda: stats["total_users"])
+ metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
+ metrics.add_callback(
+ "daily_active_users", lambda: stats["daily_active_users"]
+ )
+ metrics.add_callback(
+ "daily_messages", lambda: stats.get("daily_messages", 0)
+ )
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
new file mode 100644
index 00000000..9d4c4a07
--- /dev/null
+++ b/synapse/app/media_repository.py
@@ -0,0 +1,212 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse
+
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.rest.media.v0.content_repository import ContentRepoResource
+from synapse.rest.media.v1.media_repository import MediaRepositoryResource
+from synapse.server import HomeServer
+from synapse.storage.client_ips import ClientIpStore
+from synapse.storage.engines import create_engine
+from synapse.storage.media_repository import MediaRepositoryStore
+from synapse.util.async import sleep
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.rlimit import change_resource_limit
+from synapse.util.versionstring import get_version_string
+from synapse.api.urls import (
+ CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
+)
+from synapse.crypto import context_factory
+
+
+from twisted.internet import reactor, defer
+from twisted.web.resource import Resource
+
+from daemonize import Daemonize
+
+import sys
+import logging
+import gc
+
+logger = logging.getLogger("synapse.app.media_repository")
+
+
+class MediaRepositorySlavedStore(
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ BaseSlavedStore,
+ MediaRepositoryStore,
+ ClientIpStore,
+):
+ pass
+
+
+class MediaRepositoryServer(HomeServer):
+ def get_db_conn(self, run_new_connection=True):
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_address = listener_config.get("bind_address", "")
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+ elif name == "media":
+ media_repo = MediaRepositoryResource(self)
+ resources.update({
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ CONTENT_REPO_PREFIX: ContentRepoResource(
+ self, self.config.uploads_path
+ ),
+ })
+
+ root_resource = create_resource_tree(resources, Resource())
+ reactor.listenTCP(
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ ),
+ interface=bind_address
+ )
+ logger.info("Synapse media repository now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ reactor.listenTCP(
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ ),
+ interface=listener.get("bind_address", '127.0.0.1')
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ @defer.inlineCallbacks
+ def replicate(self):
+ http_client = self.get_simple_http_client()
+ store = self.get_datastore()
+ replication_url = self.config.worker_replication_url
+
+ while True:
+ try:
+ args = store.stream_positions()
+ args["timeout"] = 30000
+ result = yield http_client.get_json(replication_url, args=args)
+ yield store.process_replication(result)
+ except:
+ logger.exception("Error replicating from %r", replication_url)
+ yield sleep(5)
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse media repository", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.media_repository"
+
+ setup_logging(config.worker_log_config, config.worker_log_file)
+
+ database_engine = create_engine(config.database_config)
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ss = MediaRepositoryServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ss.setup()
+ ss.get_handlers()
+ ss.start_listening(config.worker_listeners)
+
+ def run():
+ with LoggingContext("run"):
+ logger.info("Running")
+ change_resource_limit(config.soft_file_limit)
+ if config.gc_thresholds:
+ gc.set_threshold(*config.gc_thresholds)
+ reactor.run()
+
+ def start():
+ ss.get_datastore().start_profiling()
+ ss.replicate()
+
+ reactor.callWhenRunning(start)
+
+ if config.worker_daemonize:
+ daemon = Daemonize(
+ app="synapse-media-repository",
+ pid=config.worker_pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index c8dde0fc..8d755a4b 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -80,11 +80,6 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__
)
- # XXX: This is a bit broken because we don't persist forgotten rooms
- # in a way that they can be streamed. This means that we don't have a
- # way to invalidate the forgotten rooms cache correctly.
- # For now we expire the cache every 10 minutes.
- BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@@ -168,7 +163,6 @@ class PusherServer(HomeServer):
store = self.get_datastore()
replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool()
- clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
@@ -220,21 +214,11 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids
)
- def expire_broken_caches():
- store.who_forgot_in_room.invalidate_all()
-
- next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
- now_ms = clock.time_msec()
- if now_ms > next_expire_broken_caches_ms:
- expire_broken_caches()
- next_expire_broken_caches_ms = (
- now_ms + store.BROKEN_CACHE_EXPIRY_MS
- )
yield store.process_replication(result)
poke_pushers(result)
except:
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 215ccfd5..e3173533 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
- # XXX: This is a bit broken because we don't persist forgotten rooms
- # in a way that they can be streamed. This means that we don't have a
- # way to invalidate the forgotten rooms cache correctly.
- # For now we expire the cache every 10 minutes.
- BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
+ get_presence_list_observers_accepted = PresenceStore.__dict__[
+ "get_presence_list_observers_accepted"
+ ]
+
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
+ self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
@@ -119,11 +121,13 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
- def set_state(self, user, state):
+ def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
+ get_state = PresenceHandler.get_state.__func__
+ _get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
@@ -194,19 +198,39 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False
yield self._send_syncing_users_now()
+ @defer.inlineCallbacks
+ def notify_from_replication(self, states, stream_id):
+ parties = yield self._get_interested_parties(
+ states, calculate_remote_hosts=False
+ )
+ room_ids_to_states, users_to_states, _ = parties
+
+ self.notifier.on_new_event(
+ "presence_key", stream_id, rooms=room_ids_to_states.keys(),
+ users=users_to_states.keys()
+ )
+
+ @defer.inlineCallbacks
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
+ states = []
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
- self.user_to_current_state[user_id] = UserPresenceState(
+ state = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
+ self.user_to_current_state[user_id] = state
+ states.append(state)
+
+ if states and "position" in stream:
+ stream_id = int(stream["position"])
+ yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object):
@@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
+ events.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
@@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
- clock = self.get_clock()
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
- def expire_broken_caches():
- store.who_forgot_in_room.invalidate_all()
- store.get_presence_list_accepted.invalidate_all()
-
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
@@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
result, "typing", "typing_key", room="room_id"
)
- next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
- now_ms = clock.time_msec()
- if now_ms > next_expire_broken_caches_ms:
- expire_broken_caches()
- next_expire_broken_caches_ms = (
- now_ms + store.BROKEN_CACHE_EXPIRY_MS
- )
yield store.process_replication(result)
typing_handler.process_replication(result)
- presence_handler.process_replication(result)
+ yield presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index f7178ea0..bde9b51b 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,6 +14,8 @@
# limitations under the License.
from synapse.api.constants import EventTypes
+from twisted.internet import defer
+
import logging
import re
@@ -79,13 +81,17 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
- sender=None, id=None):
+ sender=None, id=None, protocols=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.namespaces = self._check_namespaces(namespaces)
self.id = id
+ if protocols:
+ self.protocols = set(protocols)
+ else:
+ self.protocols = set()
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
@@ -138,65 +144,66 @@ class ApplicationService(object):
return regex_obj["exclusive"]
return False
- def _matches_user(self, event, member_list):
- if (hasattr(event, "sender") and
- self.is_interested_in_user(event.sender)):
- return True
+ @defer.inlineCallbacks
+ def _matches_user(self, event, store):
+ if not event:
+ defer.returnValue(False)
+
+ if self.is_interested_in_user(event.sender):
+ defer.returnValue(True)
# also check m.room.member state key
- if (hasattr(event, "type") and event.type == EventTypes.Member
- and hasattr(event, "state_key")
- and self.is_interested_in_user(event.state_key)):
- return True
+ if (event.type == EventTypes.Member and
+ self.is_interested_in_user(event.state_key)):
+ defer.returnValue(True)
+
+ if not store:
+ defer.returnValue(False)
+
+ member_list = yield store.get_users_in_room(event.room_id)
+
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
- return True
- return False
+ defer.returnValue(True)
+ defer.returnValue(False)
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
- def _matches_aliases(self, event, alias_list):
+ @defer.inlineCallbacks
+ def _matches_aliases(self, event, store):
+ if not store or not event:
+ defer.returnValue(False)
+
+ alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
- return True
- return False
+ defer.returnValue(True)
+ defer.returnValue(False)
- def is_interested(self, event, restrict_to=None, aliases_for_event=None,
- member_list=None):
+ @defer.inlineCallbacks
+ def is_interested(self, event, store=None):
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
- restrict_to(str): The namespace to restrict regex tests to.
- aliases_for_event(list): A list of all the known room aliases for
- this event.
- member_list(list): A list of all joined user_ids in this room.
+ store(DataStore)
Returns:
bool: True if this service would like to know about this event.
"""
- if aliases_for_event is None:
- aliases_for_event = []
- if member_list is None:
- member_list = []
-
- if restrict_to and restrict_to not in ApplicationService.NS_LIST:
- # this is a programming error, so fail early and raise a general
- # exception
- raise Exception("Unexpected restrict_to value: %s". restrict_to)
-
- if not restrict_to:
- return (self._matches_user(event, member_list)
- or self._matches_aliases(event, aliases_for_event)
- or self._matches_room_id(event))
- elif restrict_to == ApplicationService.NS_ALIASES:
- return self._matches_aliases(event, aliases_for_event)
- elif restrict_to == ApplicationService.NS_ROOMS:
- return self._matches_room_id(event)
- elif restrict_to == ApplicationService.NS_USERS:
- return self._matches_user(event, member_list)
+ # Do cheap checks first
+ if self._matches_room_id(event):
+ defer.returnValue(True)
+
+ if (yield self._matches_aliases(event, store)):
+ defer.returnValue(True)
+
+ if (yield self._matches_user(event, store)):
+ defer.returnValue(True)
+
+ defer.returnValue(False)
def is_interested_in_user(self, user_id):
return (
@@ -216,6 +223,9 @@ class ApplicationService(object):
or user_id == self.sender
)
+ def is_interested_in_protocol(self, protocol):
+ return protocol in self.protocols
+
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6da6a1b6..066127b6 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
+from synapse.types import ThirdPartyEntityKind
import logging
import urllib
@@ -24,6 +25,28 @@ import urllib
logger = logging.getLogger(__name__)
+def _is_valid_3pe_result(r, field):
+ if not isinstance(r, dict):
+ return False
+
+ for k in (field, "protocol"):
+ if k not in r:
+ return False
+ if not isinstance(r[k], str):
+ return False
+
+ if "fields" not in r:
+ return False
+ fields = r["fields"]
+ if not isinstance(fields, dict):
+ return False
+ for k in fields.keys():
+ if not isinstance(fields[k], str):
+ return False
+
+ return True
+
+
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
@@ -72,6 +95,43 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False)
@defer.inlineCallbacks
+ def query_3pe(self, service, kind, protocol, fields):
+ if kind == ThirdPartyEntityKind.USER:
+ uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
+ required_field = "userid"
+ elif kind == ThirdPartyEntityKind.LOCATION:
+ uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
+ required_field = "alias"
+ else:
+ raise ValueError(
+ "Unrecognised 'kind' argument %r to query_3pe()", kind
+ )
+
+ try:
+ response = yield self.get_json(uri, fields)
+ if not isinstance(response, list):
+ logger.warning(
+ "query_3pe to %s returned an invalid response %r",
+ uri, response
+ )
+ defer.returnValue([])
+
+ ret = []
+ for r in response:
+ if _is_valid_3pe_result(r, field=required_field):
+ ret.append(r)
+ else:
+ logger.warning(
+ "query_3pe to %s returned an invalid result %r",
+ uri, r
+ )
+
+ defer.returnValue(ret)
+ except Exception as ex:
+ logger.warning("query_3pe to %s threw exception %s", uri, ex)
+ defer.returnValue([])
+
+ @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events)
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 9afc8fd7..68a9de17 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,9 +48,12 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
+from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
-from twisted.internet import defer
+from synapse.util.logcontext import preserve_fn
+from synapse.util.metrics import Measure
+
import logging
logger = logging.getLogger(__name__)
@@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer
)
- self.queuer = _ServiceQueuer(self.txn_ctrl)
+ self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
@defer.inlineCallbacks
def start(self):
@@ -94,38 +97,36 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run.
"""
- def __init__(self, txn_ctrl):
+ def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
- self.pending_requests = {} # dict of {service_id: Deferred}
+ self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl
+ self.clock = clock
def enqueue(self, service, event):
# if this service isn't being sent something
- if not self.pending_requests.get(service.id):
- self._send_request(service, [event])
- else:
- # add to queue for this service
- if service.id not in self.queued_events:
- self.queued_events[service.id] = []
- self.queued_events[service.id].append(event)
-
- def _send_request(self, service, events):
- # send request and add callbacks
- d = self.txn_ctrl.send(service, events)
- d.addBoth(self._on_request_finish)
- d.addErrback(self._on_request_fail)
- self.pending_requests[service.id] = d
-
- def _on_request_finish(self, service):
- self.pending_requests[service.id] = None
- # if there are queued events, then send them.
- if (service.id in self.queued_events
- and len(self.queued_events[service.id]) > 0):
- self._send_request(service, self.queued_events[service.id])
- self.queued_events[service.id] = []
-
- def _on_request_fail(self, err):
- logger.error("AS request failed: %s", err)
+ self.queued_events.setdefault(service.id, []).append(event)
+ preserve_fn(self._send_request)(service)
+
+ @defer.inlineCallbacks
+ def _send_request(self, service):
+ if service.id in self.requests_in_flight:
+ return
+
+ self.requests_in_flight.add(service.id)
+ try:
+ while True:
+ events = self.queued_events.pop(service.id, [])
+ if not events:
+ return
+
+ with Measure(self.clock, "servicequeuer.send"):
+ try:
+ yield self.txn_ctrl.send(service, events)
+ except:
+ logger.exception("AS request failed")
+ finally:
+ self.requests_in_flight.discard(service.id)
class _TransactionController(object):
@@ -149,14 +150,12 @@ class _TransactionController(object):
if service_is_up:
sent = yield txn.send(self.as_api)
if sent:
- txn.complete(self.store)
+ yield txn.complete(self.store)
else:
- self._start_recoverer(service)
+ preserve_fn(self._start_recoverer)(service)
except Exception as e:
logger.exception(e)
- self._start_recoverer(service)
- # request has finished
- defer.returnValue(service)
+ preserve_fn(self._start_recoverer)(service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index eade8039..dfe43b0b 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -28,6 +28,7 @@ class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
+ self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs):
return """\
@@ -122,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
+ # protocols check
+ protocols = as_info.get("protocols")
+ if protocols:
+ # Because strings are lists in python
+ if isinstance(protocols, str) or not isinstance(protocols, list):
+ raise KeyError("Optional 'protocols' must be a list if present.")
+ for p in protocols:
+ if not isinstance(p, str):
+ raise KeyError("Bad value for 'protocols' item")
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
@@ -129,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
+ protocols=protocols,
)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5012c10e..d7211ee9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -22,6 +22,7 @@ from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
+from synapse.util.metrics import Measure
from twisted.internet import defer
@@ -61,6 +62,10 @@ Attributes:
"""
+class KeyLookupError(ValueError):
+ pass
+
+
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -239,59 +244,60 @@ class Keyring(object):
@defer.inlineCallbacks
def do_iterations():
- merged_results = {}
+ with Measure(self.clock, "get_server_verify_keys"):
+ merged_results = {}
- missing_keys = {}
- for verify_request in verify_requests:
- missing_keys.setdefault(verify_request.server_name, set()).update(
- verify_request.key_ids
- )
-
- for fn in key_fetch_fns:
- results = yield fn(missing_keys.items())
- merged_results.update(results)
-
- # We now need to figure out which verify requests we have keys
- # for and which we don't
missing_keys = {}
- requests_missing_keys = []
for verify_request in verify_requests:
- server_name = verify_request.server_name
- result_keys = merged_results[server_name]
-
- if verify_request.deferred.called:
- # We've already called this deferred, which probably
- # means that we've already found a key for it.
- continue
-
- for key_id in verify_request.key_ids:
- if key_id in result_keys:
- with PreserveLoggingContext():
- verify_request.deferred.callback((
- server_name,
- key_id,
- result_keys[key_id],
- ))
- break
- else:
- # The else block is only reached if the loop above
- # doesn't break.
- missing_keys.setdefault(server_name, set()).update(
- verify_request.key_ids
- )
- requests_missing_keys.append(verify_request)
-
- if not missing_keys:
- break
-
- for verify_request in requests_missing_keys.values():
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ missing_keys.setdefault(verify_request.server_name, set()).update(
+ verify_request.key_ids
+ )
+
+ for fn in key_fetch_fns:
+ results = yield fn(missing_keys.items())
+ merged_results.update(results)
+
+ # We now need to figure out which verify requests we have keys
+ # for and which we don't
+ missing_keys = {}
+ requests_missing_keys = []
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ result_keys = merged_results[server_name]
+
+ if verify_request.deferred.called:
+ # We've already called this deferred, which probably
+ # means that we've already found a key for it.
+ continue
+
+ for key_id in verify_request.key_ids:
+ if key_id in result_keys:
+ with PreserveLoggingContext():
+ verify_request.deferred.callback((
+ server_name,
+ key_id,
+ result_keys[key_id],
+ ))
+ break
+ else:
+ # The else block is only reached if the loop above
+ # doesn't break.
+ missing_keys.setdefault(server_name, set()).update(
+ verify_request.key_ids
+ )
+ requests_missing_keys.append(verify_request)
+
+ if not missing_keys:
+ break
+
+ for verify_request in requests_missing_keys.values():
+ verify_request.deferred.errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ verify_request.server_name, verify_request.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
def on_err(err):
for verify_request in verify_requests:
@@ -302,15 +308,15 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
- res = yield defer.gatherResults(
+ res = yield preserve_context_over_deferred(defer.gatherResults(
[
- self.store.get_server_verify_keys(
+ preserve_fn(self.store.get_server_verify_keys)(
server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(dict(res))
@@ -331,13 +337,13 @@ class Keyring(object):
)
defer.returnValue({})
- results = yield defer.gatherResults(
+ results = yield preserve_context_over_deferred(defer.gatherResults(
[
- get_key(p_name, p_keys)
+ preserve_fn(get_key)(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
union_of_keys = {}
for result in results:
@@ -363,7 +369,7 @@ class Keyring(object):
)
except Exception as e:
logger.info(
- "Unable to getting key %r for %r directly: %s %s",
+ "Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
@@ -377,13 +383,13 @@ class Keyring(object):
defer.returnValue(keys)
- results = yield defer.gatherResults(
+ results = yield preserve_context_over_deferred(defer.gatherResults(
[
- get_key(server_name, key_ids)
+ preserve_fn(get_key)(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
merged = {}
for result in results:
@@ -425,7 +431,7 @@ class Keyring(object):
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
- raise ValueError(
+ raise KeyLookupError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
@@ -448,7 +454,7 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
- raise ValueError(
+ raise KeyLookupError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
@@ -460,9 +466,9 @@ class Keyring(object):
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
- yield defer.gatherResults(
+ yield preserve_context_over_deferred(defer.gatherResults(
[
- self.store_keys(
+ preserve_fn(self.store_keys)(
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@@ -470,7 +476,7 @@ class Keyring(object):
for server_name, response_keys in keys.items()
],
consumeErrors=True
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(keys)
@@ -491,10 +497,10 @@ class Keyring(object):
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
- raise ValueError("Key response not signed by remote server")
+ raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
- raise ValueError("Key response missing TLS fingerprints")
+ raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
@@ -508,7 +514,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
- raise ValueError("TLS certificate not allowed by fingerprints")
+ raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
@@ -518,7 +524,7 @@ class Keyring(object):
keys.update(response_keys)
- yield defer.gatherResults(
+ yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store_keys)(
server_name=key_server_name,
@@ -528,7 +534,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(keys)
@@ -560,14 +566,14 @@ class Keyring(object):
server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
- raise ValueError(
+ raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
- raise ValueError(
+ raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
@@ -594,7 +600,7 @@ class Keyring(object):
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
- yield defer.gatherResults(
+ yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
@@ -607,7 +613,7 @@ class Keyring(object):
for key_id in updated_key_ids
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
results[server_name] = response_keys
@@ -635,15 +641,15 @@ class Keyring(object):
if ("signatures" not in response
or server_name not in response["signatures"]):
- raise ValueError("Key response not signed by remote server")
+ raise KeyLookupError("Key response not signed by remote server")
if "tls_certificate" not in response:
- raise ValueError("Key response missing TLS certificate")
+ raise KeyLookupError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
- raise ValueError("TLS certificate doesn't match")
+ raise KeyLookupError("TLS certificate doesn't match")
# Cache the result in the datastore.
@@ -659,7 +665,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
- raise ValueError(
+ raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
@@ -696,7 +702,7 @@ class Keyring(object):
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
- yield defer.gatherResults(
+ yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
@@ -704,4 +710,4 @@ class Keyring(object):
for key_id, key in verify_keys.items()
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index aab18d7f..0e9fd902 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -88,6 +88,8 @@ def prune_event(event):
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
+ if "replaces_state" in event.unsigned:
+ allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)(
allowed_fields,
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index da2f5e8c..2339cc90 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging
@@ -102,10 +103,10 @@ class FederationBase(object):
warn, pdu
)
- valid_pdus = yield defer.gatherResults(
+ valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
deferreds,
consumeErrors=True
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@@ -129,7 +130,7 @@ class FederationBase(object):
for pdu in pdus
]
- deferreds = self.keyring.verify_json_objects_for_server([
+ deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index da95c2ad..f2b3aceb 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
import synapse.metrics
@@ -51,10 +52,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+PDU_RETRY_TIME_MS = 1 * 60 * 1000
+
+
class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
+ self.pdu_destination_tried = {}
+ self._clock.looping_call(
+ self._clear_tried_cache, 60 * 1000,
+ )
+
+ def _clear_tried_cache(self):
+ """Clear pdu_destination_tried cache"""
+ now = self._clock.time_msec()
+
+ old_dict = self.pdu_destination_tried
+ self.pdu_destination_tried = {}
+
+ for event_id, destination_dict in old_dict.items():
+ destination_dict = {
+ dest: time
+ for dest, time in destination_dict.items()
+ if time + PDU_RETRY_TIME_MS > now
+ }
+ if destination_dict:
+ self.pdu_destination_tried[event_id] = destination_dict
+
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@@ -201,10 +226,10 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield defer.gatherResults(
+ pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(pdus)
@@ -240,8 +265,15 @@ class FederationClient(FederationBase):
if ev:
defer.returnValue(ev)
+ pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
pdu = None
for destination in destinations:
+ now = self._clock.time_msec()
+ last_attempt = pdu_attempts.get(destination, 0)
+ if last_attempt + PDU_RETRY_TIME_MS > now:
+ continue
+
try:
limiter = yield get_retry_limiter(
destination,
@@ -269,25 +301,19 @@ class FederationClient(FederationBase):
break
- except SynapseError as e:
- logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
- )
- continue
- except CodeMessageException as e:
- if 400 <= e.code < 500:
- raise
+ pdu_attempts[destination] = now
+ except SynapseError as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
- continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
+ pdu_attempts[destination] = now
+
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
@@ -406,7 +432,7 @@ class FederationClient(FederationBase):
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
- seen_events = yield self.store.get_events(event_ids)
+ seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
@@ -432,14 +458,16 @@ class FederationClient(FederationBase):
batch = set(missing_events[i:i + batch_size])
deferreds = [
- self.get_pdu(
+ preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
- res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ res = yield preserve_context_over_deferred(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
for success, result in res:
if success:
signed_events.append(result)
@@ -828,14 +856,16 @@ class FederationClient(FederationBase):
return srvs
deferreds = [
- self.get_pdu(
+ preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
- res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ res = yield preserve_context_over_deferred(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
signed_events.append(val)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 5787f854..cb2ef021 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -21,11 +21,11 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
+from synapse.util.metrics import measure_func
import synapse.metrics
import logging
@@ -51,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer
- self._clock = hs.get_clock()
+ self.clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
@@ -82,7 +82,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {}
# HACK to get unique tx id
- self._next_txn_id = int(self._clock.time_msec())
+ self._next_txn_id = int(self.clock.time_msec())
def can_send_to(self, destination):
"""Can we send messages to the given server?
@@ -119,266 +119,215 @@ class TransactionQueue(object):
if not destinations:
return
- deferreds = []
-
for destination in destinations:
- deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
- (pdu, deferred, order)
+ (pdu, order)
)
- def chain(failure):
- if not deferred.called:
- deferred.errback(failure)
-
- def log_failure(f):
- logger.warn("Failed to send pdu to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
-
- deferreds.append(deferred)
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
- # NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if not self.can_send_to(destination):
return
- deferred = defer.Deferred()
- self.pending_edus_by_dest.setdefault(destination, []).append(
- (edu, deferred)
- )
+ self.pending_edus_by_dest.setdefault(destination, []).append(edu)
- def chain(failure):
- if not deferred.called:
- deferred.errback(failure)
-
- def log_failure(f):
- logger.warn("Failed to send edu to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
-
- return deferred
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
- @defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
- deferred = defer.Deferred()
-
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
- ).append(
- (failure, deferred)
- )
-
- def chain(f):
- if not deferred.called:
- deferred.errback(f)
-
- def log_failure(f):
- logger.warn("Failed to send failure to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
+ ).append(failure)
- yield deferred
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
@defer.inlineCallbacks
- @log_function
def _attempt_new_transaction(self, destination):
yield run_on_reactor()
+ while True:
+ # list of (pending_pdu, deferred, order)
+ if destination in self.pending_transactions:
+ # XXX: pending_transactions can get stuck on by a never-ending
+ # request at which point pending_pdus_by_dest just keeps growing.
+ # we need application-layer timeouts of some flavour of these
+ # requests
+ logger.debug(
+ "TX [%s] Transaction already in progress",
+ destination
+ )
+ return
- # list of (pending_pdu, deferred, order)
- if destination in self.pending_transactions:
- # XXX: pending_transactions can get stuck on by a never-ending
- # request at which point pending_pdus_by_dest just keeps growing.
- # we need application-layer timeouts of some flavour of these
- # requests
- logger.debug(
- "TX [%s] Transaction already in progress",
- destination
- )
- return
-
- pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
- pending_edus = self.pending_edus_by_dest.pop(destination, [])
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
- if pending_pdus:
- logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- destination, len(pending_pdus))
+ if pending_pdus:
+ logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ destination, len(pending_pdus))
- if not pending_pdus and not pending_edus and not pending_failures:
- logger.debug("TX [%s] Nothing to send", destination)
- return
+ if not pending_pdus and not pending_edus and not pending_failures:
+ logger.debug("TX [%s] Nothing to send", destination)
+ return
- try:
- self.pending_transactions[destination] = 1
+ yield self._send_new_transaction(
+ destination, pending_pdus, pending_edus, pending_failures
+ )
- logger.debug("TX [%s] _attempt_new_transaction", destination)
+ @measure_func("_send_new_transaction")
+ @defer.inlineCallbacks
+ def _send_new_transaction(self, destination, pending_pdus, pending_edus,
+ pending_failures):
# Sort based on the order field
- pending_pdus.sort(key=lambda t: t[2])
-
+ pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
- edus = [x[0] for x in pending_edus]
- failures = [x[0].get_dict() for x in pending_failures]
- deferreds = [
- x[1]
- for x in pending_pdus + pending_edus + pending_failures
- ]
-
- txn_id = str(self._next_txn_id)
-
- limiter = yield get_retry_limiter(
- destination,
- self._clock,
- self.store,
- )
+ edus = pending_edus
+ failures = [x.get_dict() for x in pending_failures]
- logger.debug(
- "TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
- destination, txn_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures)
- )
+ try:
+ self.pending_transactions[destination] = 1
- logger.debug("TX [%s] Persisting transaction...", destination)
+ logger.debug("TX [%s] _attempt_new_transaction", destination)
- transaction = Transaction.create_new(
- origin_server_ts=int(self._clock.time_msec()),
- transaction_id=txn_id,
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
- )
+ txn_id = str(self._next_txn_id)
- self._next_txn_id += 1
+ limiter = yield get_retry_limiter(
+ destination,
+ self.clock,
+ self.store,
+ )
- yield self.transaction_actions.prepare_to_send(transaction)
+ logger.debug(
+ "TX [%s] {%s} Attempting new transaction"
+ " (pdus: %d, edus: %d, failures: %d)",
+ destination, txn_id,
+ len(pending_pdus),
+ len(pending_edus),
+ len(pending_failures)
+ )
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
- destination, txn_id,
- transaction.transaction_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures),
- )
+ logger.debug("TX [%s] Persisting transaction...", destination)
- with limiter:
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self._clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
- )
- code = 200
-
- if response:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warn(
- "Transaction returned error for %s: %s",
- e_id, r,
- )
- except HttpResponseException as e:
- code = e.code
- response = e.response
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self.clock.time_msec()),
+ transaction_id=txn_id,
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
+
+ self._next_txn_id += 1
+
+ yield self.transaction_actions.prepare_to_send(transaction)
+ logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
+ "TX [%s] {%s} Sending transaction [%s],"
+ " (PDUs: %d, EDUs: %d, failures: %d)",
+ destination, txn_id,
+ transaction.transaction_id,
+ len(pending_pdus),
+ len(pending_edus),
+ len(pending_failures),
)
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
+ with limiter:
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self.clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+ code = 200
+
+ if response:
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
+ logger.warn(
+ "Transaction returned error for %s: %s",
+ e_id, r,
+ )
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
+
+ logger.info(
+ "TX [%s] {%s} got %d response",
+ destination, txn_id, code
+ )
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
- logger.debug("TX [%s] Marked as delivered", destination)
-
- logger.debug("TX [%s] Yielding to callbacks...", destination)
-
- for deferred in deferreds:
- if code == 200:
- deferred.callback(None)
- else:
- deferred.errback(RuntimeError("Got status %d" % code))
-
- # Ensures we don't continue until all callbacks on that
- # deferred have fired
- try:
- yield deferred
- except:
- pass
-
- logger.debug("TX [%s] Yielded to callbacks", destination)
- except NotRetryingDestination:
- logger.info(
- "TX [%s] not ready for retry yet - "
- "dropping transaction for now",
- destination,
- )
- except RuntimeError as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
- except Exception as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
- for deferred in deferreds:
- if not deferred.called:
- deferred.errback(e)
+ logger.debug("TX [%s] Marked as delivered", destination)
+
+ if code != 200:
+ for p in pdus:
+ logger.info(
+ "Failed to send event %s to %s", p.event_id, destination
+ )
+ except NotRetryingDestination:
+ logger.info(
+ "TX [%s] not ready for retry yet - "
+ "dropping transaction for now",
+ destination,
+ )
+ except RuntimeError as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
+
+ for p in pdus:
+ logger.info("Failed to send event %s to %s", p.event_id, destination)
+ except Exception as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
- finally:
- # We want to be *very* sure we delete this after we stop processing
- self.pending_transactions.pop(destination, None)
+ for p in pdus:
+ logger.info("Failed to send event %s to %s", p.event_id, destination)
- # Check to see if there is anything else to send.
- self._attempt_new_transaction(destination)
+ finally:
+ # We want to be *very* sure we delete this after we stop processing
+ self.pending_transactions.pop(destination, None)
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 1a50a2ec..63d05f25 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -19,7 +19,6 @@ from .room import (
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
-from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler
@@ -53,8 +52,6 @@ class Handlers(object):
self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs)
- self.event_stream_handler = EventStreamHandler(hs)
- self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 051ccdb3..306686a3 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -16,7 +16,8 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.appservice import ApplicationService
+from synapse.util.metrics import Measure
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging
@@ -42,36 +43,73 @@ class ApplicationServicesHandler(object):
self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False
+ self.clock = hs.get_clock()
+ self.notify_appservices = hs.config.notify_appservices
+
+ self.current_max = 0
+ self.is_processing = False
@defer.inlineCallbacks
- def notify_interested_services(self, event):
+ def notify_interested_services(self, current_id):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
prolonged length of time.
Args:
- event(Event): The event to push out to interested services.
+ current_id(int): The current maximum ID.
"""
- # Gather interested services
- services = yield self._get_services_for_event(event)
- if len(services) == 0:
- return # no services need notifying
-
- # Do we know this user exists? If not, poke the user query API for
- # all services which match that user regex. This needs to block as these
- # user queries need to be made BEFORE pushing the event.
- yield self._check_user_exists(event.sender)
- if event.type == EventTypes.Member:
- yield self._check_user_exists(event.state_key)
-
- if not self.started_scheduler:
- self.scheduler.start().addErrback(log_failure)
- self.started_scheduler = True
-
- # Fork off pushes to these services
- for service in services:
- self.scheduler.submit_event_for_as(service, event)
+ services = yield self.store.get_app_services()
+ if not services or not self.notify_appservices:
+ return
+
+ self.current_max = max(self.current_max, current_id)
+ if self.is_processing:
+ return
+
+ with Measure(self.clock, "notify_interested_services"):
+ self.is_processing = True
+ try:
+ upper_bound = self.current_max
+ limit = 100
+ while True:
+ upper_bound, events = yield self.store.get_new_events_for_appservice(
+ upper_bound, limit
+ )
+
+ if not events:
+ break
+
+ for event in events:
+ # Gather interested services
+ services = yield self._get_services_for_event(event)
+ if len(services) == 0:
+ continue # no services need notifying
+
+ # Do we know this user exists? If not, poke the user
+ # query API for all services which match that user regex.
+ # This needs to block as these user queries need to be
+ # made BEFORE pushing the event.
+ yield self._check_user_exists(event.sender)
+ if event.type == EventTypes.Member:
+ yield self._check_user_exists(event.state_key)
+
+ if not self.started_scheduler:
+ self.scheduler.start().addErrback(log_failure)
+ self.started_scheduler = True
+
+ # Fork off pushes to these services
+ for service in services:
+ preserve_fn(self.scheduler.submit_event_for_as)(
+ service, event
+ )
+
+ yield self.store.set_appservice_last_pos(upper_bound)
+
+ if len(events) < limit:
+ break
+ finally:
+ self.is_processing = False
@defer.inlineCallbacks
def query_user_exists(self, user_id):
@@ -104,11 +142,12 @@ class ApplicationServicesHandler(object):
association can be found.
"""
room_alias_str = room_alias.to_string()
- alias_query_services = yield self._get_services_for_event(
- event=None,
- restrict_to=ApplicationService.NS_ALIASES,
- alias_list=[room_alias_str]
- )
+ services = yield self.store.get_app_services()
+ alias_query_services = [
+ s for s in services if (
+ s.is_interested_in_alias(room_alias_str)
+ )
+ ]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str
@@ -121,34 +160,35 @@ class ApplicationServicesHandler(object):
defer.returnValue(result)
@defer.inlineCallbacks
- def _get_services_for_event(self, event, restrict_to="", alias_list=None):
+ def query_3pe(self, kind, protocol, fields):
+ services = yield self._get_services_for_3pn(protocol)
+
+ results = yield preserve_context_over_deferred(defer.DeferredList([
+ preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
+ for service in services
+ ], consumeErrors=True))
+
+ ret = []
+ for (success, result) in results:
+ if success:
+ ret.extend(result)
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event.
Args:
event(Event): The event to check. Can be None if alias_list is not.
- restrict_to(str): The namespace to restrict regex tests to.
- alias_list: A list of aliases to get services for. If None, this
- list is obtained from the database.
Returns:
list<ApplicationService>: A list of services interested in this
event based on the service regex.
"""
- member_list = None
- if hasattr(event, "room_id"):
- # We need to know the aliases associated with this event.room_id,
- # if any.
- if not alias_list:
- alias_list = yield self.store.get_aliases_for_room(
- event.room_id
- )
- # We need to know the members associated with this event.room_id,
- # if any.
- member_list = yield self.store.get_users_in_room(event.room_id)
-
services = yield self.store.get_app_services()
interested_list = [
s for s in services if (
- s.is_interested(event, restrict_to, alias_list, member_list)
+ yield s.is_interested(event, self.store)
)
]
defer.returnValue(interested_list)
@@ -164,6 +204,14 @@ class ApplicationServicesHandler(object):
defer.returnValue(interested_list)
@defer.inlineCallbacks
+ def _get_services_for_3pn(self, protocol):
+ services = yield self.store.get_app_services()
+ interested_list = [
+ s for s in services if s.is_interested_in_protocol(protocol)
+ ]
+ defer.returnValue(interested_list)
+
+ @defer.inlineCallbacks
def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2e138f32..6986930c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base
- self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
+ self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
- len(result)
+ len(conn.response)
)
defer.returnValue(False)
@@ -719,13 +719,14 @@ class AuthHandler(BaseHandler):
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
+ auth_api = self.hs.get_auth()
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
- auth_api = self.hs.get_auth()
- auth_api.validate_macaroon(macaroon, "login", True)
- return self.get_user_from_macaroon(macaroon)
- except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
- raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
+ user_id = auth_api.get_user_id_from_macaroon(macaroon)
+ auth_api.validate_macaroon(macaroon, "login", True, user_id)
+ return user_id
+ except Exception:
+ raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
@@ -736,21 +737,11 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
- def get_user_from_macaroon(self, macaroon):
- user_prefix = "user_id = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix):]
- raise AuthError(
- self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
- errcode=Codes.UNKNOWN_TOKEN
- )
-
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
- except_access_token_ids = [requester.access_token_id] if requester else []
+ except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
@@ -759,10 +750,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens(
- user_id, except_access_token_ids
+ user_id, except_access_token_id
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
- user_id, except_access_token_ids
+ user_id, except_access_token_id
)
@defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 618cb536..01a76171 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -26,7 +26,9 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
-from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
+)
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@@ -249,7 +251,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member:
continue
try:
- domain = UserID.from_string(ev.state_key).domain
+ domain = get_domain_from_id(ev.state_key)
except:
continue
@@ -274,7 +276,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities=[]):
+ def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return
@@ -284,9 +286,6 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- if not extremities:
- extremities = yield self.store.get_oldest_events_in_room(room_id)
-
events = yield self.replication_layer.backfill(
dest,
room_id,
@@ -364,9 +363,9 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch
)
- results = yield defer.gatherResults(
+ results = yield preserve_context_over_deferred(defer.gatherResults(
[
- self.replication_layer.get_pdu(
+ preserve_fn(self.replication_layer.get_pdu)(
[dest],
event_id,
outlier=True,
@@ -375,10 +374,10 @@ class FederationHandler(BaseHandler):
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True
- ).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results})
+ )).addErrback(unwrapFirstError)
+ auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
- a_id for event in results for a_id, _ in event.auth_events
+ a_id for event in results for a_id, _ in event.auth_events if event
)
missing_auth = required_auth - set(auth_events)
@@ -455,6 +454,10 @@ class FederationHandler(BaseHandler):
)
max_depth = sorted_extremeties_tuple[0][1]
+ # We don't want to specify too many extremities as it causes the backfill
+ # request URI to be too long.
+ extremities = dict(sorted_extremeties_tuple[:5])
+
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",
@@ -551,10 +554,10 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
- states = yield defer.gatherResults([
- self.state_handler.resolve_state_groups(room_id, [e])
+ states = yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids
- ])
+ ]))
states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple:
@@ -1093,16 +1096,17 @@ class FederationHandler(BaseHandler):
)
if event:
- # FIXME: This is a temporary work around where we occasionally
- # return events slightly differently than when they were
- # originally signed
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ if self.hs.is_mine_id(event.event_id):
+ # FIXME: This is a temporary work around where we occasionally
+ # return events slightly differently than when they were
+ # originally signed
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
)
- )
if do_auth:
in_room = yield self.auth.check_host_in_room(
@@ -1112,6 +1116,12 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
+ events = yield self._filter_events_for_server(
+ origin, event.room_id, [event]
+ )
+
+ event = events[0]
+
defer.returnValue(event)
else:
defer.returnValue(None)
@@ -1158,9 +1168,9 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
- contexts = yield defer.gatherResults(
+ contexts = yield preserve_context_over_deferred(defer.gatherResults(
[
- self._prep_event(
+ preserve_fn(self._prep_event)(
origin,
ev_info["event"],
state=ev_info.get("state"),
@@ -1168,7 +1178,7 @@ class FederationHandler(BaseHandler):
)
for ev_info in event_infos
]
- )
+ ))
yield self.store.persist_events(
[
@@ -1452,9 +1462,9 @@ class FederationHandler(BaseHandler):
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
- different_events = yield defer.gatherResults(
+ different_events = yield preserve_context_over_deferred(defer.gatherResults(
[
- self.store.get_event(
+ preserve_fn(self.store.get_event)(
d,
allow_none=True,
allow_rejected=False,
@@ -1463,7 +1473,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d]
],
consumeErrors=True
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index dc76d34a..4c3cd9d1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -28,7 +28,8 @@ from synapse.types import (
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -502,15 +503,17 @@ class MessageHandler(BaseHandler):
lambda states: states[event.event_id]
)
- (messages, token), current_state = yield defer.gatherResults(
- [
- self.store.get_recent_events_for_room(
- event.room_id,
- limit=limit,
- end_token=room_end_token,
- ),
- deferred_room_state,
- ]
+ (messages, token), current_state = yield preserve_context_over_deferred(
+ defer.gatherResults(
+ [
+ preserve_fn(self.store.get_recent_events_for_room)(
+ event.room_id,
+ limit=limit,
+ end_token=room_end_token,
+ ),
+ deferred_room_state,
+ ]
+ )
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
@@ -719,9 +722,9 @@ class MessageHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults(
[
- get_presence(),
- get_receipts(),
- self.store.get_recent_events_for_room(
+ preserve_fn(get_presence)(),
+ preserve_fn(get_receipts)(),
+ preserve_fn(self.store.get_recent_events_for_room)(
room_id,
limit=limit,
end_token=now_token.room_key,
@@ -755,6 +758,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret)
+ @measure_func("_create_new_client_event")
@defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
@@ -806,6 +810,7 @@ class MessageHandler(BaseHandler):
(event, context,)
)
+ @measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
self,
@@ -934,7 +939,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
- self.notifier.on_new_room_event(
+ yield self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
@@ -944,6 +949,6 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
- federation_handler.handle_new_event(
+ preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations,
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6b70fa38..6a1fe76c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -503,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states)
@defer.inlineCallbacks
- def _get_interested_parties(self, states):
+ def _get_interested_parties(self, states, calculate_remote_hosts=True):
"""Given a list of states return which entities (rooms, users, servers)
are interested in the given states.
@@ -526,14 +526,15 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {}
- for room_id, states in room_ids_to_states.items():
- local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
- if not local_states:
- continue
+ if calculate_remote_hosts:
+ for room_id, states in room_ids_to_states.items():
+ local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
+ if not local_states:
+ continue
- hosts = yield self.store.get_joined_hosts_for_room(room_id)
- for host in hosts:
- hosts_to_states.setdefault(host, []).extend(local_states)
+ hosts = yield self.store.get_joined_hosts_for_room(room_id)
+ for host in hosts:
+ hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
@@ -565,6 +566,16 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states)
+ @defer.inlineCallbacks
+ def notify_for_states(self, state, stream_id):
+ parties = yield self._get_interested_parties([state])
+ room_ids_to_states, users_to_states, hosts_to_states = parties
+
+ self.notifier.on_new_event(
+ "presence_key", stream_id, rooms=room_ids_to_states.keys(),
+ users=[UserID.from_string(u) for u in users_to_states.keys()]
+ )
+
def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers.
@@ -672,7 +683,7 @@ class PresenceHandler(object):
])
@defer.inlineCallbacks
- def set_state(self, target_user, state):
+ def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
status_msg = state.get("status_msg", None)
@@ -689,10 +700,13 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
new_fields = {
- "state": presence,
- "status_msg": status_msg if presence != PresenceState.OFFLINE else None
+ "state": presence
}
+ if not ignore_status_msg:
+ msg = status_msg if presence != PresenceState.OFFLINE else None
+ new_fields["status_msg"] = msg
+
if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec()
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8cec8fc4..8b17632f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -59,10 +59,13 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids,
txn_id=None,
ratelimit=True,
+ content=None,
):
+ if content is None:
+ content = {}
msg_handler = self.hs.get_handlers().message_handler
- content = {"membership": membership}
+ content["membership"] = membership
if requester.is_guest:
content["kind"] = "guest"
@@ -140,8 +143,9 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
+ content=None,
):
- key = (target, room_id,)
+ key = (room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
@@ -153,6 +157,7 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
+ content=content,
)
defer.returnValue(result)
@@ -168,7 +173,11 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
+ content=None,
):
+ if content is None:
+ content = {}
+
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@@ -218,7 +227,7 @@ class RoomMemberHandler(BaseHandler):
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
- content = {"membership": Membership.JOIN}
+ content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
@@ -272,6 +281,7 @@ class RoomMemberHandler(BaseHandler):
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
+ content=content,
)
@defer.inlineCallbacks
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0ee4ebe5..c8dfd02e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -464,10 +464,10 @@ class SyncHandler(object):
else:
state = {}
- defer.returnValue({
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(state.values())
- })
+ defer.returnValue({
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(state.values())
+ })
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
@@ -485,9 +485,9 @@ class SyncHandler(object):
)
defer.returnValue(notifs)
- # There is no new information in this period, so your notification
- # count is whatever it was last time.
- defer.returnValue(None)
+ # There is no new information in this period, so your notification
+ # count is whatever it was last time.
+ defer.returnValue(None)
@defer.inlineCallbacks
def generate_sync_result(self, sync_config, since_token=None, full_state=False):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 5589296c..46181984 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,7 +16,9 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+)
from synapse.util.metrics import Measure
from synapse.types import UserID
@@ -169,13 +171,13 @@ class TypingHandler(object):
deferreds = []
for domain in domains:
if domain == self.server_name:
- self._push_update_local(
+ preserve_fn(self._push_update_local)(
room_id=room_id,
user_id=user_id,
typing=typing
)
else:
- deferreds.append(self.federation.send_edu(
+ deferreds.append(preserve_fn(self.federation.send_edu)(
destination=domain,
edu_type="m.typing",
content={
@@ -185,7 +187,9 @@ class TypingHandler(object):
},
))
- yield defer.DeferredList(deferreds, consumeErrors=True)
+ yield preserve_context_over_deferred(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c3589534..f93093dd 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60,
)
- response = yield preserve_context_over_fn(
- send_request,
- )
+ response = yield preserve_context_over_fn(send_request)
log_result = "%d %s" % (response.code, response.phrase,)
break
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2b3c05a7..168e53ce 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -19,6 +19,7 @@ from synapse.api.errors import (
)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict
+from synapse.util.metrics import Measure
import synapse.metrics
import synapse.events
@@ -74,12 +75,12 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0
-def request_handler(report_metrics=True):
+def request_handler(include_metrics=False):
"""Decorator for ``wrap_request_handler``"""
- return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
+ return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
-def wrap_request_handler(request_handler, report_metrics):
+def wrap_request_handler(request_handler, include_metrics=False):
"""Wraps a method that acts as a request handler with the necessary logging
and exception handling.
@@ -103,54 +104,56 @@ def wrap_request_handler(request_handler, report_metrics):
_next_request_id += 1
with LoggingContext(request_id) as request_context:
- if report_metrics:
+ with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics()
- request_metrics.start(self.clock)
-
- request_context.request = request_id
- with request.processing():
- try:
- with PreserveLoggingContext(request_context):
- yield request_handler(self, request)
- except CodeMessageException as e:
- code = e.code
- if isinstance(e, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
- )
- else:
- logger.exception(e)
- outgoing_responses_counter.inc(request.method, str(code))
- respond_with_json(
- request, code, cs_exception(e), send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
- )
- except:
- logger.exception(
- "Failed handle request %s.%s on %r: %r",
- request_handler.__module__,
- request_handler.__name__,
- self,
- request
- )
- respond_with_json(
- request,
- 500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
- send_cors=True
- )
- finally:
+ request_metrics.start(self.clock, name=self.__class__.__name__)
+
+ request_context.request = request_id
+ with request.processing():
try:
- if report_metrics:
- request_metrics.stop(
- self.clock, request, self.__class__.__name__
+ with PreserveLoggingContext(request_context):
+ if include_metrics:
+ yield request_handler(self, request, request_metrics)
+ else:
+ yield request_handler(self, request)
+ except CodeMessageException as e:
+ code = e.code
+ if isinstance(e, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, e.msg
)
+ else:
+ logger.exception(e)
+ outgoing_responses_counter.inc(request.method, str(code))
+ respond_with_json(
+ request, code, cs_exception(e), send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ version_string=self.version_string,
+ )
except:
- pass
+ logger.exception(
+ "Failed handle request %s.%s on %r: %r",
+ request_handler.__module__,
+ request_handler.__name__,
+ self,
+ request
+ )
+ respond_with_json(
+ request,
+ 500,
+ {
+ "error": "Internal server error",
+ "errcode": Codes.UNKNOWN,
+ },
+ send_cors=True
+ )
+ finally:
+ try:
+ request_metrics.stop(
+ self.clock, request
+ )
+ except Exception as e:
+ logger.warn("Failed to stop metrics: %r", e)
return wrapped_request_handler
@@ -220,9 +223,9 @@ class JsonResource(HttpServer, resource.Resource):
# It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself.
- @request_handler(report_metrics=False)
+ @request_handler(include_metrics=True)
@defer.inlineCallbacks
- def _async_render(self, request):
+ def _async_render(self, request, request_metrics):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
@@ -231,9 +234,6 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, 200, {})
return
- request_metrics = RequestMetrics()
- request_metrics.start(self.clock)
-
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
@@ -247,12 +247,6 @@ class JsonResource(HttpServer, resource.Resource):
callback = path_entry.callback
- servlet_instance = getattr(callback, "__self__", None)
- if servlet_instance is not None:
- servlet_classname = servlet_instance.__class__.__name__
- else:
- servlet_classname = "%r" % callback
-
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items()
@@ -263,10 +257,13 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return
self._send_response(request, code, response)
- try:
- request_metrics.stop(self.clock, request, servlet_classname)
- except:
- pass
+ servlet_instance = getattr(callback, "__self__", None)
+ if servlet_instance is not None:
+ servlet_classname = servlet_instance.__class__.__name__
+ else:
+ servlet_classname = "%r" % callback
+
+ request_metrics.name = servlet_classname
return
@@ -298,11 +295,12 @@ class JsonResource(HttpServer, resource.Resource):
class RequestMetrics(object):
- def start(self, clock):
+ def start(self, clock, name):
self.start = clock.time_msec()
self.start_context = LoggingContext.current_context()
+ self.name = name
- def stop(self, clock, request, servlet_classname):
+ def stop(self, clock, request):
context = LoggingContext.current_context()
tag = ""
@@ -316,26 +314,26 @@ class RequestMetrics(object):
)
return
- incoming_requests_counter.inc(request.method, servlet_classname, tag)
+ incoming_requests_counter.inc(request.method, self.name, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
- servlet_classname, tag
+ self.name, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
- ru_utime, request.method, servlet_classname, tag
+ ru_utime, request.method, self.name, tag
)
response_ru_stime.inc_by(
- ru_stime, request.method, servlet_classname, tag
+ ru_stime, request.method, self.name, tag
)
response_db_txn_count.inc_by(
- context.db_txn_count, request.method, servlet_classname, tag
+ context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, servlet_classname, tag
+ context.db_txn_duration, request.method, self.name, tag
)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 30883a06..b86648f5 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -19,7 +19,8 @@ from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util.metrics import Measure
from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
import synapse.metrics
@@ -67,10 +68,8 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class.
"""
- def __init__(self, user_id, rooms, current_token, time_now_ms,
- appservice=None):
+ def __init__(self, user_id, rooms, current_token, time_now_ms):
self.user_id = user_id
- self.appservice = appservice
self.rooms = set(rooms)
self.current_token = current_token
self.last_notified_ms = time_now_ms
@@ -107,11 +106,6 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id)
- if self.appservice:
- notifier.appservice_to_user_streams.get(
- self.appservice, set()
- ).discard(self)
-
def count_listeners(self):
return len(self.notify_deferred.observers())
@@ -142,7 +136,6 @@ class Notifier(object):
def __init__(self, hs):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
- self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
@@ -168,8 +161,6 @@ class Notifier(object):
all_user_streams |= x
for x in self.user_to_user_stream.values():
all_user_streams.add(x)
- for x in self.appservice_to_user_streams.values():
- all_user_streams |= x
return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners)
@@ -182,11 +173,8 @@ class Notifier(object):
"users",
lambda: len(self.user_to_user_stream),
)
- metrics.register_callback(
- "appservices",
- lambda: count(bool, self.appservice_to_user_streams.values()),
- )
+ @preserve_fn
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened
@@ -208,6 +196,7 @@ class Notifier(object):
self.notify_replication()
+ @preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
@@ -225,24 +214,11 @@ class Notifier(object):
else:
self._on_new_room_event(event, room_stream_id, extra_users)
+ @preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
- self.appservice_handler.notify_interested_services(event)
-
- app_streams = set()
-
- for appservice in self.appservice_to_user_streams:
- # TODO (kegan): Redundant appservice listener checks?
- # App services will already be in the room_to_user_streams set, but
- # that isn't enough. They need to be checked here in order to
- # receive *invites* for users they are interested in. Does this
- # make the room_to_user_streams check somewhat obselete?
- if appservice.is_interested(event):
- app_user_streams = self.appservice_to_user_streams.get(
- appservice, set()
- )
- app_streams |= app_user_streams
+ self.appservice_handler.notify_interested_services(room_stream_id)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
@@ -251,35 +227,36 @@ class Notifier(object):
"room_key", room_stream_id,
users=extra_users,
rooms=[event.room_id],
- extra_streams=app_streams,
)
- def on_new_event(self, stream_key, new_token, users=[], rooms=[],
- extra_streams=set()):
+ @preserve_fn
+ def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms.
"""
with PreserveLoggingContext():
- user_streams = set()
+ with Measure(self.clock, "on_new_event"):
+ user_streams = set()
- for user in users:
- user_stream = self.user_to_user_stream.get(str(user))
- if user_stream is not None:
- user_streams.add(user_stream)
+ for user in users:
+ user_stream = self.user_to_user_stream.get(str(user))
+ if user_stream is not None:
+ user_streams.add(user_stream)
- for room in rooms:
- user_streams |= self.room_to_user_streams.get(room, set())
+ for room in rooms:
+ user_streams |= self.room_to_user_streams.get(room, set())
- time_now_ms = self.clock.time_msec()
- for user_stream in user_streams:
- try:
- user_stream.notify(stream_key, new_token, time_now_ms)
- except:
- logger.exception("Failed to notify listener")
+ time_now_ms = self.clock.time_msec()
+ for user_stream in user_streams:
+ try:
+ user_stream.notify(stream_key, new_token, time_now_ms)
+ except:
+ logger.exception("Failed to notify listener")
- self.notify_replication()
+ self.notify_replication()
+ @preserve_fn
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
@@ -294,7 +271,6 @@ class Notifier(object):
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
- appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token()
if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id)
@@ -302,7 +278,6 @@ class Notifier(object):
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
- appservice=appservice,
current_token=current_token,
time_now_ms=self.clock.time_msec(),
)
@@ -477,11 +452,6 @@ class Notifier(object):
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
- if user_stream.appservice:
- self.appservice_to_user_stream.setdefault(
- user_stream.appservice, set()
- ).add(user_stream)
-
def _user_joined_room(self, user_id, room_id):
new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None:
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 46e768e3..ed2ccc4d 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -38,15 +38,16 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
- with Measure(self.clock, "handle_push_actions_for_event"):
+ with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event(
- event, self.hs, self.store, context.current_state
+ event, self.hs, self.store, context.state_group, context.current_state
)
+ with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state
)
- context.push_actions = [
- (uid, actions) for uid, actions in actions_by_user.items()
- ]
+ context.push_actions = [
+ (uid, actions) for uid, actions in actions_by_user.items()
+ ]
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 024c1490..edb00ed2 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -217,45 +217,49 @@ BASE_APPEND_OVERRIDE_RULES = [
'dont_notify'
]
},
-]
-
-
-BASE_APPEND_UNDERRIDE_RULES = [
+ # This was changed from underride to override so it's closer in priority
+ # to the content rules where the user name highlight rule lives. This
+ # way a room rule is lower priority than both but a custom override rule
+ # is higher priority than both.
{
- 'rule_id': 'global/underride/.m.rule.call',
+ 'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.call.invite',
- '_id': '_call',
+ 'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
- 'value': 'ring'
+ 'value': 'default'
}, {
- 'set_tweak': 'highlight',
- 'value': False
+ 'set_tweak': 'highlight'
}
]
},
+]
+
+
+BASE_APPEND_UNDERRIDE_RULES = [
{
- 'rule_id': 'global/underride/.m.rule.contains_display_name',
+ 'rule_id': 'global/underride/.m.rule.call',
'conditions': [
{
- 'kind': 'contains_display_name'
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.call.invite',
+ '_id': '_call',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
- 'value': 'default'
+ 'value': 'ring'
}, {
- 'set_tweak': 'highlight'
+ 'set_tweak': 'highlight',
+ 'value': False
}
]
},
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 756e5da5..004eded6 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, current_state):
- room_id = event.room_id
- # We also will want to generate notifs for other people in the room so
- # their unread countss are correct in the event stream, but to avoid
- # generating them for bot / AS users etc, we only do so for people who've
- # sent a read receipt into the room.
-
- local_users_in_room = set(
- e.state_key for e in current_state.values()
- if e.type == EventTypes.Member and e.membership == Membership.JOIN
- and hs.is_mine_id(e.state_key)
+def evaluator_for_event(event, hs, store, state_group, current_state):
+ rules_by_user = yield store.bulk_get_push_rules_for_room(
+ event.room_id, state_group, current_state
)
- # users in the room who have pushers need to get push rules run because
- # that's how their pushers work
- if_users_with_pushers = yield store.get_if_users_have_pushers(
- local_users_in_room
- )
- user_ids = set(
- uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- )
-
- users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
-
- # any users with pushers must be ours: they have pushers
- for uid in users_with_receipts:
- if uid in local_users_in_room:
- user_ids.add(uid)
-
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher:
- user_ids.add(invited_user)
-
- rules_by_user = yield _get_rules(room_id, user_ids, store)
+ rules_by_user[invited_user] = yield store.get_push_rules_for_user(
+ invited_user
+ )
defer.returnValue(BulkPushRuleEvaluator(
- room_id, rules_by_user, user_ids, store
+ event.room_id, rules_by_user, store
))
@@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562)
"""
- def __init__(self, room_id, rules_by_user, users_in_room, store):
+ def __init__(self, room_id, rules_by_user, store):
self.room_id = room_id
self.rules_by_user = rules_by_user
- self.users_in_room = users_in_room
self.store = store
@defer.inlineCallbacks
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d555a33e..becb8ef1 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,14 +17,15 @@ from twisted.internet import defer
from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event
)
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks
def get_badge_count(store, user_id):
- invites, joins = yield defer.gatherResults([
- store.get_invited_rooms_for_user(user_id),
- store.get_rooms_for_user(user_id),
- ], consumeErrors=True)
+ invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(store.get_invited_rooms_for_user)(user_id),
+ preserve_fn(store.get_rooms_for_user)(user_id),
+ ], consumeErrors=True))
my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 5853ec36..3837be52 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -17,7 +17,7 @@
from twisted.internet import defer
import pusher
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor
import logging
@@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
- def remove_pushers_by_user(self, user_id, except_token_ids=[]):
+ def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers()
logger.info(
- "Removing all pushers for user %s except access tokens ids %r",
- user_id, except_token_ids
+ "Removing all pushers for user %s except access tokens id %r",
+ user_id, except_access_token_id
)
for p in all:
- if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
+ if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
@@ -130,10 +130,12 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- p.on_new_notifications(min_stream_id, max_stream_id)
+ preserve_fn(p.on_new_notifications)(
+ min_stream_id, max_stream_id
+ )
)
- yield defer.gatherResults(deferreds)
+ yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except:
logger.exception("Exception in pusher on_new_notifications")
@@ -155,10 +157,10 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- p.on_new_receipts(min_stream_id, max_stream_id)
+ preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
)
- yield defer.gatherResults(deferreds)
+ yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except:
logger.exception("Exception in pusher on_new_receipts")
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 8c2d487f..84993b33 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -41,6 +41,7 @@ STREAM_NAMES = (
("push_rules",),
("pushers",),
("state",),
+ ("caches",),
)
@@ -70,6 +71,7 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers.
+ * "caches": Cache invalidations.
The API takes two additional query parameters:
@@ -129,6 +131,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
+ caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@@ -140,6 +143,7 @@ class ReplicationResource(Resource):
push_rules_token,
pushers_token,
state_token,
+ caches_token,
))
@request_handler()
@@ -188,6 +192,7 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
+ yield self.caches(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total)
@@ -379,6 +384,20 @@ class ReplicationResource(Resource):
"position", "type", "state_key", "event_id"
))
+ @defer.inlineCallbacks
+ def caches(self, writer, current_token, limit, request_streams):
+ current_position = current_token.caches
+
+ caches = request_streams.get("caches")
+
+ if caches is not None:
+ updated_caches = yield self.store.get_all_updated_caches(
+ caches, current_position, limit
+ )
+ writer.write_header_and_rows("caches", updated_caches, (
+ "position", "cache_func", "keys", "invalidation_ts"
+ ))
+
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
@@ -407,7 +426,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
- "push_rules", "pushers", "state"
+ "push_rules", "pushers", "state", "caches",
))):
__slots__ = []
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 46e43ce1..f19540d6 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,15 +14,43 @@
# limitations under the License.
from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
+from ._slaved_id_tracker import SlavedIdTracker
+
+import logging
+
+logger = logging.getLogger(__name__)
+
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
+ if isinstance(self.database_engine, PostgresEngine):
+ self._cache_id_gen = SlavedIdTracker(
+ db_conn, "cache_invalidation_stream", "stream_id",
+ )
+ else:
+ self._cache_id_gen = None
def stream_positions(self):
- return {}
+ pos = {}
+ if self._cache_id_gen:
+ pos["caches"] = self._cache_id_gen.get_current_token()
+ return pos
def process_replication(self, result):
+ stream = result.get("caches")
+ if stream:
+ for row in stream["rows"]:
+ (
+ position, cache_func, keys, invalidation_ts,
+ ) = row
+
+ try:
+ getattr(self, cache_func).invalidate(tuple(keys))
+ except AttributeError:
+ logger.info("Got unexpected cache_func: %r", cache_func)
+ self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index 25792d94..a374f2f1 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
+ get_app_services = DataStore.get_app_services.__func__
+ get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
+ create_appservice_txn = DataStore.create_appservice_txn.__func__
+ get_appservices_by_state = DataStore.get_appservices_by_state.__func__
+ get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
+ _get_last_txn = DataStore._get_last_txn.__func__
+ complete_appservice_txn = DataStore.complete_appservice_txn.__func__
+ get_appservice_state = DataStore.get_appservice_state.__func__
+ set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
+ set_appservice_state = DataStore.set_appservice_state.__func__
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 5fbe3a30..7301d885 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
- ].orig
+ ]
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 307833f9..e27c7332 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -25,6 +25,9 @@ class SlavedRegistrationStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
- ].orig
+ ]
_query_for_auth = DataStore._query_for_auth.__func__
+ get_user_by_id = RegistrationStore.__dict__[
+ "get_user_by_id"
+ ]
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 14227f1c..32678040 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -46,7 +46,9 @@ from synapse.rest.client.v2_alpha import (
account_data,
report_event,
openid,
+ notifications,
devices,
+ thirdparty,
)
from synapse.http.server import JsonResource
@@ -91,4 +93,6 @@ class ClientRestResource(JsonResource):
account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
+ notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
+ thirdparty.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index b0cb31a4..af21661d 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
+ def __init__(self, hs):
+ super(WhoisRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
)
+ def __init__(self, hs):
+ super(PurgeHistoryRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index 96b49b01..c2a84478 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
hs (synapse.server.HomeServer):
"""
self.hs = hs
- self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore()
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 8ac09419..09d08315 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
+ def __init__(self, hs):
+ super(ClientDirectoryServer, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
@@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore()
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 498bb9e1..701b6f54 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
+ def __init__(self, hs):
+ super(EventStreamRestServlet, self).__init__(hs)
+ self.event_stream_handler = hs.get_event_stream_handler()
+
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(
@@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args:
room_id = request.args["room_id"][0]
- handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
@@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
- chunk = yield handler.get_stream(
+ chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
@@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
+ self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request)
- handler = self.handlers.event_handler
- event = yield handler.get_event(requester.user, event_id)
+ event = yield self.event_handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 36c35205..113a49e5 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$")
+ def __init__(self, hs):
+ super(InitialSyncRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 92fcae67..6c0eec8f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
- self.cas_server_url = hs.config.cas_server_url
- self.cas_required_attributes = hs.config.cas_required_attributes
- self.servername = hs.config.server_name
- self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
+ self.handlers = hs.get_handlers()
def on_GET(self, request):
flows = []
@@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
defer.returnValue(result)
- # TODO Delete this after all CAS clients switch to token login instead
- elif self.cas_enabled and (login_submission["type"] ==
- LoginRestServlet.CAS_TYPE):
- uri = "%s/proxyValidate" % (self.cas_server_url,)
- args = {
- "ticket": login_submission["ticket"],
- "service": login_submission["service"]
- }
- body = yield self.http_client.get_raw(uri, args)
- result = yield self.do_cas_login(body)
- defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
@@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
- # TODO Delete this after all CAS clients switch to token login instead
- @defer.inlineCallbacks
- def do_cas_login(self, cas_response_body):
- user, attributes = self.parse_cas_response(cas_response_body)
-
- for required_attribute, required_value in self.cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- user_id = UserID.create(user, self.hs.hostname).to_string()
- auth_handler = self.auth_handler
- registered_user_id = yield auth_handler.check_user_exists(user_id)
- if registered_user_id:
- access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(
- registered_user_id
- )
- )
- result = {
- "user_id": registered_user_id, # may have changed
- "access_token": access_token,
- "refresh_token": refresh_token,
- "home_server": self.hs.hostname,
- }
-
- else:
- user_id, access_token = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
-
- defer.returnValue((200, result))
-
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
token = login_submission.get("token", None)
@@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
- # TODO Delete this after all CAS clients switch to token login instead
- def parse_cas_response(self, cas_response_body):
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not root[0].tag.endswith("authenticationSuccess"):
- raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- attributes = {}
- for attribute in child:
- # ElementTree library expands the namespace in attribute tags
- # to the full URL of the namespace.
- # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
- # We don't care about namespace here and it will always be encased in
- # curly braces, so we remove them.
- if "}" in attribute.tag:
- attributes[attribute.tag.split("}")[1]] = attribute.text
- else:
- attributes[attribute.tag] = attribute.text
- if user is None or attributes is None:
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-
- return (user, attributes)
-
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
@@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"}))
-# TODO Delete this after all CAS clients switch to token login instead
-class CasRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/cas", releases=())
-
- def __init__(self, hs):
- super(CasRestServlet, self).__init__(hs)
- self.cas_server_url = hs.config.cas_server_url
-
- def on_GET(self, request):
- return (200, {"serverUrl": self.cas_server_url})
-
-
class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
@@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
+ self.auth_handler = hs.get_auth_handler()
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet):
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not root[0].tag.endswith("authenticationSuccess"):
- raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- attributes = {}
- for attribute in child:
- # ElementTree library expands the namespace in attribute tags
- # to the full URL of the namespace.
- # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
- # We don't care about namespace here and it will always be encased in
- # curly braces, so we remove them.
- if "}" in attribute.tag:
- attributes[attribute.tag.split("}")[1]] = attribute.text
- else:
- attributes[attribute.tag] = attribute.text
- if user is None or attributes is None:
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-
- return (user, attributes)
+ user = None
+ attributes = None
+ try:
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise Exception("root of CAS response is not serviceResponse")
+ success = (root[0].tag.endswith("authenticationSuccess"))
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ attributes = {}
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes[tag] = attribute.text
+ if user is None:
+ raise Exception("CAS response does not contain user")
+ if attributes is None:
+ raise Exception("CAS response does not contain attributes")
+ except Exception:
+ logger.error("Error parsing CAS response", exc_info=1)
+ raise LoginError(401, "Invalid CAS response",
+ errcode=Codes.UNAUTHORIZED)
+ if not success:
+ raise LoginError(401, "Unsuccessful CAS response",
+ errcode=Codes.UNAUTHORIZED)
+ return user, attributes
def register_servlets(hs, http_server):
@@ -512,5 +426,3 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
- CasRestServlet(hs).register(http_server)
- # TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 65c4e2eb..355e8247 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
+ def __init__(self, hs):
+ super(ProfileDisplaynameRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
+ def __init__(self, hs):
+ super(ProfileAvatarURLRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
+ def __init__(self, hs):
+ super(ProfileRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 2383b9df..71d58c8e 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
+ self.handlers = hs.get_handlers()
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 866a1e91..0d817570 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here
+ def __init__(self, hs):
+ super(RoomCreateRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
@@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomStateEventRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomSendEventRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(JoinRoomAliasServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -253,6 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
action="join",
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
+ content=content,
third_party_signed=content.get("third_party_signed", None),
)
@@ -296,6 +312,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
+ def __init__(self, hs):
+ super(RoomMemberListRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
@@ -322,6 +342,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
+ def __init__(self, hs):
+ super(RoomMessageListRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -351,6 +375,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
+ def __init__(self, hs):
+ super(RoomStateRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -368,6 +396,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
+ def __init__(self, hs):
+ super(RoomInitialSyncRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -388,6 +420,7 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs):
super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock()
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -424,6 +457,10 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomForgetRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server)
@@ -462,6 +499,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomMembershipRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@@ -542,6 +583,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomRedactEventRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server)
@@ -624,6 +669,10 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$"
)
+ def __init__(self, hs):
+ super(SearchRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
new file mode 100644
index 00000000..f1a48acf
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -0,0 +1,99 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import (
+ RestServlet, parse_string, parse_integer
+)
+from synapse.events.utils import (
+ serialize_event, format_event_for_client_v2_without_room_id,
+)
+
+from ._base import client_v2_patterns
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class NotificationsServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/notifications$", releases=())
+
+ def __init__(self, hs):
+ super(NotificationsServlet, self).__init__()
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ from_token = parse_string(request, "from", required=False)
+ limit = parse_integer(request, "limit", default=50)
+
+ limit = min(limit, 500)
+
+ push_actions = yield self.store.get_push_actions_for_user(
+ user_id, from_token, limit
+ )
+
+ receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
+ user_id, 'm.read'
+ )
+
+ notif_event_ids = [pa["event_id"] for pa in push_actions]
+ notif_events = yield self.store.get_events(notif_event_ids)
+
+ returned_push_actions = []
+
+ next_token = None
+
+ for pa in push_actions:
+ returned_pa = {
+ "room_id": pa["room_id"],
+ "profile_tag": pa["profile_tag"],
+ "actions": pa["actions"],
+ "ts": pa["received_ts"],
+ "event": serialize_event(
+ notif_events[pa["event_id"]],
+ self.clock.time_msec(),
+ event_format=format_event_for_client_v2_without_room_id,
+ ),
+ }
+
+ if pa["room_id"] not in receipts_by_room:
+ returned_pa["read"] = False
+ else:
+ receipt = receipts_by_room[pa["room_id"]]
+
+ returned_pa["read"] = (
+ receipt["topological_ordering"], receipt["stream_ordering"]
+ ) >= (
+ pa["topological_ordering"], pa["stream_ordering"]
+ )
+ returned_push_actions.append(returned_pa)
+ next_token = pa["stream_ordering"]
+
+ defer.returnValue((200, {
+ "notifications": returned_push_actions,
+ "next_token": next_token,
+ }))
+
+
+def register_servlets(hs, http_server):
+ NotificationsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 943f5676..2121bd75 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet):
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
- device_id = self.device_handler.check_device_registered(
+ return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
- return device_id
@defer.inlineCallbacks
def _do_guest_registration(self):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 43d8e0bf..b11acdbe 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
- yield self.presence_handler.set_state(user, {"presence": set_presence})
+ yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
new file mode 100644
index 00000000..9abca3a8
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet
+from synapse.types import ThirdPartyEntityKind
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class ThirdPartyUserServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
+ releases=())
+
+ def __init__(self, hs):
+ super(ThirdPartyUserServlet, self).__init__()
+
+ self.auth = hs.get_auth()
+ self.appservice_handler = hs.get_application_service_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, protocol):
+ yield self.auth.get_user_by_req(request)
+
+ fields = request.args
+ del fields["access_token"]
+
+ results = yield self.appservice_handler.query_3pe(
+ ThirdPartyEntityKind.USER, protocol, fields
+ )
+
+ defer.returnValue((200, results))
+
+
+class ThirdPartyLocationServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
+ releases=())
+
+ def __init__(self, hs):
+ super(ThirdPartyLocationServlet, self).__init__()
+
+ self.auth = hs.get_auth()
+ self.appservice_handler = hs.get_application_service_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, protocol):
+ yield self.auth.get_user_by_req(request)
+
+ fields = request.args
+ del fields["access_token"]
+
+ results = yield self.appservice_handler.query_3pe(
+ ThirdPartyEntityKind.LOCATION, protocol, fields
+ )
+
+ defer.returnValue((200, results))
+
+
+def register_servlets(hs, http_server):
+ ThirdPartyUserServlet(hs).register(http_server)
+ ThirdPartyLocationServlet(hs).register(http_server)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 7209d5a3..9fe20136 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,6 +15,7 @@
from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes
+from synapse.crypto.keyring import KeyLookupError
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -210,9 +211,10 @@ class RemoteKey(Resource):
yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids
)
+ except KeyLookupError as e:
+ logger.info("Failed to fetch key: %s", e)
except:
logger.exception("Failed to get key for %r", server_name)
- pass
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 9f696207..9f0625a8 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -45,6 +45,7 @@ class DownloadResource(Resource):
@request_handler()
@defer.inlineCallbacks
def _async_render_GET(self, request):
+ request.setHeader("Content-Security-Policy", "sandbox")
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index bdd0e60c..33f35fb4 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -29,14 +29,13 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
-from copy import deepcopy
-
import os
import re
import fnmatch
import cgi
import ujson as json
import urlparse
+import itertools
import logging
logger = logging.getLogger(__name__)
@@ -163,7 +162,7 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
- if self._is_media(media_info['media_type']):
+ if _is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info
)
@@ -184,11 +183,9 @@ class PreviewUrlResource(Resource):
logger.warn("Couldn't get dims for %s" % url)
# define our OG response for this media
- elif self._is_html(media_info['media_type']):
+ elif _is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
- from lxml import etree
-
file = open(media_info['filename'])
body = file.read()
file.close()
@@ -199,17 +196,35 @@ class PreviewUrlResource(Resource):
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
encoding = match.group(1) if match else "utf-8"
- try:
- parser = etree.HTMLParser(recover=True, encoding=encoding)
- tree = etree.fromstring(body, parser)
- og = yield self._calc_og(tree, media_info, requester)
- except UnicodeDecodeError:
- # blindly try decoding the body as utf-8, which seems to fix
- # the charset mismatches on https://google.com
- parser = etree.HTMLParser(recover=True, encoding=encoding)
- tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
- og = yield self._calc_og(tree, media_info, requester)
+ og = decode_and_calc_og(body, media_info['uri'], encoding)
+
+ # pre-cache the image for posterity
+ # FIXME: it might be cleaner to use the same flow as the main /preview_url
+ # request itself and benefit from the same caching etc. But for now we
+ # just rely on the caching on the master request to speed things up.
+ if 'og:image' in og and og['og:image']:
+ image_info = yield self._download_url(
+ _rebase_url(og['og:image'], media_info['uri']), requester.user
+ )
+ if _is_media(image_info['media_type']):
+ # TODO: make sure we don't choke on white-on-transparent images
+ dims = yield self.media_repo._generate_local_thumbnails(
+ image_info['filesystem_id'], image_info
+ )
+ if dims:
+ og["og:image:width"] = dims['width']
+ og["og:image:height"] = dims['height']
+ else:
+ logger.warn("Couldn't get dims for %s" % og["og:image"])
+
+ og["og:image"] = "mxc://%s/%s" % (
+ self.server_name, image_info['filesystem_id']
+ )
+ og["og:image:type"] = image_info['media_type']
+ og["matrix:image:size"] = image_info['media_length']
+ else:
+ del og["og:image"]
else:
logger.warn("Failed to find any OG data in %s", url)
og = {}
@@ -233,139 +248,6 @@ class PreviewUrlResource(Resource):
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
@defer.inlineCallbacks
- def _calc_og(self, tree, media_info, requester):
- # suck our tree into lxml and define our OG response.
-
- # if we see any image URLs in the OG response, then spider them
- # (although the client could choose to do this by asking for previews of those
- # URLs to avoid DoSing the server)
-
- # "og:type" : "video",
- # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
- # "og:site_name" : "YouTube",
- # "og:video:type" : "application/x-shockwave-flash",
- # "og:description" : "Fun stuff happening here",
- # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
- # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
- # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
- # "og:video:width" : "1280"
- # "og:video:height" : "720",
- # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
-
- og = {}
- for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
- if 'content' in tag.attrib:
- og[tag.attrib['property']] = tag.attrib['content']
-
- # TODO: grab article: meta tags too, e.g.:
-
- # "article:publisher" : "https://www.facebook.com/thethudonline" />
- # "article:author" content="https://www.facebook.com/thethudonline" />
- # "article:tag" content="baby" />
- # "article:section" content="Breaking News" />
- # "article:published_time" content="2016-03-31T19:58:24+00:00" />
- # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
-
- if 'og:title' not in og:
- # do some basic spidering of the HTML
- title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
- og['og:title'] = title[0].text.strip() if title else None
-
- if 'og:image' not in og:
- # TODO: extract a favicon failing all else
- meta_image = tree.xpath(
- "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
- )
- if meta_image:
- og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
- else:
- # TODO: consider inlined CSS styles as well as width & height attribs
- images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
- images = sorted(images, key=lambda i: (
- -1 * float(i.attrib['width']) * float(i.attrib['height'])
- ))
- if not images:
- images = tree.xpath("//img[@src]")
- if images:
- og['og:image'] = images[0].attrib['src']
-
- # pre-cache the image for posterity
- # FIXME: it might be cleaner to use the same flow as the main /preview_url
- # request itself and benefit from the same caching etc. But for now we
- # just rely on the caching on the master request to speed things up.
- if 'og:image' in og and og['og:image']:
- image_info = yield self._download_url(
- self._rebase_url(og['og:image'], media_info['uri']), requester.user
- )
-
- if self._is_media(image_info['media_type']):
- # TODO: make sure we don't choke on white-on-transparent images
- dims = yield self.media_repo._generate_local_thumbnails(
- image_info['filesystem_id'], image_info
- )
- if dims:
- og["og:image:width"] = dims['width']
- og["og:image:height"] = dims['height']
- else:
- logger.warn("Couldn't get dims for %s" % og["og:image"])
-
- og["og:image"] = "mxc://%s/%s" % (
- self.server_name, image_info['filesystem_id']
- )
- og["og:image:type"] = image_info['media_type']
- og["matrix:image:size"] = image_info['media_length']
- else:
- del og["og:image"]
-
- if 'og:description' not in og:
- meta_description = tree.xpath(
- "//*/meta"
- "[translate(@name, 'DESCRIPTION', 'description')='description']"
- "/@content")
- if meta_description:
- og['og:description'] = meta_description[0]
- else:
- # grab any text nodes which are inside the <body/> tag...
- # unless they are within an HTML5 semantic markup tag...
- # <header/>, <nav/>, <aside/>, <footer/>
- # ...or if they are within a <script/> or <style/> tag.
- # This is a very very very coarse approximation to a plain text
- # render of the page.
-
- # We don't just use XPATH here as that is slow on some machines.
-
- # We clone `tree` as we modify it.
- cloned_tree = deepcopy(tree.find("body"))
-
- TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
- for el in cloned_tree.iter(TAGS_TO_REMOVE):
- el.getparent().remove(el)
-
- # Split all the text nodes into paragraphs (by splitting on new
- # lines)
- text_nodes = (
- re.sub(r'\s+', '\n', el.text).strip()
- for el in cloned_tree.iter()
- if el.text and isinstance(el.tag, basestring) # Removes comments
- )
- og['og:description'] = summarize_paragraphs(text_nodes)
-
- # TODO: delete the url downloads to stop diskfilling,
- # as we only ever cared about its OG
- defer.returnValue(og)
-
- def _rebase_url(self, url, base):
- base = list(urlparse.urlparse(base))
- url = list(urlparse.urlparse(url))
- if not url[0]: # fix up schema
- url[0] = base[0] or "http"
- if not url[1]: # fix up hostname
- url[1] = base[1]
- if not url[2].startswith('/'):
- url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
- return urlparse.urlunparse(url)
-
- @defer.inlineCallbacks
def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
@@ -445,17 +327,171 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None,
})
- def _is_media(self, content_type):
- if content_type.lower().startswith("image/"):
- return True
- def _is_html(self, content_type):
- content_type = content_type.lower()
- if (
- content_type.startswith("text/html") or
- content_type.startswith("application/xhtml")
- ):
- return True
+def decode_and_calc_og(body, media_uri, request_encoding=None):
+ from lxml import etree
+
+ try:
+ parser = etree.HTMLParser(recover=True, encoding=request_encoding)
+ tree = etree.fromstring(body, parser)
+ og = _calc_og(tree, media_uri)
+ except UnicodeDecodeError:
+ # blindly try decoding the body as utf-8, which seems to fix
+ # the charset mismatches on https://google.com
+ parser = etree.HTMLParser(recover=True, encoding=request_encoding)
+ tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
+ og = _calc_og(tree, media_uri)
+
+ return og
+
+
+def _calc_og(tree, media_uri):
+ # suck our tree into lxml and define our OG response.
+
+ # if we see any image URLs in the OG response, then spider them
+ # (although the client could choose to do this by asking for previews of those
+ # URLs to avoid DoSing the server)
+
+ # "og:type" : "video",
+ # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
+ # "og:site_name" : "YouTube",
+ # "og:video:type" : "application/x-shockwave-flash",
+ # "og:description" : "Fun stuff happening here",
+ # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
+ # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
+ # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
+ # "og:video:width" : "1280"
+ # "og:video:height" : "720",
+ # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
+
+ og = {}
+ for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
+ if 'content' in tag.attrib:
+ og[tag.attrib['property']] = tag.attrib['content']
+
+ # TODO: grab article: meta tags too, e.g.:
+
+ # "article:publisher" : "https://www.facebook.com/thethudonline" />
+ # "article:author" content="https://www.facebook.com/thethudonline" />
+ # "article:tag" content="baby" />
+ # "article:section" content="Breaking News" />
+ # "article:published_time" content="2016-03-31T19:58:24+00:00" />
+ # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
+
+ if 'og:title' not in og:
+ # do some basic spidering of the HTML
+ title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
+ og['og:title'] = title[0].text.strip() if title else None
+
+ if 'og:image' not in og:
+ # TODO: extract a favicon failing all else
+ meta_image = tree.xpath(
+ "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
+ )
+ if meta_image:
+ og['og:image'] = _rebase_url(meta_image[0], media_uri)
+ else:
+ # TODO: consider inlined CSS styles as well as width & height attribs
+ images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
+ images = sorted(images, key=lambda i: (
+ -1 * float(i.attrib['width']) * float(i.attrib['height'])
+ ))
+ if not images:
+ images = tree.xpath("//img[@src]")
+ if images:
+ og['og:image'] = images[0].attrib['src']
+
+ if 'og:description' not in og:
+ meta_description = tree.xpath(
+ "//*/meta"
+ "[translate(@name, 'DESCRIPTION', 'description')='description']"
+ "/@content")
+ if meta_description:
+ og['og:description'] = meta_description[0]
+ else:
+ # grab any text nodes which are inside the <body/> tag...
+ # unless they are within an HTML5 semantic markup tag...
+ # <header/>, <nav/>, <aside/>, <footer/>
+ # ...or if they are within a <script/> or <style/> tag.
+ # This is a very very very coarse approximation to a plain text
+ # render of the page.
+
+ # We don't just use XPATH here as that is slow on some machines.
+
+ from lxml import etree
+
+ TAGS_TO_REMOVE = (
+ "header", "nav", "aside", "footer", "script", "style", etree.Comment
+ )
+
+ # Split all the text nodes into paragraphs (by splitting on new
+ # lines)
+ text_nodes = (
+ re.sub(r'\s+', '\n', el).strip()
+ for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+ )
+ og['og:description'] = summarize_paragraphs(text_nodes)
+
+ # TODO: delete the url downloads to stop diskfilling,
+ # as we only ever cared about its OG
+ return og
+
+
+def _iterate_over_text(tree, *tags_to_ignore):
+ """Iterate over the tree returning text nodes in a depth first fashion,
+ skipping text nodes inside certain tags.
+ """
+ # This is basically a stack that we extend using itertools.chain.
+ # This will either consist of an element to iterate over *or* a string
+ # to be returned.
+ elements = iter([tree])
+ while True:
+ el = elements.next()
+ if isinstance(el, basestring):
+ yield el
+ elif el is not None and el.tag not in tags_to_ignore:
+ # el.text is the text before the first child, so we can immediately
+ # return it if the text exists.
+ if el.text:
+ yield el.text
+
+ # We add to the stack all the elements children, interspersed with
+ # each child's tail text (if it exists). The tail text of a node
+ # is text that comes *after* the node, so we always include it even
+ # if we ignore the child node.
+ elements = itertools.chain(
+ itertools.chain.from_iterable( # Basically a flatmap
+ [child, child.tail] if child.tail else [child]
+ for child in el.iterchildren()
+ ),
+ elements
+ )
+
+
+def _rebase_url(url, base):
+ base = list(urlparse.urlparse(base))
+ url = list(urlparse.urlparse(url))
+ if not url[0]: # fix up schema
+ url[0] = base[0] or "http"
+ if not url[1]: # fix up hostname
+ url[1] = base[1]
+ if not url[2].startswith('/'):
+ url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
+ return urlparse.urlunparse(url)
+
+
+def _is_media(content_type):
+ if content_type.lower().startswith("image/"):
+ return True
+
+
+def _is_html(content_type):
+ content_type = content_type.lower()
+ if (
+ content_type.startswith("text/html") or
+ content_type.startswith("application/xhtml")
+ ):
+ return True
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
diff --git a/synapse/server.py b/synapse/server.py
index 6bb49883..af324650 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
+from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
@@ -94,6 +95,8 @@ class HomeServer(object):
'auth_handler',
'device_handler',
'e2e_keys_handler',
+ 'event_handler',
+ 'event_stream_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@@ -214,6 +217,12 @@ class HomeServer(object):
def build_application_service_handler(self):
return ApplicationServicesHandler(self)
+ def build_event_handler(self):
+ return EventHandler(self)
+
+ def build_event_stream_handler(self):
+ return EventStreamHandler(self)
+
def build_event_sources(self):
return EventSources(self)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index c0aa868c..9570df55 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,3 +1,4 @@
+import synapse.api.auth
import synapse.handlers
import synapse.handlers.auth
import synapse.handlers.device
@@ -6,6 +7,9 @@ import synapse.storage
import synapse.state
class HomeServer(object):
+ def get_auth(self) -> synapse.api.auth.Auth:
+ pass
+
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 73fb334d..7efc5bfe 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -50,6 +50,7 @@ from .openid import OpenIdStore
from .client_ips import ClientIpStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
+from .engines import PostgresEngine
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")],
)
+ if isinstance(self.database_engine, PostgresEngine):
+ self._cache_id_gen = StreamIdGenerator(
+ db_conn, "cache_invalidation_stream", "stream_id",
+ )
+ else:
+ self._cache_id_gen = None
+
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0117fdc6..49fa8614 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict
+from synapse.storage.engines import PostgresEngine
import synapse.metrics
@@ -165,7 +166,7 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
- self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
+ self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
@@ -305,13 +306,14 @@ class SQLBaseStore(object):
func, *args, **kwargs
)
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
-
- for after_callback, after_args in after_callbacks:
- after_callback(*after_args)
+ try:
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runWithConnection(
+ inner_func, *args, **kwargs
+ )
+ finally:
+ for after_callback, after_args in after_callbacks:
+ after_callback(*after_args)
defer.returnValue(result)
@defer.inlineCallbacks
@@ -860,6 +862,62 @@ class SQLBaseStore(object):
return cache, min_val
+ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ txn.call_after(cache_func.invalidate, keys)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # get_next() returns a context manager which is designed to wrap
+ # the transaction. However, we want to only get an ID when we want
+ # to use it, here, so we need to call __enter__ manually, and have
+ # __exit__ called after the transaction finishes.
+ ctx = self._cache_id_gen.get_next()
+ stream_id = ctx.__enter__()
+ txn.call_after(ctx.__exit__, None, None, None)
+ txn.call_after(self.hs.get_notifier().on_new_replication_data)
+
+ self._simple_insert_txn(
+ txn,
+ table="cache_invalidation_stream",
+ values={
+ "stream_id": stream_id,
+ "cache_func": cache_func.__name__,
+ "keys": list(keys),
+ "invalidation_ts": self.clock.time_msec(),
+ }
+ )
+
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit,))
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+ def get_cache_stream_token(self):
+ if self._cache_id_gen:
+ return self._cache_id_gen.get_current_token()
+ else:
+ return 0
+
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index d1ee533f..a854a87e 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -218,38 +218,37 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
Returns:
AppServiceTransaction: A new transaction.
"""
+ def _create_appservice_txn(txn):
+ # work out new txn id (highest txn id for this service += 1)
+ # The highest id may be the last one sent (in which case it is last_txn)
+ # or it may be the highest in the txns list (which are waiting to be/are
+ # being sent)
+ last_txn_id = self._get_last_txn(txn, service.id)
+
+ txn.execute(
+ "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
+ (service.id,)
+ )
+ highest_txn_id = txn.fetchone()[0]
+ if highest_txn_id is None:
+ highest_txn_id = 0
+
+ new_txn_id = max(highest_txn_id, last_txn_id) + 1
+
+ # Insert new txn into txn table
+ event_ids = json.dumps([e.event_id for e in events])
+ txn.execute(
+ "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
+ "VALUES(?,?,?)",
+ (service.id, new_txn_id, event_ids)
+ )
+ return AppServiceTransaction(
+ service=service, id=new_txn_id, events=events
+ )
+
return self.runInteraction(
"create_appservice_txn",
- self._create_appservice_txn,
- service, events
- )
-
- def _create_appservice_txn(self, txn, service, events):
- # work out new txn id (highest txn id for this service += 1)
- # The highest id may be the last one sent (in which case it is last_txn)
- # or it may be the highest in the txns list (which are waiting to be/are
- # being sent)
- last_txn_id = self._get_last_txn(txn, service.id)
-
- txn.execute(
- "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
- (service.id,)
- )
- highest_txn_id = txn.fetchone()[0]
- if highest_txn_id is None:
- highest_txn_id = 0
-
- new_txn_id = max(highest_txn_id, last_txn_id) + 1
-
- # Insert new txn into txn table
- event_ids = json.dumps([e.event_id for e in events])
- txn.execute(
- "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
- "VALUES(?,?,?)",
- (service.id, new_txn_id, event_ids)
- )
- return AppServiceTransaction(
- service=service, id=new_txn_id, events=events
+ _create_appservice_txn,
)
def complete_appservice_txn(self, txn_id, service):
@@ -263,39 +262,38 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves if this transaction was stored
successfully.
"""
- return self.runInteraction(
- "complete_appservice_txn",
- self._complete_appservice_txn,
- txn_id, service
- )
-
- def _complete_appservice_txn(self, txn, txn_id, service):
txn_id = int(txn_id)
- # Debugging query: Make sure the txn being completed is EXACTLY +1 from
- # what was there before. If it isn't, we've got problems (e.g. the AS
- # has probably missed some events), so whine loudly but still continue,
- # since it shouldn't fail completion of the transaction.
- last_txn_id = self._get_last_txn(txn, service.id)
- if (last_txn_id + 1) != txn_id:
- logger.error(
- "appservice: Completing a transaction which has an ID > 1 from "
- "the last ID sent to this AS. We've either dropped events or "
- "sent it to the AS out of order. FIX ME. last_txn=%s "
- "completing_txn=%s service_id=%s", last_txn_id, txn_id,
- service.id
+ def _complete_appservice_txn(txn):
+ # Debugging query: Make sure the txn being completed is EXACTLY +1 from
+ # what was there before. If it isn't, we've got problems (e.g. the AS
+ # has probably missed some events), so whine loudly but still continue,
+ # since it shouldn't fail completion of the transaction.
+ last_txn_id = self._get_last_txn(txn, service.id)
+ if (last_txn_id + 1) != txn_id:
+ logger.error(
+ "appservice: Completing a transaction which has an ID > 1 from "
+ "the last ID sent to this AS. We've either dropped events or "
+ "sent it to the AS out of order. FIX ME. last_txn=%s "
+ "completing_txn=%s service_id=%s", last_txn_id, txn_id,
+ service.id
+ )
+
+ # Set current txn_id for AS to 'txn_id'
+ self._simple_upsert_txn(
+ txn, "application_services_state", dict(as_id=service.id),
+ dict(last_txn=txn_id)
)
- # Set current txn_id for AS to 'txn_id'
- self._simple_upsert_txn(
- txn, "application_services_state", dict(as_id=service.id),
- dict(last_txn=txn_id)
- )
+ # Delete txn
+ self._simple_delete_txn(
+ txn, "application_services_txns",
+ dict(txn_id=txn_id, as_id=service.id)
+ )
- # Delete txn
- self._simple_delete_txn(
- txn, "application_services_txns",
- dict(txn_id=txn_id, as_id=service.id)
+ return self.runInteraction(
+ "complete_appservice_txn",
+ _complete_appservice_txn,
)
@defer.inlineCallbacks
@@ -309,10 +307,25 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves to an AppServiceTransaction or
None.
"""
+ def _get_oldest_unsent_txn(txn):
+ # Monotonically increasing txn ids, so just select the smallest
+ # one in the txns table (we delete them when they are sent)
+ txn.execute(
+ "SELECT * FROM application_services_txns WHERE as_id=?"
+ " ORDER BY txn_id ASC LIMIT 1",
+ (service.id,)
+ )
+ rows = self.cursor_to_dict(txn)
+ if not rows:
+ return None
+
+ entry = rows[0]
+
+ return entry
+
entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn",
- self._get_oldest_unsent_txn,
- service
+ _get_oldest_unsent_txn,
)
if not entry:
@@ -326,22 +339,6 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
service=service, id=entry["txn_id"], events=events
))
- def _get_oldest_unsent_txn(self, txn, service):
- # Monotonically increasing txn ids, so just select the smallest
- # one in the txns table (we delete them when they are sent)
- txn.execute(
- "SELECT * FROM application_services_txns WHERE as_id=?"
- " ORDER BY txn_id ASC LIMIT 1",
- (service.id,)
- )
- rows = self.cursor_to_dict(txn)
- if not rows:
- return None
-
- entry = rows[0]
-
- return entry
-
def _get_last_txn(self, txn, service_id):
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
@@ -352,3 +349,45 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
return 0
else:
return int(last_txn_id[0]) # select 'last_txn' col
+
+ def set_appservice_last_pos(self, pos):
+ def set_appservice_last_pos_txn(txn):
+ txn.execute(
+ "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
+ )
+ return self.runInteraction(
+ "set_appservice_last_pos", set_appservice_last_pos_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_new_events_for_appservice(self, current_id, limit):
+ """Get all new evnets"""
+
+ def get_new_events_for_appservice_txn(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id"
+ " FROM events AS e"
+ " WHERE"
+ " (SELECT stream_ordering FROM appservice_stream_position)"
+ " < e.stream_ordering"
+ " AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (current_id, limit))
+ rows = txn.fetchall()
+
+ upper_bound = current_id
+ if len(rows) == limit:
+ upper_bound = rows[-1][0]
+
+ return upper_bound, [row[1] for row in rows]
+
+ upper_bound, event_ids = yield self.runInteraction(
+ "get_new_events_for_appservice", get_new_events_for_appservice_txn,
+ )
+
+ events = yield self._get_events(event_ids)
+
+ defer.returnValue((upper_bound, events))
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index ef231a04..9caaf81f 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
Returns:
Deferred
"""
- try:
- yield self._simple_insert(
+ def alias_txn(txn):
+ self._simple_insert_txn(
+ txn,
"room_aliases",
{
"room_alias": room_alias.to_string(),
"room_id": room_id,
"creator": creator,
},
- desc="create_room_alias_association",
- )
- except self.database_engine.module.IntegrityError:
- raise SynapseError(
- 409, "Room alias %s already exists" % room_alias.to_string()
)
- for server in servers:
- # TODO(erikj): Fix this to bulk insert
- yield self._simple_insert(
- "room_alias_servers",
- {
+ self._simple_insert_many_txn(
+ txn,
+ table="room_alias_servers",
+ values=[{
"room_alias": room_alias.to_string(),
"server": server,
- },
- desc="create_room_alias_association",
+ } for server in servers],
)
- self.get_aliases_for_room.invalidate((room_id,))
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (room_id,)
+ )
+
+ try:
+ ret = yield self.runInteraction(
+ "create_room_alias_association", alias_txn
+ )
+ except self.database_engine.module.IntegrityError:
+ raise SynapseError(
+ 409, "Room alias %s already exists" % room_alias.to_string()
+ )
+ defer.returnValue(ret)
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index df4000d0..eb15fb75 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
)
self._simple_insert_many_txn(txn, "event_push_actions", values)
- @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
+ @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
@@ -338,6 +338,36 @@ class EventPushActionsStore(SQLBaseStore):
defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
+ def get_push_actions_for_user(self, user_id, before=None, limit=50):
+ def f(txn):
+ before_clause = ""
+ if before:
+ before_clause = "AND stream_ordering < ?"
+ args = [user_id, before, limit]
+ else:
+ args = [user_id, limit]
+ sql = (
+ "SELECT epa.event_id, epa.room_id,"
+ " epa.stream_ordering, epa.topological_ordering,"
+ " epa.actions, epa.profile_tag, e.received_ts"
+ " FROM event_push_actions epa, events e"
+ " WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id"
+ " AND epa.user_id = ? %s"
+ " ORDER BY epa.stream_ordering DESC"
+ " LIMIT ?"
+ % (before_clause,)
+ )
+ txn.execute(sql, args)
+ return self.cursor_to_dict(txn)
+
+ push_actions = yield self.runInteraction(
+ "get_push_actions_for_user", f
+ )
+ for pa in push_actions:
+ pa["actions"] = json.loads(pa["actions"])
+ defer.returnValue(push_actions)
+
+ @defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index d2feee8d..57e50052 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -20,8 +20,11 @@ from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
+from synapse.util.logcontext import (
+ preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
+)
from synapse.util.logutils import log_function
+from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
@@ -201,7 +204,7 @@ class EventsStore(SQLBaseStore):
deferreds = []
for room_id, evs_ctxs in partitioned.items():
- d = self._event_persist_queue.add_to_queue(
+ d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs,
backfilled=backfilled,
current_state=None,
@@ -211,7 +214,9 @@ class EventsStore(SQLBaseStore):
for room_id in partitioned.keys():
self._maybe_start_persisting(room_id)
- return defer.gatherResults(deferreds, consumeErrors=True)
+ return preserve_context_over_deferred(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
@defer.inlineCallbacks
@log_function
@@ -224,7 +229,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id)
- yield deferred
+ yield preserve_context_over_deferred(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@@ -600,7 +605,8 @@ class EventsStore(SQLBaseStore):
"rejections",
"redactions",
"room_memberships",
- "state_events"
+ "state_events",
+ "topics"
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
@@ -1086,7 +1092,7 @@ class EventsStore(SQLBaseStore):
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
- res = yield defer.gatherResults(
+ res = yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
@@ -1095,7 +1101,7 @@ class EventsStore(SQLBaseStore):
for row in rows
],
consumeErrors=True
- )
+ ))
defer.returnValue({
e.event.event_id: e
@@ -1131,54 +1137,55 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
rejected_reason=None):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if rejected_reason:
- rejected_reason = yield self._simple_select_one_onecol(
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- desc="_get_event_from_row_rejected_reason",
- )
+ with Measure(self._clock, "_get_event_from_row"):
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
+
+ if rejected_reason:
+ rejected_reason = yield self._simple_select_one_onecol(
+ table="rejections",
+ keyvalues={"event_id": rejected_reason},
+ retcol="reason",
+ desc="_get_event_from_row_rejected_reason",
+ )
- original_ev = FrozenEvent(
- d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
+ original_ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
- redacted_event = None
- if redacted:
- redacted_event = prune_event(original_ev)
+ redacted_event = None
+ if redacted:
+ redacted_event = prune_event(original_ev)
- redaction_id = yield self._simple_select_one_onecol(
- table="redactions",
- keyvalues={"redacts": redacted_event.event_id},
- retcol="event_id",
- desc="_get_event_from_row_redactions",
- )
+ redaction_id = yield self._simple_select_one_onecol(
+ table="redactions",
+ keyvalues={"redacts": redacted_event.event_id},
+ retcol="event_id",
+ desc="_get_event_from_row_redactions",
+ )
- redacted_event.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
+ redacted_event.unsigned["redacted_by"] = redaction_id
+ # Get the redaction event.
- because = yield self.get_event(
- redaction_id,
- check_redacted=False,
- allow_none=True,
- )
+ because = yield self.get_event(
+ redaction_id,
+ check_redacted=False,
+ allow_none=True,
+ )
- if because:
- # It's fine to do add the event directly, since get_pdu_json
- # will serialise this field correctly
- redacted_event.unsigned["redacted_because"] = because
+ if because:
+ # It's fine to do add the event directly, since get_pdu_json
+ # will serialise this field correctly
+ redacted_event.unsigned["redacted_because"] = because
- cache_entry = _EventCacheEntry(
- event=original_ev,
- redacted_event=redacted_event,
- )
+ cache_entry = _EventCacheEntry(
+ event=original_ev,
+ redacted_event=redacted_event,
+ )
- self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+ self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 8801669a..b94ce7be 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 33
+SCHEMA_VERSION = 34
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index d03f7c54..21d06966 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending",
)
- @defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid):
- result = yield self._simple_update_one(
- table="presence_list",
- keyvalues={"user_id": observer_localpart,
- "observed_user_id": observed_userid},
- updatevalues={"accepted": True},
- desc="set_presence_list_accepted",
+ def update_presence_list_txn(txn):
+ result = self._simple_update_one_txn(
+ txn,
+ table="presence_list",
+ keyvalues={
+ "user_id": observer_localpart,
+ "observed_user_id": observed_userid
+ },
+ updatevalues={"accepted": True},
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_presence_list_accepted, (observer_localpart,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_presence_list_observers_accepted, (observed_userid,)
+ )
+
+ return result
+
+ return self.runInteraction(
+ "set_presence_list_accepted", update_presence_list_txn,
)
- self.get_presence_list_accepted.invalidate((observer_localpart,))
- self.get_presence_list_observers_accepted.invalidate((observed_userid,))
- defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
if accepted:
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8183b7f1..78334a98 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -16,6 +16,7 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules
+from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer
import logging
@@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore):
- @cachedInlineCallbacks(lru=True)
+ @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules)
- @cachedInlineCallbacks(lru=True)
+ @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
@@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results)
+ def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
+
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
+ cache_context):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ # We also will want to generate notifs for other people in the room so
+ # their unread countss are correct in the event stream, but to avoid
+ # generating them for bot / AS users etc, we only do so for people who've
+ # sent a read receipt into the room.
+ local_users_in_room = set(
+ e.state_key for e in current_state.values()
+ if e.type == EventTypes.Member and e.membership == Membership.JOIN
+ and self.hs.is_mine_id(e.state_key)
+ )
+
+ # users in the room who have pushers need to get push rules run because
+ # that's how their pushers work
+ if_users_with_pushers = yield self.get_if_users_have_pushers(
+ local_users_in_room, on_invalidate=cache_context.invalidate,
+ )
+ user_ids = set(
+ uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+ )
+
+ users_with_receipts = yield self.get_users_with_read_receipts_in_room(
+ room_id, on_invalidate=cache_context.invalidate,
+ )
+
+ # any users with pushers must be ours: they have pushers
+ for uid in users_with_receipts:
+ if uid in local_users_in_room:
+ user_ids.add(uid)
+
+ rules_by_user = yield self.bulk_get_push_rules(
+ user_ids, on_invalidate=cache_context.invalidate,
+ )
+
+ rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
+
+ defer.returnValue(rules_by_user)
+
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index a7d7c54d..8f5f8f24 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
- @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
+ @cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch(
table='pushers',
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 8c26f39f..ccc3811e 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -95,6 +95,31 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
@defer.inlineCallbacks
+ def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+ def f(txn):
+ sql = (
+ "SELECT rl.room_id, rl.event_id,"
+ " e.topological_ordering, e.stream_ordering"
+ " FROM receipts_linearized AS rl"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE rl.room_id = e.room_id"
+ " AND rl.event_id = e.event_id"
+ " AND user_id = ?"
+ )
+ txn.execute(sql, (user_id,))
+ return txn.fetchall()
+ rows = yield self.runInteraction(
+ "get_receipts_for_user_with_orderings", f
+ )
+ defer.returnValue({
+ row[0]: {
+ "event_id": row[1],
+ "topological_ordering": row[2],
+ "stream_ordering": row[3],
+ } for row in rows
+ })
+
+ @defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients.
@@ -120,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
- @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
+ @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 7e7d32eb..e404fa72 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -93,7 +93,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="add_refresh_token_to_user",
)
- @defer.inlineCallbacks
def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False):
@@ -115,7 +114,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Raises:
StoreError if the user_id could not be registered.
"""
- yield self.runInteraction(
+ return self.runInteraction(
"register",
self._register,
user_id,
@@ -127,8 +126,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
create_profile_with_localpart,
admin
)
- self.get_user_by_id.invalidate((user_id,))
- self.is_guest.invalidate((user_id,))
def _register(
self,
@@ -210,6 +207,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
(create_profile_with_localpart,)
)
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ txn.call_after(self.is_guest.invalidate, (user_id,))
+
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
@@ -236,22 +238,31 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("get_users_by_id_case_insensitive", f)
- @defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
- yield self._simple_update_one('users', {
- 'name': user_id
- }, {
- 'password_hash': password_hash
- })
- self.get_user_by_id.invalidate((user_id,))
+ def user_set_password_hash_txn(txn):
+ self._simple_update_one_txn(
+ txn,
+ 'users', {
+ 'name': user_id
+ },
+ {
+ 'password_hash': password_hash
+ }
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ return self.runInteraction(
+ "user_set_password_hash", user_set_password_hash_txn
+ )
@defer.inlineCallbacks
- def user_delete_access_tokens(self, user_id, except_token_ids=[],
+ def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None,
delete_refresh_tokens=False):
"""
@@ -259,7 +270,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args:
user_id (str): ID of user the tokens belong to
- except_token_ids (list[str]): list of access_tokens which should
+ except_token_id (str): list of access_tokens IDs which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
@@ -269,53 +280,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Returns:
defer.Deferred:
"""
- def f(txn, table, except_tokens, call_after_delete):
- sql = "SELECT token FROM %s WHERE user_id = ?" % table
- clauses = [user_id]
-
+ def f(txn):
+ keyvalues = {
+ "user_id": user_id,
+ }
if device_id is not None:
- sql += " AND device_id = ?"
- clauses.append(device_id)
+ keyvalues["device_id"] = device_id
- if except_tokens:
- sql += " AND id NOT IN (%s)" % (
- ",".join(["?" for _ in except_tokens]),
+ if delete_refresh_tokens:
+ self._simple_delete_txn(
+ txn,
+ table="refresh_tokens",
+ keyvalues=keyvalues,
)
- clauses += except_tokens
-
- txn.execute(sql, clauses)
- rows = txn.fetchall()
+ items = keyvalues.items()
+ where_clause = " AND ".join(k + " = ?" for k, _ in items)
+ values = [v for _, v in items]
+ if except_token_id:
+ where_clause += " AND id != ?"
+ values.append(except_token_id)
- n = 100
- chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
- for chunk in chunks:
- if call_after_delete:
- for row in chunk:
- txn.call_after(call_after_delete, (row[0],))
+ txn.execute(
+ "SELECT token FROM access_tokens WHERE %s" % where_clause,
+ values
+ )
+ rows = self.cursor_to_dict(txn)
- txn.execute(
- "DELETE FROM %s WHERE token in (%s)" % (
- table,
- ",".join(["?" for _ in chunk]),
- ), [r[0] for r in chunk]
+ for row in rows:
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (row["token"],)
)
- # delete refresh tokens first, to stop new access tokens being
- # allocated while our backs are turned
- if delete_refresh_tokens:
- yield self.runInteraction(
- "user_delete_access_tokens", f,
- table="refresh_tokens",
- except_tokens=[],
- call_after_delete=None,
+ txn.execute(
+ "DELETE FROM access_tokens WHERE %s" % where_clause,
+ values
)
yield self.runInteraction(
"user_delete_access_tokens", f,
- table="access_tokens",
- except_tokens=except_token_ids,
- call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token):
@@ -328,7 +331,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
},
)
- txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (access_token,)
+ )
return self.runInteraction("delete_access_token", f)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8bd693be..a422ddf6 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN],
)
- @defer.inlineCallbacks
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
- yield self.runInteraction("forget_membership", f)
- self.was_forgotten_at.invalidate_all()
- self.who_forgot_in_room.invalidate_all()
- self.did_forget.invalidate((user_id, room_id))
+
+ txn.call_after(self.was_forgotten_at.invalidate_all)
+ txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+ self._invalidate_cache_and_stream(
+ txn, self.who_forgot_in_room, (room_id,)
+ )
+ return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id):
diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/schema/delta/34/appservice_stream.sql
new file mode 100644
index 00000000..69e16eda
--- /dev/null
+++ b/synapse/storage/schema/delta/34/appservice_stream.sql
@@ -0,0 +1,23 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS appservice_stream_position(
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_ordering BIGINT,
+ CHECK (Lock='X')
+);
+
+INSERT INTO appservice_stream_position (stream_ordering)
+ SELECT COALESCE(MAX(stream_ordering), 0) FROM events;
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py
new file mode 100644
index 00000000..3b63a156
--- /dev/null
+++ b/synapse/storage/schema/delta/34/cache_stream.py
@@ -0,0 +1,46 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+# This stream is used to notify replication slaves that some caches have
+# been invalidated that they cannot infer from the other streams.
+CREATE_TABLE = """
+CREATE TABLE cache_invalidation_stream (
+ stream_id BIGINT,
+ cache_func TEXT,
+ keys TEXT[],
+ invalidation_ts BIGINT
+);
+
+CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if not isinstance(database_engine, PostgresEngine):
+ return
+
+ for statement in get_statements(CREATE_TABLE.splitlines()):
+ cur.execute(statement)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/schema/delta/34/push_display_name_rename.sql
new file mode 100644
index 00000000..0d9fe1a9
--- /dev/null
+++ b/synapse/storage/schema/delta/34/push_display_name_rename.sql
@@ -0,0 +1,20 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
+UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
+
+DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
+UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/schema/delta/34/received_txn_purge.py
new file mode 100644
index 00000000..03314434
--- /dev/null
+++ b/synapse/storage/schema/delta/34/received_txn_purge.py
@@ -0,0 +1,32 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute("TRUNCATE received_transactions")
+ else:
+ cur.execute("DELETE FROM received_transactions")
+
+ cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)")
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index ea6823f1..e1dca927 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
- @cached(lru=True)
+ @cached()
def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5b743db6..0e8fa93e 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
- @cached(num_args=2, lru=True, max_entries=1000)
+ @cached(num_args=2, max_entries=1000)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
@@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
- @cached(num_args=2, lru=True, max_entries=10000)
+ @cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 862c5c3e..0577a052 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
@@ -234,12 +234,12 @@ class StreamStore(SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
- res = yield defer.gatherResults([
+ res = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit, order=order,
)
for room_id in rm_ids
- ])
+ ]))
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 6258ff17..58d4de4f 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -62,10 +62,9 @@ class TransactionStore(SQLBaseStore):
self.last_transaction = {}
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
- hs.get_clock().looping_call(
- self._persist_in_mem_txns,
- 1000,
- )
+ self._clock.looping_call(self._persist_in_mem_txns, 1000)
+
+ self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
@@ -127,6 +126,7 @@ class TransactionStore(SQLBaseStore):
"origin": origin,
"response_code": code,
"response_json": buffer(encode_canonical_json(response_dict)),
+ "ts": self._clock.time_msec(),
},
or_ignore=True,
desc="set_received_txn_response",
@@ -383,3 +383,12 @@ class TransactionStore(SQLBaseStore):
yield self.runInteraction("_persist_in_mem_txns", f)
except:
logger.exception("Failed to persist transactions!")
+
+ def _cleanup_transactions(self):
+ now = self._clock.time_msec()
+ month_ago = now - 30 * 24 * 60 * 60 * 1000
+
+ def _cleanup_transactions_txn(txn):
+ txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+
+ return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
diff --git a/synapse/types.py b/synapse/types.py
index 5349b0c4..fd17ecbb 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
+
+
+# Some arbitrary constants used for internal API enumerations. Don't rely on
+# exact values; always pass or compare symbolically
+class ThirdPartyEntityKind(object):
+ USER = 'user'
+ LOCATION = 'location'
diff --git a/synapse/util/async.py b/synapse/util/async.py
index c84b23ff..347fb1e3 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit):
except StopIteration:
pass
- return defer.gatherResults([
+ return preserve_context_over_deferred(defer.gatherResults([
preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit)
- ], consumeErrors=True).addErrback(unwrapFirstError)
+ ], consumeErrors=True)).addErrback(unwrapFirstError)
class Linearizer(object):
@@ -181,7 +181,8 @@ class Linearizer(object):
self.key_to_defer[key] = new_defer
if current_defer:
- yield preserve_context_over_deferred(current_defer)
+ with PreserveLoggingContext():
+ yield current_defer
@contextmanager
def _ctx_manager():
@@ -264,7 +265,7 @@ class ReadWriteLock(object):
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
- yield defer.gatherResults(to_wait_on)
+ yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f31dfb22..8dba61d4 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,8 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
-
-from collections import OrderedDict
+from collections import namedtuple
import os
import functools
@@ -54,16 +53,11 @@ class Cache(object):
"metrics",
)
- def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
- if lru:
- cache_type = TreeCache if tree else dict
- self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type
- )
- self.max_entries = None
- else:
- self.cache = OrderedDict()
- self.max_entries = max_entries
+ def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+ cache_type = TreeCache if tree else dict
+ self.cache = LruCache(
+ max_size=max_entries, keylen=keylen, cache_type=cache_type
+ )
self.name = name
self.keylen = keylen
@@ -81,8 +75,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, default=_CacheSentinel):
- val = self.cache.get(key, _CacheSentinel)
+ def get(self, key, default=_CacheSentinel, callback=None):
+ val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@@ -94,19 +88,15 @@ class Cache(object):
else:
return default
- def update(self, sequence, key, value):
+ def update(self, sequence, key, value, callback=None):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(key, value)
-
- def prefill(self, key, value):
- if self.max_entries is not None:
- while len(self.cache) >= self.max_entries:
- self.cache.popitem(last=False)
+ self.prefill(key, value, callback=callback)
- self.cache[key] = value
+ def prefill(self, key, value, callback=None):
+ self.cache.set(key, value, callback=callback)
def invalidate(self, key):
self.check_thread()
@@ -151,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
+
+ Cached functions can be "chained" (i.e. a cached function can call other cached
+ functions and get appropriately invalidated when they called caches are
+ invalidated) by adding a special "cache_context" argument to the function
+ and passing that as a kwarg to all caches called. For example::
+
+ @cachedInlineCallbacks(cache_context=True)
+ def foo(self, key, cache_context):
+ r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
+ defer.returnValue(r1 + r2)
+
"""
- def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
- inlineCallbacks=False):
+ def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+ inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@@ -165,15 +167,33 @@ class CacheDescriptor(object):
self.max_entries = max_entries
self.num_args = num_args
- self.lru = lru
self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
+ all_args = inspect.getargspec(orig)
+ self.arg_names = all_args.args[1:num_args + 1]
+
+ if "cache_context" in all_args.args:
+ if not cache_context:
+ raise ValueError(
+ "Cannot have a 'cache_context' arg without setting"
+ " cache_context=True"
+ )
+ try:
+ self.arg_names.remove("cache_context")
+ except ValueError:
+ pass
+ elif cache_context:
+ raise ValueError(
+ "Cannot have cache_context=True without having an arg"
+ " named `cache_context`"
+ )
+
+ self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
+ " (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
@@ -182,16 +202,29 @@ class CacheDescriptor(object):
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
- lru=self.lru,
tree=self.tree,
)
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+
+ # Add temp cache_context so inspect.getcallargs doesn't explode
+ if self.add_cache_context:
+ kwargs["cache_context"] = None
+
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext(cache, cache_key)
+
try:
- cached_result_d = cache.get(cache_key)
+ cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@@ -228,7 +261,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- cache.update(sequence, cache_key, ret)
+ cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@@ -297,6 +330,10 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
@@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
- res = cache.get(tuple(key))
+ res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
@@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
- cache.update(sequence, tuple(key), observer)
+ cache.update(
+ sequence, tuple(key), observer,
+ callback=invalidate_callback
+ )
def invalidate(f, key):
cache.invalidate(key)
@@ -376,24 +416,29 @@ class CacheListDescriptor(object):
return wrapped
-def cached(max_entries=1000, num_args=1, lru=True, tree=False):
+class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
+ def invalidate(self):
+ self.cache.invalidate(self.key)
+
+
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru,
tree=tree,
+ cache_context=cache_context,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru,
tree=tree,
inlineCallbacks=True,
+ cache_context=cache_context,
)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f9df445a..9c4c6791 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object):
- __slots__ = ["prev_node", "next_node", "key", "value"]
+ __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
- def __init__(self, prev_node, next_node, key, value):
+ def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
+ self.callbacks = callbacks
class LruCache(object):
@@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
+
+ Can also set callbacks on objects when getting/setting which are fired
+ when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
@@ -62,10 +66,10 @@ class LruCache(object):
return inner
- def add_node(key, value):
+ def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
- node = _Node(prev_node, next_node, key, value)
+ node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
+ for cb in node.callbacks:
+ cb()
+ node.callbacks.clear()
+
@synchronized
- def cache_get(key, default=None):
+ def cache_get(key, default=None, callback=None):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
+ if callback:
+ node.callbacks.add(callback)
return node.value
else:
return default
@synchronized
- def cache_set(key, value):
+ def cache_set(key, value, callback=None):
node = cache.get(key, None)
if node is not None:
+ if value != node.value:
+ for cb in node.callbacks:
+ cb()
+ node.callbacks.clear()
+
+ if callback:
+ node.callbacks.add(callback)
+
move_node_to_front(node)
node.value = value
else:
- add_node(key, value)
+ if callback:
+ callbacks = set([callback])
+ else:
+ callbacks = set()
+ add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
@@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear():
list_root.next_node = list_root
list_root.prev_node = list_root
+ for node in cache.values():
+ for cb in node.callbacks:
+ cb()
cache.clear()
@synchronized
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 03bc1401..c31585ae 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt
return popped
+ def values(self):
+ return [e.value for e in self.root.values()]
+
def __len__(self):
return self.size
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 5316259d..6c83eb21 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
return res
-def preserve_context_over_deferred(deferred):
+def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
- current_context = LoggingContext.current_context()
- d = _PreservingContextDeferred(current_context)
+ if context is None:
+ context = LoggingContext.current_context()
+ d = _PreservingContextDeferred(context)
deferred.chainDeferred(d)
return d
@@ -316,8 +317,13 @@ def preserve_fn(f):
def g(*args, **kwargs):
with PreserveLoggingContext(current):
- return f(*args, **kwargs)
-
+ res = f(*args, **kwargs)
+ if isinstance(res, defer.Deferred):
+ return preserve_context_over_deferred(
+ res, context=LoggingContext.sentinel
+ )
+ else:
+ return res
return g
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 0b944d3e..4ea930d3 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
from synapse.util.logcontext import LoggingContext
import synapse.metrics
+from functools import wraps
import logging
@@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution(
)
+def measure_func(name):
+ def wrapper(func):
+ @wraps(func)
+ @defer.inlineCallbacks
+ def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, name):
+ r = yield func(self, *args, **kwargs)
+ defer.returnValue(r)
+ return measured_func
+ return wrapper
+
+
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
@@ -64,7 +78,6 @@ class Measure(object):
self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context()
if not self.start_context:
- logger.warn("Entered Measure without log context: %s", self.name)
self.start_context = LoggingContext("Measure")
self.start_context.__enter__()
self.created_context = True
@@ -74,7 +87,7 @@ class Measure(object):
self.db_txn_duration = self.start_context.db_txn_duration
def __exit__(self, exc_type, exc_val, exc_tb):
- if exc_type is not None or not self.start_context:
+ if isinstance(exc_type, Exception) or not self.start_context:
return
duration = self.clock.time_msec() - self.start
@@ -85,7 +98,7 @@ class Measure(object):
if context != self.start_context:
logger.warn(
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
- context, self.start_context, self.name
+ self.start_context, context, self.name
)
return
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 948ad517..cc12c0a2 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging
@@ -55,12 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
given events
events ([synapse.events.EventBase]): list of events to filter
"""
- forgotten = yield defer.gatherResults([
+ forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
- ], consumeErrors=True)
+ ], consumeErrors=True))
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index d6cc1881..aa8cc505 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -14,6 +14,8 @@
# limitations under the License.
from synapse.appservice import ApplicationService
+from twisted.internet import defer
+
from mock import Mock
from tests import unittest
@@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
)
+ self.store = Mock()
+
+ @defer.inlineCallbacks
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
- self.assertTrue(self.service.is_interested(self.event))
+ self.assertTrue((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@someone_else:matrix.org"
- self.assertFalse(self.service.is_interested(self.event))
+ self.assertFalse((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
@@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
- self.assertTrue(self.service.is_interested(self.event))
+ self.assertTrue((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
- self.assertTrue(self.service.is_interested(self.event))
+ self.assertTrue((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
- self.assertFalse(self.service.is_interested(self.event))
+ self.assertFalse((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.assertTrue(self.service.is_interested(
- self.event,
- aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
- ))
+ self.store.get_aliases_for_room.return_value = [
+ "#irc_foobar:matrix.org", "#athing:matrix.org"
+ ]
+ self.store.get_users_in_room.return_value = []
+ self.assertTrue((yield self.service.is_interested(
+ self.event, self.store
+ )))
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
@@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
"!irc_foobar:matrix.org"
))
+ @defer.inlineCallbacks
def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.assertFalse(self.service.is_interested(
- self.event,
- aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
- ))
+ self.store.get_aliases_for_room.return_value = [
+ "#xmpp_foobar:matrix.org", "#athing:matrix.org"
+ ]
+ self.store.get_users_in_room.return_value = []
+ self.assertFalse((yield self.service.is_interested(
+ self.event, self.store
+ )))
+ @defer.inlineCallbacks
def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
@@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
- self.assertTrue(self.service.is_interested(
- self.event,
- aliases_for_event=["#irc_barfoo:matrix.org"]
- ))
-
- def test_restrict_to_rooms(self):
- self.service.namespaces[ApplicationService.NS_ROOMS].append(
- _regex("!flibble_.*:matrix.org")
- )
- self.service.namespaces[ApplicationService.NS_USERS].append(
- _regex("@irc_.*")
- )
- self.event.sender = "@irc_foobar:matrix.org"
- self.event.room_id = "!wibblewoo:matrix.org"
- self.assertFalse(self.service.is_interested(
- self.event,
- restrict_to=ApplicationService.NS_ROOMS
- ))
-
- def test_restrict_to_aliases(self):
- self.service.namespaces[ApplicationService.NS_ALIASES].append(
- _regex("#xmpp_.*:matrix.org")
- )
- self.service.namespaces[ApplicationService.NS_USERS].append(
- _regex("@irc_.*")
- )
- self.event.sender = "@irc_foobar:matrix.org"
- self.assertFalse(self.service.is_interested(
- self.event,
- restrict_to=ApplicationService.NS_ALIASES,
- aliases_for_event=["#irc_barfoo:matrix.org"]
- ))
-
- def test_restrict_to_senders(self):
- self.service.namespaces[ApplicationService.NS_ALIASES].append(
- _regex("#xmpp_.*:matrix.org")
- )
- self.service.namespaces[ApplicationService.NS_USERS].append(
- _regex("@irc_.*")
- )
- self.event.sender = "@xmpp_foobar:matrix.org"
- self.assertFalse(self.service.is_interested(
- self.event,
- restrict_to=ApplicationService.NS_USERS,
- aliases_for_event=["#xmpp_barfoo:matrix.org"]
- ))
+ self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
+ self.store.get_users_in_room.return_value = []
+ self.assertTrue((yield self.service.is_interested(
+ self.event, self.store
+ )))
+ @defer.inlineCallbacks
def test_interested_in_self(self):
# make sure invites get through
self.service.sender = "@appservice:name"
@@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
"membership": "invite"
}
self.event.state_key = self.service.sender
- self.assertTrue(self.service.is_interested(self.event))
+ self.assertTrue((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
- join_list = [
+ self.store.get_users_in_room.return_value = [
"@alice:here",
"@irc_fo:here", # AS user
"@bob:here",
]
+ self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org"
- self.assertTrue(self.service.is_interested(
- event=self.event,
- member_list=join_list
- ))
+ self.assertTrue((yield self.service.is_interested(
+ event=self.event, store=self.store
+ )))
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 631a2293..e5a902f7 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self):
self.txn_ctrl = Mock()
- self.queuer = _ServiceQueuer(self.txn_ctrl)
+ self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately.
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a884c95f..7fe88172 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
from .. import unittest
+from tests.utils import MockClock
from synapse.handlers.appservice import ApplicationServicesHandler
@@ -32,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore = Mock(return_value=self.mock_store)
hs.get_application_service_api = Mock(return_value=self.mock_as_api)
hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
+ hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
@defer.inlineCallbacks
@@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
type="m.room.message",
room_id="!foo:bar"
)
+ self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_as_api.push = Mock()
- yield self.handler.notify_interested_services(event)
+ yield self.handler.notify_interested_services(0)
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
@@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
- yield self.handler.notify_interested_services(event)
+ self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
+ yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id
)
@@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
- yield self.handler.notify_interested_services(event)
+ self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
+ yield self.handler.notify_interested_services(0)
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been."
@@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet"
servers = ["aperture"]
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice_alias(is_interested_in_alias=True)
services = [
- self._mkservice(is_interested=False),
+ self._mkservice_alias(is_interested_in_alias=False),
interested_service,
- self._mkservice(is_interested=False)
+ self._mkservice_alias(is_interested_in_alias=False)
]
self.mock_store.get_app_services = Mock(return_value=services)
@@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
+
+ def _mkservice_alias(self, is_interested_in_alias):
+ service = Mock()
+ service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
+ service.token = "mock_service_token"
+ service.url = "mock_service_url"
+ return service
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 21077cbe..4a8cd19a 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -14,11 +14,13 @@
# limitations under the License.
import pymacaroons
+from twisted.internet import defer
+import synapse
+import synapse.api.errors
from synapse.handlers.auth import AuthHandler
from tests import unittest
from tests.utils import setup_test_homeserver
-from twisted.internet import defer
class AuthHandlers(object):
@@ -31,11 +33,12 @@ class AuthTestCase(unittest.TestCase):
def setUp(self):
self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = AuthHandlers(self.hs)
+ self.auth_handler = self.hs.handlers.auth_handler
def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret"
- token = self.hs.handlers.auth_handler.generate_access_token("some_user")
+ token = self.auth_handler.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks
@@ -46,7 +49,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000
- token = self.hs.handlers.auth_handler.generate_access_token("a_user")
+ token = self.auth_handler.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat):
@@ -67,3 +70,46 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_type)
v.satisfy_general(verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+ def test_short_term_login_token_gives_user_id(self):
+ self.hs.clock.now = 1000
+
+ token = self.auth_handler.generate_short_term_login_token(
+ "a_user", 5000
+ )
+
+ self.assertEqual(
+ "a_user",
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
+ )
+ )
+
+ # when we advance the clock, the token should be rejected
+ self.hs.clock.now = 6000
+ with self.assertRaises(synapse.api.errors.AuthError):
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
+ )
+
+ def test_short_term_login_token_cannot_replace_user_id(self):
+ token = self.auth_handler.generate_short_term_login_token(
+ "a_user", 5000
+ )
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+
+ self.assertEqual(
+ "a_user",
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
+ )
+
+ # add another "user_id" caveat, which might allow us to override the
+ # user_id.
+ macaroon.add_first_party_caveat("user_id = b_user")
+
+ with self.assertRaises(synapse.api.errors.AuthError):
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96b7dba5..ab609556 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
+from mock import Mock
+
from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
@@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
def test_eviction_lru(self):
- cache = Cache("test", max_entries=2, lru=True)
+ cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
@@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
+
+ @defer.inlineCallbacks
+ def test_invalidate_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func.invalidate(("foo",))
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 1)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ @defer.inlineCallbacks
+ def test_eviction_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached(max_entries=2)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+ yield a.func2("foo2")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func("foo3")
+
+ self.assertEquals(callcount[0], 3)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 4)
+ self.assertEquals(callcount2[0], 3)
+
+ @defer.inlineCallbacks
+ def test_double_get(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+ yield a.func2("foo")
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 2)
+
+ a.func.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 3)
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 2a801173..c8d6525a 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -15,7 +15,9 @@
from . import unittest
-from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs
+from synapse.rest.media.v1.preview_url_resource import (
+ summarize_paragraphs, decode_and_calc_og
+)
class PreviewTestCase(unittest.TestCase):
@@ -137,3 +139,79 @@ class PreviewTestCase(unittest.TestCase):
" of old wooden houses in Northern Norway, the oldest house dating from"
" 1789. The Arctic Cathedral, a modern church…"
)
+
+
+class PreviewUrlTestCase(unittest.TestCase):
+ def test_simple(self):
+ html = """
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+
+ self.assertEquals(og, {
+ "og:title": "Foo",
+ "og:description": "Some text."
+ })
+
+ def test_comment(self):
+ html = """
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ <!-- HTML comment -->
+ Some text.
+ </body>
+ </html>
+ """
+
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+
+ self.assertEquals(og, {
+ "og:title": "Foo",
+ "og:description": "Some text."
+ })
+
+ def test_comment2(self):
+ html = """
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ Some text.
+ <!-- HTML comment -->
+ Some more text.
+ <p>Text</p>
+ More text
+ </body>
+ </html>
+ """
+
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+
+ self.assertEquals(og, {
+ "og:title": "Foo",
+ "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
+ })
+
+ def test_script(self):
+ html = """
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ <script> (function() {})() </script>
+ Some text.
+ </body>
+ </html>
+ """
+
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+
+ self.assertEquals(og, {
+ "og:title": "Foo",
+ "og:description": "Some text."
+ })
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index bab366fb..1eba5b53 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from mock import Mock
+
class LruCacheTestCase(unittest.TestCase):
@@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1)
+ cache["key"] = 2 # Make sure overriding works.
+ self.assertEquals(cache.get("key"), 2)
def test_pop(self):
cache = LruCache(1)
@@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)
+
+
+class LruCacheCallbacksTestCase(unittest.TestCase):
+ def test_get(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.get("key", "value")
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_multi_get(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_set(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value", m)
+ self.assertFalse(m.called)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_pop(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value", m)
+ self.assertFalse(m.called)
+
+ cache.pop("key")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ cache.pop("key")
+ self.assertEquals(m.call_count, 1)
+
+ def test_del_multi(self):
+ m1 = Mock()
+ m2 = Mock()
+ m3 = Mock()
+ m4 = Mock()
+ cache = LruCache(4, 2, cache_type=TreeCache)
+
+ cache.set(("a", "1"), "value", m1)
+ cache.set(("a", "2"), "value", m2)
+ cache.set(("b", "1"), "value", m3)
+ cache.set(("b", "2"), "value", m4)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+ self.assertEquals(m4.call_count, 0)
+
+ cache.del_multi(("a",))
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 1)
+ self.assertEquals(m3.call_count, 0)
+ self.assertEquals(m4.call_count, 0)
+
+ def test_clear(self):
+ m1 = Mock()
+ m2 = Mock()
+ cache = LruCache(5)
+
+ cache.set("key1", "value", m1)
+ cache.set("key2", "value", m2)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+
+ cache.clear()
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 1)
+
+ def test_eviction(self):
+ m1 = Mock(name="m1")
+ m2 = Mock(name="m2")
+ m3 = Mock(name="m3")
+ cache = LruCache(2)
+
+ cache.set("key1", "value", m1)
+ cache.set("key2", "value", m2)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key3", "value", m3)
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key3", "value")
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.get("key2")
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key1", "value", m1)
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 1)