summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2019-12-01 11:02:18 +0100
committerAndrej Shadura <andrewsh@debian.org>2019-12-01 11:02:18 +0100
commit1c1b52e4c5d41ad1a138c34d27994c052a55b83b (patch)
tree194a31b7847880b75cb52345f6535f2df4dff5c2
parented93259bb49a2b1808408b55dcb72883a8bc4d33 (diff)
parent11311f62092daee712cb2aac8b0333fd0904bdfd (diff)
Merge branch 'debian/master' into debian/buster-backports
-rw-r--r--.buildkite/postgres-config.yaml21
-rwxr-xr-x.buildkite/scripts/create_postgres_db.py36
-rwxr-xr-x.buildkite/scripts/test_synapse_port_db.sh36
-rw-r--r--.buildkite/sqlite-config.yaml18
-rw-r--r--.buildkite/test_db.dbbin0 -> 18825216 bytes
-rw-r--r--.github/PULL_REQUEST_TEMPLATE.md1
-rw-r--r--CHANGES.md107
-rw-r--r--CONTRIBUTING.rst19
-rw-r--r--INSTALL.md18
-rw-r--r--contrib/experiments/test_messaging.py4
-rw-r--r--debian/changelog12
-rw-r--r--debian/copyright4
-rw-r--r--docker/README.md2
-rwxr-xr-xdocker/start.py11
-rw-r--r--docs/CAPTCHA_SETUP.md6
-rw-r--r--docs/sample_config.yaml12
-rw-r--r--docs/tcp_replication.md15
-rw-r--r--mypy.ini11
-rwxr-xr-xscripts-dev/build_debian_packages2
-rwxr-xr-xscripts-dev/lint.sh14
-rwxr-xr-xscripts-dev/update_database124
-rwxr-xr-xscripts/move_remote_media_to_new_store.py2
-rwxr-xr-xscripts/synapse_port_db6
-rw-r--r--snap/snapcraft.yaml20
-rw-r--r--synapse/__init__.py4
-rw-r--r--synapse/_scripts/register_new_matrix_user.py6
-rw-r--r--synapse/api/auth.py2
-rw-r--r--synapse/api/constants.py7
-rw-r--r--synapse/api/errors.py2
-rw-r--r--synapse/api/filtering.py15
-rw-r--r--synapse/app/__init__.py4
-rw-r--r--synapse/app/appservice.py4
-rw-r--r--synapse/app/client_reader.py4
-rw-r--r--synapse/app/event_creator.py4
-rw-r--r--synapse/app/federation_reader.py4
-rw-r--r--synapse/app/federation_sender.py4
-rw-r--r--synapse/app/frontend_proxy.py4
-rw-r--r--synapse/app/homeserver.py174
-rw-r--r--synapse/app/media_repository.py4
-rw-r--r--synapse/app/pusher.py4
-rw-r--r--synapse/app/synchrotron.py4
-rw-r--r--synapse/app/user_dir.py4
-rw-r--r--synapse/appservice/__init__.py4
-rw-r--r--synapse/config/captcha.py4
-rw-r--r--synapse/config/emailconfig.py2
-rw-r--r--synapse/config/key.py4
-rw-r--r--synapse/config/logger.py6
-rw-r--r--synapse/config/registration.py2
-rw-r--r--synapse/config/server.py6
-rw-r--r--synapse/crypto/event_signing.py6
-rw-r--r--synapse/event_auth.py2
-rw-r--r--synapse/events/snapshot.py318
-rw-r--r--synapse/events/spamcheck.py14
-rw-r--r--synapse/federation/federation_base.py6
-rw-r--r--synapse/federation/federation_client.py18
-rw-r--r--synapse/federation/federation_server.py239
-rw-r--r--synapse/federation/send_queue.py4
-rw-r--r--synapse/federation/sender/per_destination_queue.py19
-rw-r--r--synapse/federation/sender/transaction_manager.py4
-rw-r--r--synapse/federation/transport/__init__.py4
-rw-r--r--synapse/federation/transport/client.py10
-rw-r--r--synapse/federation/transport/server.py10
-rw-r--r--synapse/groups/attestations.py2
-rw-r--r--synapse/groups/groups_server.py2
-rw-r--r--synapse/handlers/account_data.py7
-rw-r--r--synapse/handlers/admin.py7
-rw-r--r--synapse/handlers/appservice.py5
-rw-r--r--synapse/handlers/auth.py94
-rw-r--r--synapse/handlers/deactivate_account.py3
-rw-r--r--synapse/handlers/device.py20
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/directory.py4
-rw-r--r--synapse/handlers/e2e_keys.py174
-rw-r--r--synapse/handlers/events.py6
-rw-r--r--synapse/handlers/federation.py140
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py6
-rw-r--r--synapse/handlers/initial_sync.py18
-rw-r--r--synapse/handlers/message.py29
-rw-r--r--synapse/handlers/pagination.py25
-rw-r--r--synapse/handlers/profile.py8
-rw-r--r--synapse/handlers/read_marker.py13
-rw-r--r--synapse/handlers/receipts.py37
-rw-r--r--synapse/handlers/register.py56
-rw-r--r--synapse/handlers/room.py45
-rw-r--r--synapse/handlers/room_member.py125
-rw-r--r--synapse/handlers/search.py24
-rw-r--r--synapse/handlers/stats.py5
-rw-r--r--synapse/handlers/sync.py36
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/handlers/ui_auth/checkers.py2
-rw-r--r--synapse/http/client.py21
-rw-r--r--synapse/http/connectproxyclient.py195
-rw-r--r--synapse/http/federation/srv_resolver.py2
-rw-r--r--synapse/http/matrixfederationclient.py12
-rw-r--r--synapse/http/proxyagent.py195
-rw-r--r--synapse/http/request_metrics.py2
-rw-r--r--synapse/http/server.py2
-rw-r--r--synapse/http/servlet.py4
-rw-r--r--synapse/http/site.py4
-rw-r--r--synapse/logging/_structured.py2
-rw-r--r--synapse/logging/_terse_json.py2
-rw-r--r--synapse/logging/context.py2
-rw-r--r--synapse/notifier.py6
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py9
-rw-r--r--synapse/push/emailpusher.py14
-rw-r--r--synapse/push/httppusher.py23
-rw-r--r--synapse/push/mailer.py3
-rw-r--r--synapse/push/push_rule_evaluator.py4
-rw-r--r--synapse/push/push_tools.py9
-rw-r--r--synapse/push/pusherpool.py4
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/replication/http/_base.py8
-rw-r--r--synapse/replication/http/federation.py24
-rw-r--r--synapse/replication/http/login.py7
-rw-r--r--synapse/replication/http/membership.py16
-rw-r--r--synapse/replication/http/register.py14
-rw-r--r--synapse/replication/http/send_event.py7
-rw-r--r--synapse/replication/slave/storage/_base.py10
-rw-r--r--synapse/replication/slave/storage/devices.py13
-rw-r--r--synapse/replication/tcp/client.py22
-rw-r--r--synapse/replication/tcp/protocol.py76
-rw-r--r--synapse/replication/tcp/streams/__init__.py1
-rw-r--r--synapse/replication/tcp/streams/_base.py18
-rw-r--r--synapse/rest/admin/__init__.py569
-rw-r--r--synapse/rest/admin/groups.py46
-rw-r--r--synapse/rest/admin/rooms.py157
-rw-r--r--synapse/rest/admin/users.py406
-rw-r--r--synapse/rest/client/v1/login.py128
-rw-r--r--synapse/rest/client/v1/room.py166
-rw-r--r--synapse/rest/client/v2_alpha/account.py18
-rw-r--r--synapse/rest/client/v2_alpha/read_marker.py13
-rw-r--r--synapse/rest/client/v2_alpha/receipts.py11
-rw-r--r--synapse/rest/client/v2_alpha/register.py14
-rw-r--r--synapse/rest/client/v2_alpha/sync.py13
-rw-r--r--synapse/rest/client/versions.py3
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py2
-rw-r--r--synapse/rest/media/v1/media_repository.py12
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py30
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py4
-rw-r--r--synapse/server.py23
-rw-r--r--synapse/server.pyi25
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py2
-rw-r--r--synapse/spam_checker_api/__init__.py51
-rw-r--r--synapse/state/__init__.py182
-rw-r--r--synapse/storage/__init__.py23
-rw-r--r--synapse/storage/_base.py15
-rw-r--r--synapse/storage/background_updates.py9
-rw-r--r--synapse/storage/data_stores/__init__.py12
-rw-r--r--synapse/storage/data_stores/main/__init__.py9
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py17
-rw-r--r--synapse/storage/data_stores/main/devices.py109
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py8
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py24
-rw-r--r--synapse/storage/data_stores/main/event_federation.py4
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py2
-rw-r--r--synapse/storage/data_stores/main/events.py1121
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py72
-rw-r--r--synapse/storage/data_stores/main/group_server.py19
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py2
-rw-r--r--synapse/storage/data_stores/main/push_rule.py2
-rw-r--r--synapse/storage/data_stores/main/pusher.py2
-rw-r--r--synapse/storage/data_stores/main/registration.py23
-rw-r--r--synapse/storage/data_stores/main/roommember.py4
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql25
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_labels.sql30
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql17
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite42
-rw-r--r--synapse/storage/data_stores/main/search.py4
-rw-r--r--synapse/storage/data_stores/main/state.py58
-rw-r--r--synapse/storage/data_stores/main/stats.py6
-rw-r--r--synapse/storage/data_stores/main/stream.py47
-rw-r--r--synapse/storage/persist_events.py649
-rw-r--r--synapse/storage/purge_events.py117
-rw-r--r--synapse/storage/state.py233
-rw-r--r--synapse/storage/util/id_generators.py2
-rw-r--r--synapse/util/async_helpers.py18
-rw-r--r--synapse/util/caches/__init__.py2
-rw-r--r--synapse/util/caches/descriptors.py57
-rw-r--r--synapse/util/httpresourcetree.py2
-rw-r--r--synapse/util/metrics.py6
-rw-r--r--synapse/util/rlimit.py2
-rw-r--r--synapse/util/versionstring.py10
-rw-r--r--synapse/visibility.py30
-rw-r--r--tests/api/test_filtering.py43
-rw-r--r--tests/crypto/test_keyring.py4
-rw-r--r--tests/handlers/test_federation.py126
-rw-r--r--tests/handlers/test_typing.py7
-rw-r--r--tests/http/__init__.py17
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py11
-rw-r--r--tests/http/test_proxyagent.py334
-rw-r--r--tests/push/test_http.py2
-rw-r--r--tests/replication/slave/storage/_base.py1
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/rest/admin/test_admin.py82
-rw-r--r--tests/rest/client/v1/test_rooms.py101
-rw-r--r--tests/rest/client/v1/utils.py15
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py143
-rw-r--r--tests/server.py30
-rw-r--r--tests/storage/test__base.py2
-rw-r--r--tests/storage/test_devices.py16
-rw-r--r--tests/storage/test_e2e_room_keys.py75
-rw-r--r--tests/storage/test_purge.py15
-rw-r--r--tests/storage/test_redaction.py11
-rw-r--r--tests/storage/test_room.py3
-rw-r--r--tests/storage/test_roommember.py3
-rw-r--r--tests/storage/test_state.py153
-rw-r--r--tests/test_federation.py15
-rw-r--r--tests/test_phone_home.py51
-rw-r--r--tests/test_state.py64
-rw-r--r--tests/test_visibility.py18
-rw-r--r--tests/util/caches/test_descriptors.py8
-rw-r--r--tests/utils.py10
-rw-r--r--tox.ini10
214 files changed, 6337 insertions, 2822 deletions
diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml
new file mode 100644
index 00000000..a35fec39
--- /dev/null
+++ b/.buildkite/postgres-config.yaml
@@ -0,0 +1,21 @@
+# Configuration file used for testing the 'synapse_port_db' script.
+# Tells the script to connect to the postgresql database that will be available in the
+# CI's Docker setup at the point where this file is considered.
+server_name: "test"
+
+signing_key_path: "/src/.buildkite/test.signing.key"
+
+report_stats: false
+
+database:
+ name: "psycopg2"
+ args:
+ user: postgres
+ host: postgres
+ password: postgres
+ database: synapse
+
+# Suppress the key server warning.
+trusted_key_servers:
+ - server_name: "matrix.org"
+suppress_key_server_warning: true
diff --git a/.buildkite/scripts/create_postgres_db.py b/.buildkite/scripts/create_postgres_db.py
new file mode 100755
index 00000000..df6082b0
--- /dev/null
+++ b/.buildkite/scripts/create_postgres_db.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from synapse.storage.engines import create_engine
+
+logger = logging.getLogger("create_postgres_db")
+
+if __name__ == "__main__":
+ # Create a PostgresEngine.
+ db_engine = create_engine({"name": "psycopg2", "args": {}})
+
+ # Connect to postgres to create the base database.
+ # We use "postgres" as a database because it's bound to exist and the "synapse" one
+ # doesn't exist yet.
+ db_conn = db_engine.module.connect(
+ user="postgres", host="postgres", password="postgres", dbname="postgres"
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+ cur.execute("CREATE DATABASE synapse;")
+ cur.close()
+ db_conn.close()
diff --git a/.buildkite/scripts/test_synapse_port_db.sh b/.buildkite/scripts/test_synapse_port_db.sh
new file mode 100755
index 00000000..9ed21776
--- /dev/null
+++ b/.buildkite/scripts/test_synapse_port_db.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+#
+# Test script for 'synapse_port_db', which creates a virtualenv, installs Synapse along
+# with additional dependencies needed for the test (such as coverage or the PostgreSQL
+# driver), update the schema of the test SQLite database and run background updates on it,
+# create an empty test database in PostgreSQL, then run the 'synapse_port_db' script to
+# test porting the SQLite database to the PostgreSQL database (with coverage).
+
+set -xe
+cd `dirname $0`/../..
+
+echo "--- Install dependencies"
+
+# Install dependencies for this test.
+pip install psycopg2 coverage coverage-enable-subprocess
+
+# Install Synapse itself. This won't update any libraries.
+pip install -e .
+
+echo "--- Generate the signing key"
+
+# Generate the server's signing key.
+python -m synapse.app.homeserver --generate-keys -c .buildkite/sqlite-config.yaml
+
+echo "--- Prepare the databases"
+
+# Make sure the SQLite3 database is using the latest schema and has no pending background update.
+scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml
+
+# Create the PostgreSQL database.
+./.buildkite/scripts/create_postgres_db.py
+
+echo "+++ Run synapse_port_db"
+
+# Run the script
+coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml
diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml
new file mode 100644
index 00000000..635b9217
--- /dev/null
+++ b/.buildkite/sqlite-config.yaml
@@ -0,0 +1,18 @@
+# Configuration file used for testing the 'synapse_port_db' script.
+# Tells the 'update_database' script to connect to the test SQLite database to upgrade its
+# schema and run background updates on it.
+server_name: "test"
+
+signing_key_path: "/src/.buildkite/test.signing.key"
+
+report_stats: false
+
+database:
+ name: "sqlite3"
+ args:
+ database: ".buildkite/test_db.db"
+
+# Suppress the key server warning.
+trusted_key_servers:
+ - server_name: "matrix.org"
+suppress_key_server_warning: true
diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db
new file mode 100644
index 00000000..f20567ba
--- /dev/null
+++ b/.buildkite/test_db.db
Binary files differ
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 1ead0d00..8939fda6 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -5,3 +5,4 @@
* [ ] Pull request is based on the develop branch
* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog)
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off)
+* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style))
diff --git a/CHANGES.md b/CHANGES.md
index 9312dc29..a9afd36d 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,110 @@
+Synapse 1.6.1 (2019-11-28)
+==========================
+
+Security updates
+----------------
+
+This release includes a security fix ([\#6426](https://github.com/matrix-org/synapse/issues/6426), below). Administrators are encouraged to upgrade as soon as possible.
+
+Bugfixes
+--------
+
+- Clean up local threepids from user on account deactivation. ([\#6426](https://github.com/matrix-org/synapse/issues/6426))
+- Fix startup error when http proxy is defined. ([\#6421](https://github.com/matrix-org/synapse/issues/6421))
+
+
+Synapse 1.6.0 (2019-11-26)
+==========================
+
+Bugfixes
+--------
+
+- Fix phone home stats reporting. ([\#6418](https://github.com/matrix-org/synapse/issues/6418))
+
+
+Synapse 1.6.0rc2 (2019-11-25)
+=============================
+
+Bugfixes
+--------
+
+- Fix a bug which could cause the background database update hander for event labels to get stuck in a loop raising exceptions. ([\#6407](https://github.com/matrix-org/synapse/issues/6407))
+
+
+Synapse 1.6.0rc1 (2019-11-20)
+=============================
+
+Features
+--------
+
+- Add federation support for cross-signing. ([\#5727](https://github.com/matrix-org/synapse/issues/5727))
+- Increase default room version from 4 to 5, thereby enforcing server key validity period checks. ([\#6220](https://github.com/matrix-org/synapse/issues/6220))
+- Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars. ([\#6238](https://github.com/matrix-org/synapse/issues/6238))
+- Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)). ([\#6301](https://github.com/matrix-org/synapse/issues/6301), [\#6310](https://github.com/matrix-org/synapse/issues/6310), [\#6340](https://github.com/matrix-org/synapse/issues/6340))
+
+
+Bugfixes
+--------
+
+- Fix LruCache callback deduplication for Python 3.8. Contributed by @V02460. ([\#6213](https://github.com/matrix-org/synapse/issues/6213))
+- Remove a room from a server's public rooms list on room upgrade. ([\#6232](https://github.com/matrix-org/synapse/issues/6232), [\#6235](https://github.com/matrix-org/synapse/issues/6235))
+- Delete keys from key backup when deleting backup versions. ([\#6253](https://github.com/matrix-org/synapse/issues/6253))
+- Make notification of cross-signing signatures work with workers. ([\#6254](https://github.com/matrix-org/synapse/issues/6254))
+- Fix exception when remote servers attempt to join a room that they're not allowed to join. ([\#6278](https://github.com/matrix-org/synapse/issues/6278))
+- Prevent errors from appearing on Synapse startup if `git` is not installed. ([\#6284](https://github.com/matrix-org/synapse/issues/6284))
+- Appservice requests will no longer contain a double slash prefix when the appservice url provided ends in a slash. ([\#6306](https://github.com/matrix-org/synapse/issues/6306))
+- Fix `/purge_room` admin API. ([\#6307](https://github.com/matrix-org/synapse/issues/6307))
+- Fix the `hidden` field in the `devices` table for SQLite versions prior to 3.23.0. ([\#6313](https://github.com/matrix-org/synapse/issues/6313))
+- Fix bug which casued rejected events to be persisted with the wrong room state. ([\#6320](https://github.com/matrix-org/synapse/issues/6320))
+- Fix bug where `rc_login` ratelimiting would prematurely kick in. ([\#6335](https://github.com/matrix-org/synapse/issues/6335))
+- Prevent the server taking a long time to start up when guest registration is enabled. ([\#6338](https://github.com/matrix-org/synapse/issues/6338))
+- Fix bug where upgrading a guest account to a full user would fail when account validity is enabled. ([\#6359](https://github.com/matrix-org/synapse/issues/6359))
+- Fix `to_device` stream ID getting reset every time Synapse restarts, which had the potential to cause unable to decrypt errors. ([\#6363](https://github.com/matrix-org/synapse/issues/6363))
+- Fix permission denied error when trying to generate a config file with the docker image. ([\#6389](https://github.com/matrix-org/synapse/issues/6389))
+
+
+Improved Documentation
+----------------------
+
+- Contributor documentation now mentions script to run linters. ([\#6164](https://github.com/matrix-org/synapse/issues/6164))
+- Modify CAPTCHA_SETUP.md to update the terms `private key` and `public key` to `secret key` and `site key` respectively. Contributed by Yash Jipkate. ([\#6257](https://github.com/matrix-org/synapse/issues/6257))
+- Update `INSTALL.md` Email section to talk about `account_threepid_delegates`. ([\#6272](https://github.com/matrix-org/synapse/issues/6272))
+- Fix a small typo in `account_threepid_delegates` configuration option. ([\#6273](https://github.com/matrix-org/synapse/issues/6273))
+
+
+Internal Changes
+----------------
+
+- Add a CI job to test the `synapse_port_db` script. ([\#6140](https://github.com/matrix-org/synapse/issues/6140), [\#6276](https://github.com/matrix-org/synapse/issues/6276))
+- Convert EventContext to an attrs. ([\#6218](https://github.com/matrix-org/synapse/issues/6218))
+- Move `persist_events` out from main data store. ([\#6240](https://github.com/matrix-org/synapse/issues/6240), [\#6300](https://github.com/matrix-org/synapse/issues/6300))
+- Reduce verbosity of user/room stats. ([\#6250](https://github.com/matrix-org/synapse/issues/6250))
+- Reduce impact of debug logging. ([\#6251](https://github.com/matrix-org/synapse/issues/6251))
+- Expose some homeserver functionality to spam checkers. ([\#6259](https://github.com/matrix-org/synapse/issues/6259))
+- Change cache descriptors to always return deferreds. ([\#6263](https://github.com/matrix-org/synapse/issues/6263), [\#6291](https://github.com/matrix-org/synapse/issues/6291))
+- Fix incorrect comment regarding the functionality of an `if` statement. ([\#6269](https://github.com/matrix-org/synapse/issues/6269))
+- Update CI to run `isort` over the `scripts` and `scripts-dev` directories. ([\#6270](https://github.com/matrix-org/synapse/issues/6270))
+- Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated. ([\#6271](https://github.com/matrix-org/synapse/issues/6271), [\#6314](https://github.com/matrix-org/synapse/issues/6314))
+- Port replication http server endpoints to async/await. ([\#6274](https://github.com/matrix-org/synapse/issues/6274))
+- Port room rest handlers to async/await. ([\#6275](https://github.com/matrix-org/synapse/issues/6275))
+- Remove redundant CLI parameters on CI's `flake8` step. ([\#6277](https://github.com/matrix-org/synapse/issues/6277))
+- Port `federation_server.py` to async/await. ([\#6279](https://github.com/matrix-org/synapse/issues/6279))
+- Port receipt and read markers to async/wait. ([\#6280](https://github.com/matrix-org/synapse/issues/6280))
+- Split out state storage into separate data store. ([\#6294](https://github.com/matrix-org/synapse/issues/6294), [\#6295](https://github.com/matrix-org/synapse/issues/6295))
+- Refactor EventContext for clarity. ([\#6298](https://github.com/matrix-org/synapse/issues/6298))
+- Update the version of black used to 19.10b0. ([\#6304](https://github.com/matrix-org/synapse/issues/6304))
+- Add some documentation about worker replication. ([\#6305](https://github.com/matrix-org/synapse/issues/6305))
+- Move admin endpoints into separate files. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#6308](https://github.com/matrix-org/synapse/issues/6308))
+- Document the use of `lint.sh` for code style enforcement & extend it to run on specified paths only. ([\#6312](https://github.com/matrix-org/synapse/issues/6312))
+- Add optional python dependencies and dependant binary libraries to snapcraft packaging. ([\#6317](https://github.com/matrix-org/synapse/issues/6317))
+- Remove the dependency on psutil and replace functionality with the stdlib `resource` module. ([\#6318](https://github.com/matrix-org/synapse/issues/6318), [\#6336](https://github.com/matrix-org/synapse/issues/6336))
+- Improve documentation for EventContext fields. ([\#6319](https://github.com/matrix-org/synapse/issues/6319))
+- Add some checks that we aren't using state from rejected events. ([\#6330](https://github.com/matrix-org/synapse/issues/6330))
+- Add continuous integration for python 3.8. ([\#6341](https://github.com/matrix-org/synapse/issues/6341))
+- Correct spacing/case of various instances of the word "homeserver". ([\#6357](https://github.com/matrix-org/synapse/issues/6357))
+- Temporarily blacklist the failing unit test PurgeRoomTestCase.test_purge_room. ([\#6361](https://github.com/matrix-org/synapse/issues/6361))
+
+
Synapse 1.5.1 (2019-11-06)
==========================
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index a71a4a69..df81f6e5 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -58,10 +58,29 @@ All Matrix projects have a well-defined code-style - and sometimes we've even
got as far as documenting it... For instance, synapse's code style doc lives
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md.
+To facilitate meeting these criteria you can run ``scripts-dev/lint.sh``
+locally. Since this runs the tools listed in the above document, you'll need
+python 3.6 and to install each tool. **Note that the script does not just
+test/check, but also reformats code, so you may wish to ensure any new code is
+committed first**. By default this script checks all files and can take some
+time; if you alter only certain files, you might wish to specify paths as
+arguments to reduce the run-time.
+
Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise.
+Before doing a commit, ensure the changes you've made don't produce
+linting errors. You can do this by running the linters as follows. Ensure to
+commit any files that were corrected.
+
+::
+ # Install the dependencies
+ pip install -U black flake8 isort
+
+ # Run the linter script
+ ./scripts-dev/lint.sh
+
Changelog
~~~~~~~~~
diff --git a/INSTALL.md b/INSTALL.md
index 69e42392..29e0abaf 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -36,7 +36,7 @@ that your email address is probably `user@example.com` rather than
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
-- Python 3.5, 3.6, or 3.7
+- Python 3.5, 3.6, 3.7 or 3.8.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
Synapse is written in Python but some of the libraries it uses are written in
@@ -413,16 +413,18 @@ For a more detailed guide to configuring your server for federation, see
## Email
-It is desirable for Synapse to have the capability to send email. For example,
-this is required to support the 'password reset' feature.
+It is desirable for Synapse to have the capability to send email. This allows
+Synapse to send password reset emails, send verifications when an email address
+is added to a user's account, and send email notifications to users when they
+receive new messages.
To configure an SMTP server for Synapse, modify the configuration section
-headed ``email``, and be sure to have at least the ``smtp_host``, ``smtp_port``
-and ``notif_from`` fields filled out. You may also need to set ``smtp_user``,
-``smtp_pass``, and ``require_transport_security``.
+headed `email`, and be sure to have at least the `smtp_host`, `smtp_port`
+and `notif_from` fields filled out. You may also need to set `smtp_user`,
+`smtp_pass`, and `require_transport_security`.
-If Synapse is not configured with an SMTP server, password reset via email will
- be disabled by default.
+If email is not configured, password reset, registration and notifications via
+email will be disabled.
## Registering a user
diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
index 6b22400a..3bbbcfa1 100644
--- a/contrib/experiments/test_messaging.py
+++ b/contrib/experiments/test_messaging.py
@@ -78,7 +78,7 @@ class InputOutput(object):
m = re.match("^join (\S+)$", line)
if m:
# The `sender` wants to join a room.
- room_name, = m.groups()
+ (room_name,) = m.groups()
self.print_line("%s joining %s" % (self.user, room_name))
self.server.join_room(room_name, self.user, self.user)
# self.print_line("OK.")
@@ -105,7 +105,7 @@ class InputOutput(object):
m = re.match("^backfill (\S+)$", line)
if m:
# we want to backfill a room
- room_name, = m.groups()
+ (room_name,) = m.groups()
self.print_line("backfill %s" % room_name)
self.server.backfill(room_name)
return
diff --git a/debian/changelog b/debian/changelog
index dd326dea..30bba618 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,15 @@
+matrix-synapse (1.6.1-1) unstable; urgency=high
+
+ * New upstream release.
+
+ -- Andrej Shadura <andrewsh@debian.org> Fri, 29 Nov 2019 14:43:09 +0100
+
+matrix-synapse (1.6.0-1) unstable; urgency=medium
+
+ * New upstream release.
+
+ -- Andrej Shadura <andrewsh@debian.org> Tue, 26 Nov 2019 22:19:26 +0100
+
matrix-synapse (1.5.1-1~bpo10+1) buster-backports; urgency=medium
* Rebuild for buster-backports.
diff --git a/debian/copyright b/debian/copyright
index 4b4d571b..3c37f454 100644
--- a/debian/copyright
+++ b/debian/copyright
@@ -5,7 +5,9 @@ Files: *
Copyright:
2014-2017 OpenMarket Ltd
2017 Vector Creations Ltd
- 2017—2018 New Vector Ltd
+ 2017-2018 Vector Creations Ltd
+ 2017—2019 New Vector Ltd
+ 2019 The Matrix.org Foundation C.I.C.
License: Apache-2.0
Files: synapse/config/jwt.py
diff --git a/docker/README.md b/docker/README.md
index 4b712f3f..24dfa77d 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -101,7 +101,7 @@ is suitable for local testing, but for any practical use, you will either need
to use a reverse proxy, or configure Synapse to expose an HTTPS port.
For documentation on using a reverse proxy, see
-https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
+https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.md.
For more information on enabling TLS support in synapse itself, see
https://github.com/matrix-org/synapse/blob/master/INSTALL.md#tls-certificates. Of
diff --git a/docker/start.py b/docker/start.py
index e41ea20e..97fd247f 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -169,11 +169,11 @@ def run_generate_config(environ, ownership):
# log("running %s" % (args, ))
if ownership is not None:
- args = ["su-exec", ownership] + args
- os.execv("/sbin/su-exec", args)
-
# make sure that synapse has perms to write to the data dir.
subprocess.check_output(["chown", ownership, data_dir])
+
+ args = ["su-exec", ownership] + args
+ os.execv("/sbin/su-exec", args)
else:
os.execv("/usr/local/bin/python", args)
@@ -217,8 +217,9 @@ def main(args, environ):
# backwards-compatibility generate-a-config-on-the-fly mode
if "SYNAPSE_CONFIG_PATH" in environ:
error(
- "SYNAPSE_SERVER_NAME and SYNAPSE_CONFIG_PATH are mutually exclusive "
- "except in `generate` or `migrate_config` mode."
+ "SYNAPSE_SERVER_NAME can only be combined with SYNAPSE_CONFIG_PATH "
+ "in `generate` or `migrate_config` mode. To start synapse using a "
+ "config file, unset the SYNAPSE_SERVER_NAME environment variable."
)
config_path = "/compiled/homeserver.yaml"
diff --git a/docs/CAPTCHA_SETUP.md b/docs/CAPTCHA_SETUP.md
index 5f905753..331e5d05 100644
--- a/docs/CAPTCHA_SETUP.md
+++ b/docs/CAPTCHA_SETUP.md
@@ -4,7 +4,7 @@ The captcha mechanism used is Google's ReCaptcha. This requires API keys from Go
## Getting keys
-Requires a public/private key pair from:
+Requires a site/secret key pair from:
<https://developers.google.com/recaptcha/>
@@ -15,8 +15,8 @@ Must be a reCAPTCHA v2 key using the "I'm not a robot" Checkbox option
The keys are a config option on the home server config. If they are not
visible, you can generate them via `--generate-config`. Set the following value:
- recaptcha_public_key: YOUR_PUBLIC_KEY
- recaptcha_private_key: YOUR_PRIVATE_KEY
+ recaptcha_public_key: YOUR_SITE_KEY
+ recaptcha_private_key: YOUR_SECRET_KEY
In addition, you MUST enable captchas via:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 6c81c0db..89615939 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -72,7 +72,7 @@ pid_file: DATADIR/homeserver.pid
# For example, for room version 1, default_room_version should be set
# to "1".
#
-#default_room_version: "4"
+#default_room_version: "5"
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
@@ -287,7 +287,7 @@ listeners:
# Used by phonehome stats to group together related servers.
#server_context: context
-# Resource-constrained Homeserver Settings
+# Resource-constrained homeserver Settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
@@ -743,11 +743,11 @@ uploads_path: "DATADIR/uploads"
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
-# This Home Server's ReCAPTCHA public key.
+# This homeserver's ReCAPTCHA public key.
#
#recaptcha_public_key: "YOUR_PUBLIC_KEY"
-# This Home Server's ReCAPTCHA private key.
+# This homeserver's ReCAPTCHA private key.
#
#recaptcha_private_key: "YOUR_PRIVATE_KEY"
@@ -955,7 +955,7 @@ uploads_path: "DATADIR/uploads"
# If a delegate is specified, the config option public_baseurl must also be filled out.
#
account_threepid_delegates:
- #email: https://example.com # Delegate email sending to example.org
+ #email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Users who register on this homeserver will automatically be joined
@@ -1270,7 +1270,7 @@ password_config:
# smtp_user: "exampleusername"
# smtp_pass: "examplepassword"
# require_transport_security: false
-# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
+# notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
# app_name: Matrix
#
# # Enable email notifications by default
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index e099d8a8..ba9e874d 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -199,7 +199,20 @@ client (C):
#### REPLICATE (C)
- Asks the server to replicate a given stream
+Asks the server to replicate a given stream. The syntax is:
+
+```
+ REPLICATE <stream_name> <token>
+```
+
+Where `<token>` may be either:
+ * a numeric stream_id to stream updates since (exclusive)
+ * `NOW` to stream all subsequent updates.
+
+The `<stream_name>` is the name of a replication stream to subscribe
+to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
+of streams). It can also be `ALL` to subscribe to all known streams,
+in which case the `<token>` must be set to `NOW`.
#### USER_SYNC (C)
diff --git a/mypy.ini b/mypy.ini
index ffadaddc..1d77c0ec 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,8 +1,11 @@
[mypy]
-namespace_packages=True
-plugins=mypy_zope:plugin
-follow_imports=skip
-mypy_path=stubs
+namespace_packages = True
+plugins = mypy_zope:plugin
+follow_imports = normal
+check_untyped_defs = True
+show_error_codes = True
+show_traceback = True
+mypy_path = stubs
[mypy-zope]
ignore_missing_imports = True
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index 93305ee9..84eaec6a 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -20,11 +20,13 @@ from concurrent.futures import ThreadPoolExecutor
DISTS = (
"debian:stretch",
"debian:buster",
+ "debian:bullseye",
"debian:sid",
"ubuntu:xenial",
"ubuntu:bionic",
"ubuntu:cosmic",
"ubuntu:disco",
+ "ubuntu:eoan",
)
DESC = '''\
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 02a2ca39..34c4854e 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -7,7 +7,15 @@
set -e
-isort -y -rc synapse tests scripts-dev scripts
-flake8 synapse tests
-python3 -m black synapse tests scripts-dev scripts
+if [ $# -ge 1 ]
+then
+ files=$*
+else
+ files="synapse tests scripts-dev scripts"
+fi
+
+echo "Linting these locations: $files"
+isort -y -rc $files
+flake8 $files
+python3 -m black $files
./scripts-dev/config-lint.sh
diff --git a/scripts-dev/update_database b/scripts-dev/update_database
new file mode 100755
index 00000000..27a1ad1e
--- /dev/null
+++ b/scripts-dev/update_database
@@ -0,0 +1,124 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import sys
+
+import yaml
+
+from twisted.internet import defer, reactor
+
+from synapse.config.homeserver import HomeServerConfig
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+
+logger = logging.getLogger("update_database")
+
+
+class MockHomeserver(HomeServer):
+ DATASTORE_CLASS = DataStore
+
+ def __init__(self, config, database_engine, db_conn, **kwargs):
+ super(MockHomeserver, self).__init__(
+ config.server_name,
+ reactor=reactor,
+ config=config,
+ database_engine=database_engine,
+ **kwargs
+ )
+
+ self.database_engine = database_engine
+ self.db_conn = db_conn
+
+ def get_db_conn(self):
+ return self.db_conn
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description=(
+ "Updates a synapse database to the latest schema and runs background updates"
+ " on it."
+ )
+ )
+ parser.add_argument("-v", action='store_true')
+ parser.add_argument(
+ "--database-config",
+ type=argparse.FileType('r'),
+ required=True,
+ help="A database config file for either a SQLite3 database or a PostgreSQL one.",
+ )
+
+ args = parser.parse_args()
+
+ logging_config = {
+ "level": logging.DEBUG if args.v else logging.INFO,
+ "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+ }
+
+ logging.basicConfig(**logging_config)
+
+ # Load, process and sanity-check the config.
+ hs_config = yaml.safe_load(args.database_config)
+
+ if "database" not in hs_config:
+ sys.stderr.write("The configuration file must have a 'database' section.\n")
+ sys.exit(4)
+
+ config = HomeServerConfig()
+ config.parse_config_dict(hs_config, "", "")
+
+ # Create the database engine and a connection to it.
+ database_engine = create_engine(config.database_config)
+ db_conn = database_engine.module.connect(
+ **{
+ k: v
+ for k, v in config.database_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ )
+
+ # Update the database to the latest schema.
+ prepare_database(db_conn, database_engine, config=config)
+ db_conn.commit()
+
+ # Instantiate and initialise the homeserver object.
+ hs = MockHomeserver(
+ config,
+ database_engine,
+ db_conn,
+ db_config=config.database_config,
+ )
+ # setup instantiates the store within the homeserver object.
+ hs.setup()
+ store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def run_background_updates():
+ yield store.run_background_updates(sleep=False)
+ # Stop the reactor to exit the script once every background update is run.
+ reactor.stop()
+
+ # Apply all background updates on the database.
+ reactor.callWhenRunning(lambda: run_as_background_process(
+ "background_updates", run_background_updates
+ ))
+
+ reactor.run()
diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py
index 12747c60..b5b63933 100755
--- a/scripts/move_remote_media_to_new_store.py
+++ b/scripts/move_remote_media_to_new_store.py
@@ -72,7 +72,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
# check that the original exists
original_file = src_paths.remote_media_filepath(origin_server, file_id)
if not os.path.exists(original_file):
- logger.warn(
+ logger.warning(
"Original for %s/%s (%s) does not exist",
origin_server,
file_id,
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 54faed1e..0d332168 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -157,7 +157,7 @@ class Store(
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
- logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
+ logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
if i < N:
i += 1
conn.rollback()
@@ -432,7 +432,7 @@ class Porter(object):
for row in rows:
d = dict(zip(headers, row))
if "\0" in d['value']:
- logger.warn('dropping search row %s', d)
+ logger.warning('dropping search row %s', d)
else:
rows_dict.append(d)
@@ -647,7 +647,7 @@ class Porter(object):
if isinstance(col, bytes):
return bytearray(col)
elif isinstance(col, string_types) and "\0" in col:
- logger.warn(
+ logger.warning(
"DROPPING ROW: NUL value in table %s col %s: %r",
table,
headers[j],
diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml
index 1f7df71d..9e644e85 100644
--- a/snap/snapcraft.yaml
+++ b/snap/snapcraft.yaml
@@ -20,3 +20,23 @@ parts:
source: .
plugin: python
python-version: python3
+ python-packages:
+ - '.[all]'
+ build-packages:
+ - libffi-dev
+ - libturbojpeg0-dev
+ - libssl-dev
+ - libxslt1-dev
+ - libpq-dev
+ - zlib1g-dev
+ stage-packages:
+ - libasn1-8-heimdal
+ - libgssapi3-heimdal
+ - libhcrypto4-heimdal
+ - libheimbase1-heimdal
+ - libheimntlm0-heimdal
+ - libhx509-5-heimdal
+ - libkrb5-26-heimdal
+ - libldap-2.4-2
+ - libpq5
+ - libsasl2-2
diff --git a/synapse/__init__.py b/synapse/__init__.py
index ec16f54a..f99de2f3 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-""" This is a reference implementation of a Matrix home server.
+""" This is a reference implementation of a Matrix homeserver.
"""
import os
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.5.1"
+__version__ = "1.6.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index bdcd915b..d528450c 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -144,8 +144,8 @@ def main():
logging.captureWarnings(True)
parser = argparse.ArgumentParser(
- description="Used to register new users with a given home server when"
- " registration has been disabled. The home server must be"
+ description="Used to register new users with a given homeserver when"
+ " registration has been disabled. The homeserver must be"
" configured with the 'registration_shared_secret' option"
" set."
)
@@ -202,7 +202,7 @@ def main():
"server_url",
default="https://localhost:8448",
nargs="?",
- help="URL to use to talk to the home server. Defaults to "
+ help="URL to use to talk to the homeserver. Defaults to "
" 'https://localhost:8448'.",
)
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 53f3bb0f..5d0b7d28 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -497,7 +497,7 @@ class Auth(object):
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
- logger.warn("Unrecognised appservice access token.")
+ logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
return defer.succeed(service)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 31219667..49c4b850 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -138,3 +138,10 @@ class LimitBlockingTypes(object):
MONTHLY_ACTIVE_USER = "monthly_active_user"
HS_DISABLED = "hs_disabled"
+
+
+class EventContentFields(object):
+ """Fields found in events' content, regardless of type."""
+
+ # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
+ LABELS = "org.matrix.labels"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cca92c34..5853a54c 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -457,7 +457,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
class FederationError(RuntimeError):
- """ This class is used to inform remote home servers about erroneous
+ """ This class is used to inform remote homeservers about erroneous
PDUs they sent us.
FATAL: The remote server could not interpret the source event.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 9f06556b..bec13f08 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -20,6 +20,7 @@ from jsonschema import FormatChecker
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
from synapse.types import RoomID, UserID
@@ -66,6 +67,10 @@ ROOM_EVENT_FILTER_SCHEMA = {
"contains_url": {"type": "boolean"},
"lazy_load_members": {"type": "boolean"},
"include_redundant_members": {"type": "boolean"},
+ # Include or exclude events with the provided labels.
+ # cf https://github.com/matrix-org/matrix-doc/pull/2326
+ "org.matrix.labels": {"type": "array", "items": {"type": "string"}},
+ "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
},
}
@@ -259,6 +264,9 @@ class Filter(object):
self.contains_url = self.filter_json.get("contains_url", None)
+ self.labels = self.filter_json.get("org.matrix.labels", None)
+ self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
+
def filters_all_types(self):
return "*" in self.not_types
@@ -282,6 +290,7 @@ class Filter(object):
room_id = None
ev_type = "m.presence"
contains_url = False
+ labels = []
else:
sender = event.get("sender", None)
if not sender:
@@ -300,10 +309,11 @@ class Filter(object):
content = event.get("content", {})
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), text_type)
+ labels = content.get(EventContentFields.LABELS, [])
- return self.check_fields(room_id, sender, ev_type, contains_url)
+ return self.check_fields(room_id, sender, ev_type, labels, contains_url)
- def check_fields(self, room_id, sender, event_type, contains_url):
+ def check_fields(self, room_id, sender, event_type, labels, contains_url):
"""Checks whether the filter matches the given event fields.
Returns:
@@ -313,6 +323,7 @@ class Filter(object):
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
"types": lambda v: _matches_wildcard(event_type, v),
+ "labels": lambda v: v in labels,
}
for name, match_func in literal_keys.items():
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index d877c778..a01bac29 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -44,6 +44,8 @@ def check_bind_error(e, address, bind_addresses):
bind_addresses (list): Addresses on which the service listens.
"""
if address == "0.0.0.0" and "::" in bind_addresses:
- logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]")
+ logger.warning(
+ "Failed to listen on 0.0.0.0, continuing because listening on [::]"
+ )
else:
raise e
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 767b87d2..02b900f3 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -94,7 +94,7 @@ class AppserviceServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -103,7 +103,7 @@ class AppserviceServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index dbcc414c..dadb487d 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -153,7 +153,7 @@ class ClientReaderServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -162,7 +162,7 @@ class ClientReaderServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index f20d810e..d110599a 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -147,7 +147,7 @@ class EventCreatorServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -156,7 +156,7 @@ class EventCreatorServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 1ef027a8..418c0862 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -132,7 +132,7 @@ class FederationReaderServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -141,7 +141,7 @@ class FederationReaderServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 04fbb407..139221ad 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -123,7 +123,7 @@ class FederationSenderServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -132,7 +132,7 @@ class FederationSenderServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 9504bfbc..e647459d 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -204,7 +204,7 @@ class FrontendProxyServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -213,7 +213,7 @@ class FrontendProxyServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index eb54f568..883b3fb7 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -19,12 +19,13 @@ from __future__ import print_function
import gc
import logging
+import math
import os
+import resource
import sys
from six import iteritems
-import psutil
from prometheus_client import Gauge
from twisted.application import service
@@ -282,7 +283,7 @@ class SynapseHomeServer(HomeServer):
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -291,7 +292,7 @@ class SynapseHomeServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
@@ -471,6 +472,87 @@ class SynapseService(service.Service):
return self._port.stopListening()
+# Contains the list of processes we will be monitoring
+# currently either 0 or 1
+_stats_process = []
+
+
+@defer.inlineCallbacks
+def phone_stats_home(hs, stats, stats_process=_stats_process):
+ logger.info("Gathering stats for reporting")
+ now = int(hs.get_clock().time())
+ uptime = int(now - hs.start_time)
+ if uptime < 0:
+ uptime = 0
+
+ stats["homeserver"] = hs.config.server_name
+ stats["server_context"] = hs.config.server_context
+ stats["timestamp"] = now
+ stats["uptime_seconds"] = uptime
+ version = sys.version_info
+ stats["python_version"] = "{}.{}.{}".format(
+ version.major, version.minor, version.micro
+ )
+ stats["total_users"] = yield hs.get_datastore().count_all_users()
+
+ total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+ stats["total_nonbridged_users"] = total_nonbridged_users
+
+ daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
+ for name, count in iteritems(daily_user_type_results):
+ stats["daily_user_type_" + name] = count
+
+ room_count = yield hs.get_datastore().get_room_count()
+ stats["total_room_count"] = room_count
+
+ stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+
+ r30_results = yield hs.get_datastore().count_r30_users()
+ for name, count in iteritems(r30_results):
+ stats["r30_users_" + name] = count
+
+ daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = CACHE_SIZE_FACTOR
+ stats["event_cache_size"] = hs.config.event_cache_size
+
+ #
+ # Performance statistics
+ #
+ old = stats_process[0]
+ new = (now, resource.getrusage(resource.RUSAGE_SELF))
+ stats_process[0] = new
+
+ # Get RSS in bytes
+ stats["memory_rss"] = new[1].ru_maxrss
+
+ # Get CPU time in % of a single core, not % of all cores
+ used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
+ old[1].ru_utime + old[1].ru_stime
+ )
+ if used_cpu_time == 0 or new[0] == old[0]:
+ stats["cpu_average"] = 0
+ else:
+ stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
+
+ #
+ # Database version
+ #
+
+ stats["database_engine"] = hs.get_datastore().database_engine_name
+ stats["database_server_version"] = hs.get_datastore().get_server_version()
+ logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
+ try:
+ yield hs.get_proxied_http_client().put_json(
+ hs.config.report_stats_endpoint, stats
+ )
+ except Exception as e:
+ logger.warning("Error reporting stats: %s", e)
+
+
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
@@ -497,91 +579,19 @@ def run(hs):
reactor.run = profile(reactor.run)
clock = hs.get_clock()
- start_time = clock.time()
stats = {}
- # Contains the list of processes we will be monitoring
- # currently either 0 or 1
- stats_process = []
+ def performance_stats_init():
+ _stats_process.clear()
+ _stats_process.append(
+ (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
+ )
def start_phone_stats_home():
- return run_as_background_process("phone_stats_home", phone_stats_home)
-
- @defer.inlineCallbacks
- def phone_stats_home():
- logger.info("Gathering stats for reporting")
- now = int(hs.get_clock().time())
- uptime = int(now - start_time)
- if uptime < 0:
- uptime = 0
-
- stats["homeserver"] = hs.config.server_name
- stats["server_context"] = hs.config.server_context
- stats["timestamp"] = now
- stats["uptime_seconds"] = uptime
- version = sys.version_info
- stats["python_version"] = "{}.{}.{}".format(
- version.major, version.minor, version.micro
- )
- stats["total_users"] = yield hs.get_datastore().count_all_users()
-
- total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
- stats["total_nonbridged_users"] = total_nonbridged_users
-
- daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
- for name, count in iteritems(daily_user_type_results):
- stats["daily_user_type_" + name] = count
-
- room_count = yield hs.get_datastore().get_room_count()
- stats["total_room_count"] = room_count
-
- stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
- stats[
- "daily_active_rooms"
- ] = yield hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
-
- r30_results = yield hs.get_datastore().count_r30_users()
- for name, count in iteritems(r30_results):
- stats["r30_users_" + name] = count
-
- daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
- stats["cache_factor"] = CACHE_SIZE_FACTOR
- stats["event_cache_size"] = hs.config.event_cache_size
-
- if len(stats_process) > 0:
- stats["memory_rss"] = 0
- stats["cpu_average"] = 0
- for process in stats_process:
- stats["memory_rss"] += process.memory_info().rss
- stats["cpu_average"] += int(process.cpu_percent(interval=None))
-
- stats["database_engine"] = hs.get_datastore().database_engine_name
- stats["database_server_version"] = hs.get_datastore().get_server_version()
- logger.info(
- "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
+ return run_as_background_process(
+ "phone_stats_home", phone_stats_home, hs, stats
)
- try:
- yield hs.get_simple_http_client().put_json(
- hs.config.report_stats_endpoint, stats
- )
- except Exception as e:
- logger.warn("Error reporting stats: %s", e)
-
- def performance_stats_init():
- try:
- process = psutil.Process()
- # Ensure we can fetch both, and make the initial request for cpu_percent
- # so the next request will use this as the initial point.
- process.memory_info().rss
- process.cpu_percent(interval=None)
- logger.info("report_stats can use psutil")
- stats_process.append(process)
- except (AttributeError):
- logger.warning("Unable to read memory/cpu stats. Disabling reporting.")
def generate_user_daily_visit_stats():
return run_as_background_process(
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 6bc7202f..2c6dd3ef 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -120,7 +120,7 @@ class MediaRepositoryServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -129,7 +129,7 @@ class MediaRepositoryServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index d84732ee..01a5ffc3 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -114,7 +114,7 @@ class PusherServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -123,7 +123,7 @@ class PusherServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 6a7e2fa7..b14da09f 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -326,7 +326,7 @@ class SynchrotronServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -335,7 +335,7 @@ class SynchrotronServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index a5d6dc79..6cb10031 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -150,7 +150,7 @@ class UserDirectoryServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -159,7 +159,7 @@ class UserDirectoryServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 33b35794..aea3985a 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -94,7 +94,9 @@ class ApplicationService(object):
ip_range_whitelist=None,
):
self.token = token
- self.url = url
+ self.url = (
+ url.rstrip("/") if isinstance(url, str) else None
+ ) # url must not end with a slash
self.hs_token = hs_token
self.sender = sender
self.server_name = hostname
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 44bd5c67..f0171bb5 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -35,11 +35,11 @@ class CaptchaConfig(Config):
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
- # This Home Server's ReCAPTCHA public key.
+ # This homeserver's ReCAPTCHA public key.
#
#recaptcha_public_key: "YOUR_PUBLIC_KEY"
- # This Home Server's ReCAPTCHA private key.
+ # This homeserver's ReCAPTCHA private key.
#
#recaptcha_private_key: "YOUR_PRIVATE_KEY"
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 39e7a1dd..43fad0bf 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -305,7 +305,7 @@ class EmailConfig(Config):
# smtp_user: "exampleusername"
# smtp_pass: "examplepassword"
# require_transport_security: false
- # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
+ # notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
# app_name: Matrix
#
# # Enable email notifications by default
diff --git a/synapse/config/key.py b/synapse/config/key.py
index ec5d430a..52ff1b26 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -125,7 +125,7 @@ class KeyConfig(Config):
# if neither trusted_key_servers nor perspectives are given, use the default.
if "perspectives" not in config and "trusted_key_servers" not in config:
- logger.warn(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN)
+ logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN)
key_servers = [{"server_name": "matrix.org"}]
else:
key_servers = config.get("trusted_key_servers", [])
@@ -156,7 +156,7 @@ class KeyConfig(Config):
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
- logger.warn("Config is missing macaroon_secret_key")
+ logger.warning("Config is missing macaroon_secret_key")
seed = bytes(self.signing_key[0])
self.macaroon_secret_key = hashlib.sha256(seed).digest()
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index be92e33f..75bb9047 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -182,7 +182,7 @@ def _reload_stdlib_logging(*args, log_config=None):
logger = logging.getLogger("")
if not log_config:
- logger.warn("Reloaded a blank config?")
+ logger.warning("Reloaded a blank config?")
logging.config.dictConfig(log_config)
@@ -234,8 +234,8 @@ def setup_logging(
# make sure that the first thing we log is a thing we can grep backwards
# for
- logging.warn("***** STARTING SERVER *****")
- logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
+ logging.warning("***** STARTING SERVER *****")
+ logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
return logger
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ab41623b..1f6dac69 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -300,7 +300,7 @@ class RegistrationConfig(Config):
# If a delegate is specified, the config option public_baseurl must also be filled out.
#
account_threepid_delegates:
- #email: https://example.com # Delegate email sending to example.org
+ #email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Users who register on this homeserver will automatically be joined
diff --git a/synapse/config/server.py b/synapse/config/server.py
index d556df30..00d01c43 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -41,7 +41,7 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
-DEFAULT_ROOM_VERSION = "4"
+DEFAULT_ROOM_VERSION = "5"
ROOM_COMPLEXITY_TOO_GREAT = (
"Your homeserver is unable to join rooms this large or complex. "
@@ -721,7 +721,7 @@ class ServerConfig(Config):
# Used by phonehome stats to group together related servers.
#server_context: context
- # Resource-constrained Homeserver Settings
+ # Resource-constrained homeserver Settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
@@ -781,7 +781,7 @@ class ServerConfig(Config):
"--daemonize",
action="store_true",
default=None,
- help="Daemonize the home server",
+ help="Daemonize the homeserver",
)
server_group.add_argument(
"--print-pidfile",
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 694fb2c8..ccaa8a99 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -125,9 +125,11 @@ def compute_event_signature(event_dict, signature_name, signing_key):
redact_json = prune_event_dict(event_dict)
redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None)
- logger.debug("Signing event: %s", encode_canonical_json(redact_json))
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Signing event: %s", encode_canonical_json(redact_json))
redact_json = sign_json(redact_json, signature_name, signing_key)
- logger.debug("Signed event: %s", encode_canonical_json(redact_json))
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Signed event: %s", encode_canonical_json(redact_json))
return redact_json["signatures"]
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index e7b72254..ec3243b2 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -77,7 +77,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
- logger.warn("Trusting event: %s", event.event_id)
+ logger.warning("Trusting event: %s", event.event_id)
return
if event.type == EventTypes.Create:
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index acbcbeec..64e898f4 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,104 +12,125 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Optional, Tuple, Union
from six import iteritems
+import attr
from frozendict import frozendict
from twisted.internet import defer
+from synapse.appservice import ApplicationService
from synapse.logging.context import make_deferred_yieldable, run_in_background
-class EventContext(object):
+@attr.s(slots=True)
+class EventContext:
"""
+ Holds information relevant to persisting an event
+
Attributes:
- state_group (int|None): state group id, if the state has been stored
- as a state group. This is usually only None if e.g. the event is
- an outlier.
- rejected (bool|str): A rejection reason if the event was rejected, else
- False
-
- push_actions (list[(str, list[object])]): list of (user_id, actions)
- tuples
-
- prev_group (int): Previously persisted state group. ``None`` for an
- outlier.
- delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
- (type, state_key) -> event_id. ``None`` for an outlier.
-
- prev_state_events (?): XXX: is this ever set to anything other than
- the empty list?
-
- _current_state_ids (dict[(str, str), str]|None):
- The current state map including the current event. None if outlier
- or we haven't fetched the state from DB yet.
- (type, state_key) -> event_id
+ rejected: A rejection reason if the event was rejected, else False
+
+ _state_group: The ID of the state group for this event. Note that state events
+ are persisted with a state group which includes the new event, so this is
+ effectively the state *after* the event in question.
+
+ For a *rejected* state event, where the state of the rejected event is
+ ignored, this state_group should never make it into the
+ event_to_state_groups table. Indeed, inspecting this value for a rejected
+ state event is almost certainly incorrect.
+
+ For an outlier, where we don't have the state at the event, this will be
+ None.
+
+ Note that this is a private attribute: it should be accessed via
+ the ``state_group`` property.
+
+ state_group_before_event: The ID of the state group representing the state
+ of the room before this event.
+
+ If this is a non-state event, this will be the same as ``state_group``. If
+ it's a state event, it will be the same as ``prev_group``.
+
+ If ``state_group`` is None (ie, the event is an outlier),
+ ``state_group_before_event`` will always also be ``None``.
+
+ prev_group: If it is known, ``state_group``'s prev_group. Note that this being
+ None does not necessarily mean that ``state_group`` does not have
+ a prev_group!
+
+ If the event is a state event, this is normally the same as ``prev_group``.
+
+ If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
+ will always also be ``None``.
+
+ Note that this *not* (necessarily) the state group associated with
+ ``_prev_state_ids``.
+
+ delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
+ and ``state_group``.
+
+ app_service: If this event is being sent by a (local) application service, that
+ app service.
+
+ _current_state_ids: The room state map, including this event - ie, the state
+ in ``state_group``.
- _prev_state_ids (dict[(str, str), str]|None):
- The current state map excluding the current event. None if outlier
- or we haven't fetched the state from DB yet.
(type, state_key) -> event_id
- _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
- been calculated. None if we haven't started calculating yet
+ FIXME: what is this for an outlier? it seems ill-defined. It seems like
+ it could be either {}, or the state we were given by the remote
+ server, depending on $THINGS
- _event_type (str): The type of the event the context is associated with.
- Only set when state has not been fetched yet.
+ Note that this is a private attribute: it should be accessed via
+ ``get_current_state_ids``. _AsyncEventContext impl calculates this
+ on-demand: it will be None until that happens.
- _event_state_key (str|None): The state_key of the event the context is
- associated with. Only set when state has not been fetched yet.
+ _prev_state_ids: The room state map, excluding this event - ie, the state
+ in ``state_group_before_event``. For a non-state
+ event, this will be the same as _current_state_events.
- _prev_state_id (str|None): If the event associated with the context is
- a state event, then `_prev_state_id` is the event_id of the state
- that was replaced.
- Only set when state has not been fetched yet.
- """
+ Note that it is a completely different thing to prev_group!
- __slots__ = [
- "state_group",
- "rejected",
- "prev_group",
- "delta_ids",
- "prev_state_events",
- "app_service",
- "_current_state_ids",
- "_prev_state_ids",
- "_prev_state_id",
- "_event_type",
- "_event_state_key",
- "_fetching_state_deferred",
- ]
-
- def __init__(self):
- self.prev_state_events = []
- self.rejected = False
- self.app_service = None
+ (type, state_key) -> event_id
- @staticmethod
- def with_state(
- state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
- ):
- context = EventContext()
+ FIXME: again, what is this for an outlier?
- # The current state including the current event
- context._current_state_ids = current_state_ids
- # The current state excluding the current event
- context._prev_state_ids = prev_state_ids
- context.state_group = state_group
+ As with _current_state_ids, this is a private attribute. It should be
+ accessed via get_prev_state_ids.
+ """
- context._prev_state_id = None
- context._event_type = None
- context._event_state_key = None
- context._fetching_state_deferred = defer.succeed(None)
+ rejected = attr.ib(default=False, type=Union[bool, str])
+ _state_group = attr.ib(default=None, type=Optional[int])
+ state_group_before_event = attr.ib(default=None, type=Optional[int])
+ prev_group = attr.ib(default=None, type=Optional[int])
+ delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+ app_service = attr.ib(default=None, type=Optional[ApplicationService])
- # A previously persisted state group and a delta between that
- # and this state.
- context.prev_group = prev_group
- context.delta_ids = delta_ids
+ _current_state_ids = attr.ib(
+ default=None, type=Optional[Dict[Tuple[str, str], str]]
+ )
+ _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
- return context
+ @staticmethod
+ def with_state(
+ state_group,
+ state_group_before_event,
+ current_state_ids,
+ prev_state_ids,
+ prev_group=None,
+ delta_ids=None,
+ ):
+ return EventContext(
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ state_group=state_group,
+ state_group_before_event=state_group_before_event,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
+ )
@defer.inlineCallbacks
def serialize(self, event, store):
@@ -137,11 +158,11 @@ class EventContext(object):
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None,
- "state_group": self.state_group,
+ "state_group": self._state_group,
+ "state_group_before_event": self.state_group_before_event,
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
- "prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None,
}
@@ -157,24 +178,18 @@ class EventContext(object):
Returns:
EventContext
"""
- context = EventContext()
-
- # We use the state_group and prev_state_id stuff to pull the
- # current_state_ids out of the DB and construct prev_state_ids.
- context._prev_state_id = input["prev_state_id"]
- context._event_type = input["event_type"]
- context._event_state_key = input["event_state_key"]
-
- context._current_state_ids = None
- context._prev_state_ids = None
- context._fetching_state_deferred = None
-
- context.state_group = input["state_group"]
- context.prev_group = input["prev_group"]
- context.delta_ids = _decode_state_dict(input["delta_ids"])
-
- context.rejected = input["rejected"]
- context.prev_state_events = input["prev_state_events"]
+ context = _AsyncEventContextImpl(
+ # We use the state_group and prev_state_id stuff to pull the
+ # current_state_ids out of the DB and construct prev_state_ids.
+ prev_state_id=input["prev_state_id"],
+ event_type=input["event_type"],
+ event_state_key=input["event_state_key"],
+ state_group=input["state_group"],
+ state_group_before_event=input["state_group_before_event"],
+ prev_group=input["prev_group"],
+ delta_ids=_decode_state_dict(input["delta_ids"]),
+ rejected=input["rejected"],
+ )
app_service_id = input["app_service_id"]
if app_service_id:
@@ -182,29 +197,52 @@ class EventContext(object):
return context
+ @property
+ def state_group(self) -> Optional[int]:
+ """The ID of the state group for this event.
+
+ Note that state events are persisted with a state group which includes the new
+ event, so this is effectively the state *after* the event in question.
+
+ For an outlier, where we don't have the state at the event, this will be None.
+
+ It is an error to access this for a rejected event, since rejected state should
+ not make it into the room state. Accessing this property will raise an exception
+ if ``rejected`` is set.
+ """
+ if self.rejected:
+ raise RuntimeError("Attempt to access state_group of rejected event")
+
+ return self._state_group
+
@defer.inlineCallbacks
def get_current_state_ids(self, store):
- """Gets the current state IDs
+ """
+ Gets the room state map, including this event - ie, the state in ``state_group``
+
+ It is an error to access this for a rejected event, since rejected state should
+ not make it into the room state. This method will raise an exception if
+ ``rejected`` is set.
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
+
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
+ if self.rejected:
+ raise RuntimeError("Attempt to access state_ids of rejected event")
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store
- )
-
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
+ yield self._ensure_fetched(store)
return self._current_state_ids
@defer.inlineCallbacks
def get_prev_state_ids(self, store):
- """Gets the prev state IDs
+ """
+ Gets the room state map, excluding this event.
+
+ For a non-state event, this will be the same as get_current_state_ids().
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
@@ -212,27 +250,64 @@ class EventContext(object):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
-
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store
- )
-
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
+ yield self._ensure_fetched(store)
return self._prev_state_ids
def get_cached_current_state_ids(self):
"""Gets the current state IDs if we have them already cached.
+ It is an error to access this for a rejected event, since rejected state should
+ not make it into the room state. This method will raise an exception if
+ ``rejected`` is set.
+
Returns:
dict[(str, str), str]|None: Returns None if we haven't cached the
state or if state_group is None, which happens when the associated
event is an outlier.
"""
+ if self.rejected:
+ raise RuntimeError("Attempt to access state_ids of rejected event")
return self._current_state_ids
+ def _ensure_fetched(self, store):
+ return defer.succeed(None)
+
+
+@attr.s(slots=True)
+class _AsyncEventContextImpl(EventContext):
+ """
+ An implementation of EventContext which fetches _current_state_ids and
+ _prev_state_ids from the database on demand.
+
+ Attributes:
+
+ _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+ been calculated. None if we haven't started calculating yet
+
+ _event_type (str): The type of the event the context is associated with.
+
+ _event_state_key (str): The state_key of the event the context is
+ associated with.
+
+ _prev_state_id (str|None): If the event associated with the context is
+ a state event, then `_prev_state_id` is the event_id of the state
+ that was replaced.
+ """
+
+ _prev_state_id = attr.ib(default=None)
+ _event_type = attr.ib(default=None)
+ _event_state_key = attr.ib(default=None)
+ _fetching_state_deferred = attr.ib(default=None)
+
+ def _ensure_fetched(self, store):
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store
+ )
+
+ return make_deferred_yieldable(self._fetching_state_deferred)
+
@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
@@ -250,27 +325,6 @@ class EventContext(object):
else:
self._prev_state_ids = self._current_state_ids
- @defer.inlineCallbacks
- def update_state(
- self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
- ):
- """Replace the state in the context
- """
-
- # We need to make sure we wait for any ongoing fetching of state
- # to complete so that the updated state doesn't get clobbered
- if self._fetching_state_deferred:
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
- self.state_group = state_group
- self._prev_state_ids = prev_state_ids
- self.prev_group = prev_group
- self._current_state_ids = current_state_ids
- self.delta_ids = delta_ids
-
- # We need to ensure that that we've marked as having fetched the state
- self._fetching_state_deferred = defer.succeed(None)
-
def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 129771f1..5a907718 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
+
+from synapse.spam_checker_api import SpamCheckerApi
+
class SpamChecker(object):
def __init__(self, hs):
@@ -26,7 +31,14 @@ class SpamChecker(object):
pass
if module is not None:
- self.spam_checker = module(config=config)
+ # Older spam checkers don't accept the `api` argument, so we
+ # try and detect support.
+ spam_args = inspect.getfullargspec(module)
+ if "api" in spam_args.args:
+ api = SpamCheckerApi(hs)
+ self.spam_checker = module(config=config, api=api)
+ else:
+ self.spam_checker = module(config=config)
def check_event_for_spam(self, event):
"""Checks if a given event is considered "spammy" by this server.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 223aace0..0e221832 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -102,7 +102,7 @@ class FederationBase(object):
pass
if not res:
- logger.warn(
+ logger.warning(
"Failed to find copy of %s with valid signature", pdu.event_id
)
@@ -173,7 +173,7 @@ class FederationBase(object):
return redacted_event
if self.spam_checker.check_event_for_spam(pdu):
- logger.warn(
+ logger.warning(
"Event contains spam, redacting %s: %s",
pdu.event_id,
pdu.get_pdu_json(),
@@ -185,7 +185,7 @@ class FederationBase(object):
def errback(failure, pdu):
failure.trap(SynapseError)
with PreserveLoggingContext(ctx):
- logger.warn(
+ logger.warning(
"Signature check failed for %s: %s",
pdu.event_id,
failure.getErrorMessage(),
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 5b22a39b..27f6aff0 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -177,7 +177,7 @@ class FederationClient(FederationBase):
given destination server.
Args:
- dest (str): The remote home server to ask.
+ dest (str): The remote homeserver to ask.
room_id (str): The room_id to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
@@ -196,7 +196,7 @@ class FederationClient(FederationBase):
dest, room_id, extremities, limit
)
- logger.debug("backfill transaction_data=%s", repr(transaction_data))
+ logger.debug("backfill transaction_data=%r", transaction_data)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
@@ -227,7 +227,7 @@ class FederationClient(FederationBase):
one succeeds.
Args:
- destinations (list): Which home servers to query
+ destinations (list): Which homeservers to query
event_id (str): event to fetch
room_version (str): version of the room
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
@@ -312,7 +312,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the room state at a given event from a remote home server.
+ """Requests all of the room state at a given event from a remote homeserver.
Args:
destination (str): The remote homeserver to query for the state.
@@ -522,12 +522,12 @@ class FederationClient(FederationBase):
res = yield callback(destination)
return res
except InvalidResponseError as e:
- logger.warn("Failed to %s via %s: %s", description, destination, e)
+ logger.warning("Failed to %s via %s: %s", description, destination, e)
except HttpResponseException as e:
if not 500 <= e.code < 600:
raise e.to_synapse_error()
else:
- logger.warn(
+ logger.warning(
"Failed to %s via %s: %i %s",
description,
destination,
@@ -535,7 +535,9 @@ class FederationClient(FederationBase):
e.args[0],
)
except Exception:
- logger.warn("Failed to %s via %s", description, destination, exc_info=1)
+ logger.warning(
+ "Failed to %s via %s", description, destination, exc_info=1
+ )
raise SynapseError(502, "Failed to %s via any server" % (description,))
@@ -553,7 +555,7 @@ class FederationClient(FederationBase):
Note that this does not append any events to any graphs.
Args:
- destinations (str): Candidate homeservers which are probably
+ destinations (Iterable[str]): Candidate homeservers which are probably
participating in the room.
room_id (str): The room in which the event will happen.
user_id (str): The user whose membership is being evented.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5fc7c1d6..d942d77a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -21,7 +21,6 @@ from six import iteritems
from canonicaljson import json
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
@@ -86,14 +85,12 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
- @defer.inlineCallbacks
- @log_function
- def on_backfill_request(self, origin, room_id, versions, limit):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_backfill_request(self, origin, room_id, versions, limit):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- pdus = yield self.handler.on_backfill_request(
+ pdus = await self.handler.on_backfill_request(
origin, room_id, versions, limit
)
@@ -101,9 +98,7 @@ class FederationServer(FederationBase):
return 200, res
- @defer.inlineCallbacks
- @log_function
- def on_incoming_transaction(self, origin, transaction_data):
+ async def on_incoming_transaction(self, origin, transaction_data):
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -118,18 +113,17 @@ class FederationServer(FederationBase):
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (
- yield self._transaction_linearizer.queue(
+ await self._transaction_linearizer.queue(
(origin, transaction.transaction_id)
)
):
- result = yield self._handle_incoming_transaction(
+ result = await self._handle_incoming_transaction(
origin, transaction, request_time
)
return result
- @defer.inlineCallbacks
- def _handle_incoming_transaction(self, origin, transaction, request_time):
+ async def _handle_incoming_transaction(self, origin, transaction, request_time):
""" Process an incoming transaction and return the HTTP response
Args:
@@ -140,7 +134,7 @@ class FederationServer(FederationBase):
Returns:
Deferred[(int, object)]: http response code and body
"""
- response = yield self.transaction_actions.have_responded(origin, transaction)
+ response = await self.transaction_actions.have_responded(origin, transaction)
if response:
logger.debug(
@@ -151,7 +145,7 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- # Reject if PDU count > 50 and EDU count > 100
+ # Reject if PDU count > 50 or EDU count > 100
if len(transaction.pdus) > 50 or (
hasattr(transaction, "edus") and len(transaction.edus) > 100
):
@@ -159,7 +153,7 @@ class FederationServer(FederationBase):
logger.info("Transaction PDU or EDU count too large. Returning 400")
response = {}
- yield self.transaction_actions.set_response(
+ await self.transaction_actions.set_response(
origin, transaction, 400, response
)
return 400, response
@@ -195,7 +189,7 @@ class FederationServer(FederationBase):
continue
try:
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue
@@ -221,13 +215,12 @@ class FederationServer(FederationBase):
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
- @defer.inlineCallbacks
- def process_pdus_for_room(room_id):
+ async def process_pdus_for_room(room_id):
logger.debug("Processing PDUs for %s", room_id)
try:
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
- logger.warn("Ignoring PDUs for room %s from banned server", room_id)
+ logger.warning("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
@@ -237,10 +230,10 @@ class FederationServer(FederationBase):
event_id = pdu.event_id
with nested_logging_context(event_id):
try:
- yield self._handle_received_pdu(origin, pdu)
+ await self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {}
except FederationError as e:
- logger.warn("Error handling PDU %s: %s", event_id, e)
+ logger.warning("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)}
except Exception as e:
f = failure.Failure()
@@ -251,36 +244,33 @@ class FederationServer(FederationBase):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- yield concurrently_execute(
+ await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
- yield self.received_edu(origin, edu.edu_type, edu.content)
+ await self.received_edu(origin, edu.edu_type, edu.content)
response = {"pdus": pdu_results}
logger.debug("Returning: %s", str(response))
- yield self.transaction_actions.set_response(origin, transaction, 200, response)
+ await self.transaction_actions.set_response(origin, transaction, 200, response)
return 200, response
- @defer.inlineCallbacks
- def received_edu(self, origin, edu_type, content):
+ async def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
- yield self.registry.on_edu(edu_type, origin, content)
+ await self.registry.on_edu(edu_type, origin, content)
- @defer.inlineCallbacks
- @log_function
- def on_context_state_request(self, origin, room_id, event_id):
+ async def on_context_state_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -289,8 +279,8 @@ class FederationServer(FederationBase):
# in the cache so we could return it without waiting for the linearizer
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
- with (yield self._server_linearizer.queue((origin, room_id))):
- resp = yield self._state_resp_cache.wrap(
+ with (await self._server_linearizer.queue((origin, room_id))):
+ resp = await self._state_resp_cache.wrap(
(room_id, event_id),
self._on_context_state_request_compute,
room_id,
@@ -299,65 +289,60 @@ class FederationServer(FederationBase):
return 200, resp
- @defer.inlineCallbacks
- def on_state_ids_request(self, origin, room_id, event_id):
+ async def on_state_ids_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
- auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
+ state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+ auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
- @defer.inlineCallbacks
- def _on_context_state_request_compute(self, room_id, event_id):
- pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
- auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
+ async def _on_context_state_request_compute(self, room_id, event_id):
+ pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}
- @defer.inlineCallbacks
- @log_function
- def on_pdu_request(self, origin, event_id):
- pdu = yield self.handler.get_persisted_pdu(origin, event_id)
+ async def on_pdu_request(self, origin, event_id):
+ pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu:
return 200, self._transaction_from_pdus([pdu]).get_dict()
else:
return 404, ""
- @defer.inlineCallbacks
- def on_query_request(self, query_type, args):
+ async def on_query_request(self, query_type, args):
received_queries_counter.labels(query_type).inc()
- resp = yield self.registry.on_query(query_type, args)
+ resp = await self.registry.on_query(query_type, args)
return 200, resp
- @defer.inlineCallbacks
- def on_make_join_request(self, origin, room_id, user_id, supported_versions):
+ async def on_make_join_request(self, origin, room_id, user_id, supported_versions):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
if room_version not in supported_versions:
- logger.warn("Room version %s not in %s", room_version, supported_versions)
+ logger.warning(
+ "Room version %s not in %s", room_version, supported_versions
+ )
raise IncompatibleRoomVersionError(room_version=room_version)
- pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
+ pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- @defer.inlineCallbacks
- def on_invite_request(self, origin, content, room_version):
+ async def on_invite_request(self, origin, content, room_version):
if room_version not in KNOWN_ROOM_VERSIONS:
raise SynapseError(
400,
@@ -369,28 +354,27 @@ class FederationServer(FederationBase):
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
- ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+ ret_pdu = await self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
- @defer.inlineCallbacks
- def on_send_join_request(self, origin, content, room_id):
+ async def on_send_join_request(self, origin, content, room_id):
logger.debug("on_send_join_request: content: %s", content)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
- res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+ res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
return (
200,
@@ -402,48 +386,44 @@ class FederationServer(FederationBase):
},
)
- @defer.inlineCallbacks
- def on_make_leave_request(self, origin, room_id, user_id):
+ async def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
- pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id)
+ await self.check_server_matches_acl(origin_host, room_id)
+ pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- @defer.inlineCallbacks
- def on_send_leave_request(self, origin, content, room_id):
+ async def on_send_leave_request(self, origin, content, room_id):
logger.debug("on_send_leave_request: content: %s", content)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
- yield self.handler.on_send_leave_request(origin, pdu)
+ await self.handler.on_send_leave_request(origin, pdu)
return 200, {}
- @defer.inlineCallbacks
- def on_event_auth(self, origin, room_id, event_id):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_event_auth(self, origin, room_id, event_id):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
+ auth_pdus = await self.handler.on_event_auth(event_id)
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
return 200, res
- @defer.inlineCallbacks
- def on_query_auth_request(self, origin, content, room_id, event_id):
+ async def on_query_auth_request(self, origin, content, room_id, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
@@ -462,22 +442,22 @@ class FederationServer(FederationBase):
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
- with (yield self._server_linearizer.queue((origin, room_id))):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [
event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
]
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
+ signed_auth = await self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True, room_version=room_version
)
- ret = yield self.handler.on_query_auth(
+ ret = await self.handler.on_query_auth(
origin,
event_id,
room_id,
@@ -503,16 +483,14 @@ class FederationServer(FederationBase):
return self.on_query_request("user_devices", user_id)
@trace
- @defer.inlineCallbacks
- @log_function
- def on_claim_client_keys(self, origin, content):
+ async def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
- results = yield self.store.claim_e2e_one_time_keys(query)
+ results = await self.store.claim_e2e_one_time_keys(query)
json_result = {}
for user_id, device_keys in results.items():
@@ -536,14 +514,12 @@ class FederationServer(FederationBase):
return {"one_time_keys": json_result}
- @defer.inlineCallbacks
- @log_function
- def on_get_missing_events(
+ async def on_get_missing_events(
self, origin, room_id, earliest_events, latest_events, limit
):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
@@ -553,7 +529,7 @@ class FederationServer(FederationBase):
limit,
)
- missing_events = yield self.handler.on_get_missing_events(
+ missing_events = await self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit
)
@@ -586,8 +562,7 @@ class FederationServer(FederationBase):
destination=None,
)
- @defer.inlineCallbacks
- def _handle_received_pdu(self, origin, pdu):
+ async def _handle_received_pdu(self, origin, pdu):
""" Process a PDU received in a federation /send/ transaction.
If the event is invalid, then this method throws a FederationError.
@@ -640,37 +615,34 @@ class FederationServer(FederationBase):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
- room_version = yield self.store.get_room_version(pdu.room_id)
+ room_version = await self.store.get_room_version(pdu.room_id)
# Check signature.
try:
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
- yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
+ await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
- @defer.inlineCallbacks
- def exchange_third_party_invite(
+ async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed
):
- ret = yield self.handler.exchange_third_party_invite(
+ ret = await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
return ret
- @defer.inlineCallbacks
- def on_exchange_third_party_invite_request(self, room_id, event_dict):
- ret = yield self.handler.on_exchange_third_party_invite_request(
+ async def on_exchange_third_party_invite_request(self, room_id, event_dict):
+ ret = await self.handler.on_exchange_third_party_invite_request(
room_id, event_dict
)
return ret
- @defer.inlineCallbacks
- def check_server_matches_acl(self, server_name, room_id):
+ async def check_server_matches_acl(self, server_name, room_id):
"""Check if the given server is allowed by the server ACLs in the room
Args:
@@ -680,13 +652,13 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
- state_ids = yield self.store.get_current_state_ids(room_id)
+ state_ids = await self.store.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
- acl_event = yield self.store.get_event(acl_event_id)
+ acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
return
@@ -709,7 +681,7 @@ def server_matches_acl_event(server_name, acl_event):
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
- logger.warn("Ignorning non-bool allow_ip_literals flag")
+ logger.warning("Ignorning non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
@@ -723,7 +695,7 @@ def server_matches_acl_event(server_name, acl_event):
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
- logger.warn("Ignorning non-list deny ACL %s", deny)
+ logger.warning("Ignorning non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
@@ -733,7 +705,7 @@ def server_matches_acl_event(server_name, acl_event):
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
- logger.warn("Ignorning non-list allow ACL %s", allow)
+ logger.warning("Ignorning non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
@@ -747,7 +719,7 @@ def server_matches_acl_event(server_name, acl_event):
def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types):
- logger.warn(
+ logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
@@ -799,15 +771,14 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler
- @defer.inlineCallbacks
- def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type)
if not handler:
- logger.warn("No handler registered for EDU type %s", edu_type)
+ logger.warning("No handler registered for EDU type %s", edu_type)
with start_active_span_from_edu(content, "handle_edu"):
try:
- yield handler(origin, content)
+ await handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
@@ -816,7 +787,7 @@ class FederationHandlerRegistry(object):
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
if not handler:
- logger.warn("No handler registered for query type %s", query_type)
+ logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)
@@ -840,7 +811,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
super(ReplicationFederationHandlerRegistry, self).__init__()
- def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type, origin, content):
"""Overrides FederationHandlerRegistry
"""
if not self.config.use_presence and edu_type == "m.presence":
@@ -848,17 +819,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
handler = self.edu_handlers.get(edu_type)
if handler:
- return super(ReplicationFederationHandlerRegistry, self).on_edu(
+ return await super(ReplicationFederationHandlerRegistry, self).on_edu(
edu_type, origin, content
)
- return self._send_edu(edu_type=edu_type, origin=origin, content=content)
+ return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
- def on_query(self, query_type, args):
+ async def on_query(self, query_type, args):
"""Overrides FederationHandlerRegistry
"""
handler = self.query_handlers.get(query_type)
if handler:
- return handler(args)
+ return await handler(args)
- return self._get_query_client(query_type=query_type, args=args)
+ return await self._get_query_client(query_type=query_type, args=args)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 454456a5..ced4925a 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -36,6 +36,8 @@ from six import iteritems
from sortedcontainers import SortedDict
+from twisted.internet import defer
+
from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
@@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object):
receipt (synapse.types.ReadReceipt):
"""
# nothing to do here: the replication listener will handle it.
- pass
+ return defer.succeed(None)
def send_presence(self, states):
"""As per FederationSender
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index cc75c394..a5b36b18 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -192,15 +192,16 @@ class PerDestinationQueue(object):
# We have to keep 2 free slots for presence and rr_edus
limit = MAX_EDUS_PER_TRANSACTION - 2
- device_update_edus, dev_list_id = (
- yield self._get_device_update_edus(limit)
+ device_update_edus, dev_list_id = yield self._get_device_update_edus(
+ limit
)
limit -= len(device_update_edus)
- to_device_edus, device_stream_id = (
- yield self._get_to_device_message_edus(limit)
- )
+ (
+ to_device_edus,
+ device_stream_id,
+ ) = yield self._get_to_device_message_edus(limit)
pending_edus = device_update_edus + to_device_edus
@@ -359,20 +360,20 @@ class PerDestinationQueue(object):
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
- now_stream_id, results = yield self._store.get_devices_by_remote(
+ now_stream_id, results = yield self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit
)
edus = [
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.device_list_update",
+ edu_type=edu_type,
content=content,
)
- for content in results
+ for (edu_type, content) in results
]
- assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+ assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
return (edus, now_stream_id)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 5b6c79c5..67b3e1ab 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -146,7 +146,7 @@ class TransactionManager(object):
if code == 200:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
- logger.warn(
+ logger.warning(
"TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
@@ -155,7 +155,7 @@ class TransactionManager(object):
)
else:
for p in pdus:
- logger.warn(
+ logger.warning(
"TX [%s] {%s} Failed to send event %s",
destination,
txn_id,
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index d9fcc520..5db733af 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -14,9 +14,9 @@
# limitations under the License.
"""The transport layer is responsible for both sending transactions to remote
-home servers and receiving a variety of requests from other home servers.
+homeservers and receiving a variety of requests from other homeservers.
-By default this is done over HTTPS (and all home servers are required to
+By default this is done over HTTPS (and all homeservers are required to
support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol.
"""
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 7b184081..dc95ab21 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -44,7 +44,7 @@ class TransportLayerClient(object):
given event.
Args:
- destination (str): The host name of the remote home server we want
+ destination (str): The host name of the remote homeserver we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
@@ -68,7 +68,7 @@ class TransportLayerClient(object):
given event. Returns the state's event_id's
Args:
- destination (str): The host name of the remote home server we want
+ destination (str): The host name of the remote homeserver we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
@@ -91,7 +91,7 @@ class TransportLayerClient(object):
""" Requests the pdu with give id and origin from the given server.
Args:
- destination (str): The host name of the remote home server we want
+ destination (str): The host name of the remote homeserver we want
to get the state from.
event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
@@ -122,10 +122,10 @@ class TransportLayerClient(object):
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
- "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s",
+ "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
destination,
room_id,
- repr(event_tuples),
+ event_tuples,
str(limit),
)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 0f16f21c..09baa9c5 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -202,7 +202,7 @@ def _parse_auth_header(header_bytes):
sig = strip_quotes(param_dict["sig"])
return origin, key, sig
except Exception as e:
- logger.warn(
+ logger.warning(
"Error parsing auth header '%s': %s",
header_bytes.decode("ascii", "replace"),
e,
@@ -287,10 +287,12 @@ class BaseFederationServlet(object):
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
- logger.warn("authenticate_request failed: missing authentication")
+ logger.warning(
+ "authenticate_request failed: missing authentication"
+ )
raise
except Exception as e:
- logger.warn("authenticate_request failed: %s", e)
+ logger.warning("authenticate_request failed: %s", e)
raise
request_tags = {
@@ -712,7 +714,7 @@ class PublicRoomList(BaseFederationServlet):
This API returns information in the same format as /publicRooms on the
client API, but will only ever include local public rooms and hence is
- intended for consumption by other home servers.
+ intended for consumption by other homeservers.
GET /publicRooms HTTP/1.1
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index dfd7ae04..d950a8b2 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -181,7 +181,7 @@ class GroupAttestionRenewer(object):
elif not self.is_mine_id(user_id):
destination = get_domain_from_id(user_id)
else:
- logger.warn(
+ logger.warning(
"Incorrectly trying to do attestations for user: %r in %r",
user_id,
group_id,
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 8f10b6ad..29e8ffc2 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -488,7 +488,7 @@ class GroupsServerHandler(object):
profile = yield self.profile_handler.get_profile_from_cache(user_id)
user_profile.update(profile)
except Exception as e:
- logger.warn("Error getting profile for %s: %s", user_id, e)
+ logger.warning("Error getting profile for %s: %s", user_id, e)
user_profiles.append(user_profile)
return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 38bc6719..2d7e6df6 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -38,9 +38,10 @@ class AccountDataEventSource(object):
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
)
- account_data, room_account_data = (
- yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
- )
+ (
+ account_data,
+ room_account_data,
+ ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 1a87b588..6407d56f 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -30,6 +30,9 @@ class AdminHandler(BaseHandler):
def __init__(self, hs):
super(AdminHandler, self).__init__(hs)
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+
@defer.inlineCallbacks
def get_whois(self, user):
connections = []
@@ -205,7 +208,7 @@ class AdminHandler(BaseHandler):
from_key = events[-1].internal_metadata.after
- events = yield filter_events_for_client(self.store, user_id, events)
+ events = yield filter_events_for_client(self.storage, user_id, events)
writer.write_events(room_id, events)
@@ -241,7 +244,7 @@ class AdminHandler(BaseHandler):
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
- state = yield self.store.get_state_for_event(event_id)
+ state = yield self.state_store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state)
return writer.finished()
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3e9b2981..fe62f78e 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -73,7 +73,10 @@ class ApplicationServicesHandler(object):
try:
limit = 100
while True:
- upper_bound, events = yield self.store.get_new_events_for_appservice(
+ (
+ upper_bound,
+ events,
+ ) = yield self.store.get_new_events_for_appservice(
self.current_max, limit
)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 333eb306..54a71c49 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -102,8 +102,9 @@ class AuthHandler(BaseHandler):
login_types.append(t)
self._supported_login_types = login_types
- self._account_ratelimiter = Ratelimiter()
- self._failed_attempts_ratelimiter = Ratelimiter()
+ # Ratelimiter for failed auth during UIA. Uses same ratelimit config
+ # as per `rc_login.failed_attempts`.
+ self._failed_uia_attempts_ratelimiter = Ratelimiter()
self._clock = self.hs.get_clock()
@@ -133,12 +134,38 @@ class AuthHandler(BaseHandler):
AuthError if the client has completed a login flow, and it gives
a different user to `requester`
+
+ LimitExceededError if the ratelimiter's failed request count for this
+ user is too high to proceed
+
"""
+ user_id = requester.user.to_string()
+
+ # Check if we should be ratelimited due to too many previous failed attempts
+ self._failed_uia_attempts_ratelimiter.ratelimit(
+ user_id,
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
+ )
+
# build a list of supported flows
flows = [[login_type] for login_type in self._supported_login_types]
- result, params, _ = yield self.check_auth(flows, request_body, clientip)
+ try:
+ result, params, _ = yield self.check_auth(flows, request_body, clientip)
+ except LoginError:
+ # Update the ratelimite to say we failed (`can_do_action` doesn't raise).
+ self._failed_uia_attempts_ratelimiter.can_do_action(
+ user_id,
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
+ )
+ raise
# find the completed login type
for login_type in self._supported_login_types:
@@ -223,7 +250,7 @@ class AuthHandler(BaseHandler):
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
- # on a home server.
+ # on a homeserver.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
session["clientdict"] = clientdict
@@ -501,11 +528,8 @@ class AuthHandler(BaseHandler):
multiple matches
Raises:
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
UserDeactivatedError if a user is found but is deactivated.
"""
- self.ratelimit_login_per_account(user_id)
res = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
return res[0]
@@ -525,7 +549,7 @@ class AuthHandler(BaseHandler):
result = None
if not user_infos:
- logger.warn("Attempted to login as %s but they do not exist", user_id)
+ logger.warning("Attempted to login as %s but they do not exist", user_id)
elif len(user_infos) == 1:
# a single match (possibly not exact)
result = user_infos.popitem()
@@ -534,7 +558,7 @@ class AuthHandler(BaseHandler):
result = (user_id, user_infos[user_id])
else:
# multiple matches, none of them exact
- logger.warn(
+ logger.warning(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id,
@@ -572,8 +596,6 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
"""
if username.startswith("@"):
@@ -581,8 +603,6 @@ class AuthHandler(BaseHandler):
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
- self.ratelimit_login_per_account(qualified_user_id)
-
login_type = login_submission.get("type")
known_login_type = False
@@ -650,15 +670,6 @@ class AuthHandler(BaseHandler):
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
- # unknown username or invalid password.
- self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=True,
- )
-
# We raise a 403 here, but note that if we're doing user-interactive
# login, it turns all LoginErrors into a 401 anyway.
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@@ -710,10 +721,6 @@ class AuthHandler(BaseHandler):
Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
-
- Raises:
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
@@ -728,7 +735,7 @@ class AuthHandler(BaseHandler):
result = yield self.validate_hash(password, password_hash)
if not result:
- logger.warn("Failed password login for user %s", user_id)
+ logger.warning("Failed password login for user %s", user_id)
return None
return user_id
@@ -742,7 +749,7 @@ class AuthHandler(BaseHandler):
auth_api.validate_macaroon(macaroon, "login", user_id)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- self.ratelimit_login_per_account(user_id)
+
yield self.auth.check_auth_blocking(user_id)
return user_id
@@ -810,7 +817,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.
- # We've now moving towards the Home Server being the entity that
+ # We've now moving towards the homeserver being the entity that
# is responsible for validating threepids used for resetting passwords
# on accounts, so in future Synapse will gain knowledge of specific
# types (mediums) of threepid. For now, we still use the existing
@@ -912,35 +919,6 @@ class AuthHandler(BaseHandler):
else:
return defer.succeed(False)
- def ratelimit_login_per_account(self, user_id):
- """Checks whether the process must be stopped because of ratelimiting.
-
- Checks against two ratelimiters: the generic one for login attempts per
- account and the one specific to failed attempts.
-
- Args:
- user_id (unicode): complete @user:id
-
- Raises:
- LimitExceededError if one of the ratelimiters' login requests count
- for this user is too high too proceed.
- """
- self._failed_attempts_ratelimiter.ratelimit(
- user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=False,
- )
-
- self._account_ratelimiter.ratelimit(
- user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_account.per_second,
- burst_count=self.hs.config.rc_login_account.burst_count,
- update=True,
- )
-
@attr.s
class MacaroonGenerator(object):
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 63267a0a..6dedaaff 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -95,6 +95,9 @@ class DeactivateAccountHandler(BaseHandler):
user_id, threepid["medium"], threepid["address"]
)
+ # Remove all 3PIDs this user has bound to the homeserver
+ yield self.store.user_delete_threepids(user_id)
+
# delete any devices belonging to the user, which will also
# delete corresponding access tokens.
yield self._device_handler.delete_all_devices_for_user(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5f23ee44..26ef5e15 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
+ self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler()
@trace
@@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler):
continue
# mapping from event_id -> state_dict
- prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+ prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@@ -458,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+ master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+ self_signing_key = yield self.store.get_e2e_cross_signing_key(
+ user_id, "self_signing"
+ )
+
+ return {
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ "master_key": master_key,
+ "self_signing_key": self_signing_key,
+ }
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -656,7 +668,7 @@ class DeviceListUpdater(object):
except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
# TODO: Remember that we are now out of sync and try again
# later
- logger.warn("Failed to handle device list update for %s", user_id)
+ logger.warning("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
@@ -694,7 +706,7 @@ class DeviceListUpdater(object):
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
- logger.warn(
+ logger.warning(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id,
len(devices),
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 0043cbea..73b9e120 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -52,7 +52,7 @@ class DeviceMessageHandler(object):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
- logger.warn(
+ logger.warning(
"Dropping device message from %r with spoofed sender %r",
origin,
sender_user_id,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 526379c6..69051101 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler):
ignore_backoff=True,
)
except CodeMessageException as e:
- logging.warn("Error retrieving alias")
+ logging.warning("Error retrieving alias")
if e.code == 404:
result = None
else:
@@ -283,7 +283,7 @@ class DirectoryHandler(BaseHandler):
def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
- raise SynapseError(400, "Room Alias is not hosted on this Home Server")
+ raise SynapseError(400, "Room Alias is not hosted on this homeserver")
result = yield self.get_association_from_room_alias(room_alias)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5ea54f60..f09a0b73 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,6 +36,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key,
)
from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
+ self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+ federation_registry = hs.get_federation_registry()
+
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ federation_registry.register_edu_handler(
+ "org.matrix.signing_key_update",
+ self._edu_updater.incoming_signing_key_update,
+ )
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- hs.get_federation_registry().register_query_handler(
+ federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -119,9 +130,10 @@ class E2eKeysHandler(object):
else:
query_list.append((user_id, None))
- user_ids_not_in_cache, remote_results = (
- yield self.store.get_user_devices_from_cache(query_list)
- )
+ (
+ user_ids_not_in_cache,
+ remote_results,
+ ) = yield self.store.get_user_devices_from_cache(query_list)
for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
for device_id, device in iteritems(devices):
@@ -207,13 +219,15 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
- for user_id, key in remote_result["master_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
except Exception as e:
failure = _exception_to_failure(e)
@@ -251,7 +265,7 @@ class E2eKeysHandler(object):
Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from
- (master|self_signing|user_signing) -> user_id -> key
+ (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@@ -343,7 +357,16 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
- return {"device_keys": res}
+ ret = {"device_keys": res}
+
+ # add in the cross-signing keys
+ cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ device_keys_query, None
+ )
+
+ ret.update(cross_signing_keys)
+
+ return ret
@trace
@defer.inlineCallbacks
@@ -688,17 +711,21 @@ class E2eKeysHandler(object):
try:
# get our self-signing key to verify the signatures
- _, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "self_signing"
- )
+ (
+ _,
+ self_signing_key_id,
+ self_signing_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
# get our master key, since we may have received a signature of it.
# We need to fetch it here so that we know what its key ID is, so
# that we can check if a signature that was sent is a signature of
# the master key or of a device
- master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "master"
- )
+ (
+ master_key,
+ _,
+ master_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
@@ -838,9 +865,11 @@ class E2eKeysHandler(object):
try:
# get our user-signing key to verify the signatures
- user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "user_signing"
- )
+ (
+ user_signing_key,
+ user_signing_key_id,
+ user_signing_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
except SynapseError as e:
failure = _exception_to_failure(e)
for user, devicemap in signatures.items():
@@ -859,7 +888,11 @@ class E2eKeysHandler(object):
try:
# get the target user's master key, to make sure it matches
# what was sent
- master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key(
+ (
+ master_key,
+ master_key_id,
+ _,
+ ) = yield self._get_e2e_cross_signing_verify_key(
target_user, "master", user_id
)
@@ -1047,3 +1080,100 @@ class SignatureListItem:
target_user_id = attr.ib()
target_device_id = attr.ib()
signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+ """Handles incoming signing key updates from federation and updates the DB"""
+
+ def __init__(self, hs, e2e_keys_handler):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_federation_client()
+ self.clock = hs.get_clock()
+ self.e2e_keys_handler = e2e_keys_handler
+
+ self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+ # user_id -> list of updates waiting to be handled.
+ self._pending_updates = {}
+
+ # Recently seen stream ids. We don't bother keeping these in the DB,
+ # but they're useful to have them about to reduce the number of spurious
+ # resyncs.
+ self._seen_updates = ExpiringCache(
+ cache_name="signing_key_update_edu",
+ clock=self.clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ iterable=True,
+ )
+
+ @defer.inlineCallbacks
+ def incoming_signing_key_update(self, origin, edu_content):
+ """Called on incoming signing key update from federation. Responsible for
+ parsing the EDU and adding to pending updates list.
+
+ Args:
+ origin (string): the server that sent the EDU
+ edu_content (dict): the contents of the EDU
+ """
+
+ user_id = edu_content.pop("user_id")
+ master_key = edu_content.pop("master_key", None)
+ self_signing_key = edu_content.pop("self_signing_key", None)
+
+ if get_domain_from_id(user_id) != origin:
+ logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+ return
+
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ # We don't share any rooms with this user. Ignore update, as we
+ # probably won't get any further updates.
+ return
+
+ self._pending_updates.setdefault(user_id, []).append(
+ (master_key, self_signing_key)
+ )
+
+ yield self._handle_signing_key_updates(user_id)
+
+ @defer.inlineCallbacks
+ def _handle_signing_key_updates(self, user_id):
+ """Actually handle pending updates.
+
+ Args:
+ user_id (string): the user whose updates we are processing
+ """
+
+ device_handler = self.e2e_keys_handler.device_handler
+
+ with (yield self._remote_edu_linearizer.queue(user_id)):
+ pending_updates = self._pending_updates.pop(user_id, [])
+ if not pending_updates:
+ # This can happen since we batch updates
+ return
+
+ device_ids = []
+
+ logger.info("pending updates: %r", pending_updates)
+
+ for master_key, self_signing_key in pending_updates:
+ if master_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "master", master_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ device_ids.append(verify_key.version)
+ if self_signing_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(
+ self_signing_key
+ )
+ device_ids.append(verify_key.version)
+
+ yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 5e748687..45fe13c6 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
+ def __init__(self, hs):
+ super(EventHandler, self).__init__(hs)
+ self.storage = hs.get_storage()
+
@defer.inlineCallbacks
def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
@@ -172,7 +176,7 @@ class EventHandler(BaseHandler):
is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
- self.store, user.to_string(), [event], is_peeking=is_peeking
+ self.storage, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 488058fe..0e904f2d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
+from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import (
make_deferred_yieldable,
@@ -96,9 +97,9 @@ class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
a) handling received Pdus before handing them on as Events to the rest
- of the home server (including auth and state conflict resoultion)
+ of the homeserver (including auth and state conflict resoultion)
b) converting events that were produced by local clients that may need
- to be sent to remote home servers.
+ to be sent to remote homeservers.
c) doing the necessary dances to invite remote users and join remote
rooms.
"""
@@ -109,6 +110,8 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@@ -180,7 +183,7 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warn(
+ logger.warning(
"[%s %s] Received event failed sanity checks", room_id, event_id
)
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
@@ -301,7 +304,7 @@ class FederationHandler(BaseHandler):
# following.
if sent_to_us_directly:
- logger.warn(
+ logger.warning(
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
room_id,
event_id,
@@ -324,7 +327,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.store.get_state_groups_ids(room_id, seen)
+ ours = yield self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -350,10 +353,11 @@ class FederationHandler(BaseHandler):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- remote_state, got_auth_chain = (
- yield self.federation_client.get_state_for_room(
- origin, room_id, p
- )
+ (
+ remote_state,
+ got_auth_chain,
+ ) = yield self.federation_client.get_state_for_room(
+ origin, room_id, p
)
# we want the state *after* p; get_state_for_room returns the
@@ -405,7 +409,7 @@ class FederationHandler(BaseHandler):
state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception:
- logger.warn(
+ logger.warning(
"[%s %s] Error attempting to resolve state at missing "
"prev_events",
room_id,
@@ -518,7 +522,9 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
- logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e)
+ logger.warning(
+ "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
+ )
return
logger.info(
@@ -545,7 +551,7 @@ class FederationHandler(BaseHandler):
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
- logger.warn(
+ logger.warning(
"[%s %s] Received prev_event %s failed history check.",
room_id,
event_id,
@@ -888,7 +894,7 @@ class FederationHandler(BaseHandler):
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = yield filter_events_for_server(
- self.store,
+ self.storage,
self.server_name,
list(extremities_events.values()),
redact=False,
@@ -1059,7 +1065,7 @@ class FederationHandler(BaseHandler):
SynapseError if the event does not pass muster
"""
if len(ev.prev_event_ids()) > 20:
- logger.warn(
+ logger.warning(
"Rejecting event %s which has %i prev_events",
ev.event_id,
len(ev.prev_event_ids()),
@@ -1067,7 +1073,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
if len(ev.auth_event_ids()) > 10:
- logger.warn(
+ logger.warning(
"Rejecting event %s which has %i auth_events",
ev.event_id,
len(ev.auth_event_ids()),
@@ -1101,7 +1107,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the
- server `target_host`.
+ servers contained in `target_hosts`.
This first triggers a /make_join/ request that returns a partial
event that we can fill out and sign. This is then sent to the
@@ -1110,6 +1116,15 @@ class FederationHandler(BaseHandler):
We suspend processing of any received events from this room until we
have finished processing the join.
+
+ Args:
+ target_hosts (Iterable[str]): List of servers to attempt to join the room with.
+
+ room_id (str): The ID of the room to join.
+
+ joinee (str): The User ID of the joining user.
+
+ content (dict): The event content to use for the join event.
"""
logger.debug("Joining %s to %s", joinee, room_id)
@@ -1169,6 +1184,22 @@ class FederationHandler(BaseHandler):
yield self._persist_auth_tree(origin, auth_chain, state, event)
+ # Check whether this room is the result of an upgrade of a room we already know
+ # about. If so, migrate over user information
+ predecessor = yield self.store.get_room_predecessor(room_id)
+ if not predecessor:
+ return
+ old_room_id = predecessor["room_id"]
+ logger.debug(
+ "Found predecessor for %s during remote join: %s", room_id, old_room_id
+ )
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ yield member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, room_id
+ )
+
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
@@ -1203,7 +1234,7 @@ class FederationHandler(BaseHandler):
with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
- logger.warn(
+ logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
)
@@ -1250,7 +1281,7 @@ class FederationHandler(BaseHandler):
builder=builder
)
except AuthError as e:
- logger.warn("Failed to create join %r because %s", event, e)
+ logger.warning("Failed to create join to %s because %s", room_id, e)
raise e
event_allowed = yield self.third_party_event_rules.check_event_allowed(
@@ -1494,7 +1525,7 @@ class FederationHandler(BaseHandler):
room_version, event, context, do_sig_check=False
)
except AuthError as e:
- logger.warn("Failed to create new leave %r because %s", event, e)
+ logger.warning("Failed to create new leave %r because %s", event, e)
raise e
return event
@@ -1549,7 +1580,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups(room_id, [event_id])
+ state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
@@ -1578,7 +1609,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
+ state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1606,7 +1637,7 @@ class FederationHandler(BaseHandler):
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
- events = yield filter_events_for_server(self.store, origin, events)
+ events = yield filter_events_for_server(self.storage, origin, events)
return events
@@ -1636,7 +1667,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield filter_events_for_server(self.store, origin, [event])
+ events = yield filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
@@ -1657,7 +1688,11 @@ class FederationHandler(BaseHandler):
# hack around with a try/finally instead.
success = False
try:
- if not event.internal_metadata.is_outlier() and not backfilled:
+ if (
+ not event.internal_metadata.is_outlier()
+ and not backfilled
+ and not context.rejected
+ ):
yield self.action_generator.handle_push_actions_for_event(
event, context
)
@@ -1788,7 +1823,7 @@ class FederationHandler(BaseHandler):
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
- logger.warn("Rejecting %s because %s", e.event_id, err.msg)
+ logger.warning("Rejecting %s because %s", e.event_id, err.msg)
if e == event:
raise
@@ -1841,12 +1876,7 @@ class FederationHandler(BaseHandler):
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
- try:
- yield self.do_auth(origin, event, context, auth_events=auth_events)
- except AuthError as e:
- logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg)
-
- context.rejected = RejectedReason.AUTH_ERROR
+ context = yield self.do_auth(origin, event, context, auth_events=auth_events)
if not context.rejected:
yield self._check_for_soft_fail(event, state, backfilled)
@@ -1902,7 +1932,7 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets = yield self.store.get_state_groups(
+ state_sets = yield self.state_store.get_state_groups(
event.room_id, extrem_ids
)
state_sets = list(state_sets.values())
@@ -1938,7 +1968,7 @@ class FederationHandler(BaseHandler):
try:
event_auth.check(room_version, event, auth_events=current_auth_events)
except AuthError as e:
- logger.warn("Soft-failing %r because %s", event, e)
+ logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
@defer.inlineCallbacks
@@ -1993,7 +2023,7 @@ class FederationHandler(BaseHandler):
)
missing_events = yield filter_events_for_server(
- self.store, origin, missing_events
+ self.storage, origin, missing_events
)
return missing_events
@@ -2015,12 +2045,12 @@ class FederationHandler(BaseHandler):
Also NB that this function adds entries to it.
Returns:
- defer.Deferred[None]
+ defer.Deferred[EventContext]: updated context object
"""
room_version = yield self.store.get_room_version(event.room_id)
try:
- yield self._update_auth_events_and_context_for_auth(
+ context = yield self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
)
except Exception:
@@ -2037,8 +2067,10 @@ class FederationHandler(BaseHandler):
try:
event_auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
- logger.warn("Failed auth resolution for %r because %s", event, e)
- raise e
+ logger.warning("Failed auth resolution for %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
+
+ return context
@defer.inlineCallbacks
def _update_auth_events_and_context_for_auth(
@@ -2062,7 +2094,7 @@ class FederationHandler(BaseHandler):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Returns:
- defer.Deferred[None]
+ defer.Deferred[EventContext]: updated context
"""
event_auth_events = set(event.auth_event_ids())
@@ -2101,7 +2133,7 @@ class FederationHandler(BaseHandler):
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e)
- return
+ return context
seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
@@ -2142,7 +2174,7 @@ class FederationHandler(BaseHandler):
if event.internal_metadata.is_outlier():
logger.info("Skipping auth_event fetch for outlier")
- return
+ return context
# FIXME: Assumes we have and stored all the state for all the
# prev_events
@@ -2151,7 +2183,7 @@ class FederationHandler(BaseHandler):
)
if not different_auth:
- return
+ return context
logger.info(
"auth_events refers to events which are not in our calculated auth "
@@ -2198,10 +2230,12 @@ class FederationHandler(BaseHandler):
auth_events.update(new_state)
- yield self._update_context_for_auth_events(
+ context = yield self._update_context_for_auth_events(
event, context, auth_events, event_key
)
+ return context
+
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution,
@@ -2210,14 +2244,16 @@ class FederationHandler(BaseHandler):
Args:
event (Event): The event we're handling the context for
- context (synapse.events.snapshot.EventContext): event context
- to be updated
+ context (synapse.events.snapshot.EventContext): initial event context
auth_events (dict[(str, str)->str]): Events to update in the event
context.
event_key ((str, str)): (type, state_key) for the current event.
this will not be included in the current_state in the context.
+
+ Returns:
+ Deferred[EventContext]: new event context
"""
state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
@@ -2234,7 +2270,7 @@ class FederationHandler(BaseHandler):
# create a new state group as a delta from the existing one.
prev_group = context.state_group
- state_group = yield self.store.store_state_group(
+ state_group = yield self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -2242,8 +2278,9 @@ class FederationHandler(BaseHandler):
current_state_ids=current_state_ids,
)
- yield context.update_state(
+ return EventContext.with_state(
state_group=state_group,
+ state_group_before_event=context.state_group_before_event,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
prev_group=prev_group,
@@ -2431,10 +2468,12 @@ class FederationHandler(BaseHandler):
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
- logger.warn("Denying new third party invite %r because %s", event, e)
+ logger.warning("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
@@ -2487,7 +2526,7 @@ class FederationHandler(BaseHandler):
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
- logger.warn("Denying third party invite %r because %s", event, e)
+ logger.warning("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
@@ -2495,6 +2534,7 @@ class FederationHandler(BaseHandler):
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
@@ -2664,7 +2704,7 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
else:
- max_stream_id = yield self.store.persist_events(
+ max_stream_id = yield self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 46eb9ee8..92fecbfc 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -392,7 +392,7 @@ class GroupsLocalHandler(object):
try:
user_profile = yield self.profile_handler.get_profile(user_id)
except Exception as e:
- logger.warn("No profile for user %s: %s", user_id, e)
+ logger.warning("No profile for user %s: %s", user_id, e)
user_profile = {}
return {"state": "invite", "user_profile": user_profile}
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index ba99ddf7..000fbf09 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -272,7 +272,7 @@ class IdentityHandler(BaseHandler):
changed = False
if e.code in (400, 404, 501):
# The remote server probably doesn't support unbinding (yet)
- logger.warn("Received %d response while unbinding threepid", e.code)
+ logger.warning("Received %d response while unbinding threepid", e.code)
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
raise SynapseError(500, "Failed to contact identity server")
@@ -403,7 +403,7 @@ class IdentityHandler(BaseHandler):
if self.hs.config.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use
- logger.warn(
+ logger.warning(
'The config option "trust_identity_server_for_password_resets" '
'has been replaced by "account_threepid_delegate". '
"Please consult the sample config at docs/sample_config.yaml for "
@@ -457,7 +457,7 @@ class IdentityHandler(BaseHandler):
if self.hs.config.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use
- logger.warn(
+ logger.warning(
'The config option "trust_identity_server_for_password_resets" '
'has been replaced by "account_threepid_delegate". '
"Please consult the sample config at docs/sample_config.yaml for "
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index f991efee..81dce96f 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler):
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
def snapshot_all_rooms(
self,
@@ -126,8 +128,8 @@ class InitialSyncHandler(BaseHandler):
tags_by_room = yield self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(user_id)
+ account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+ user_id
)
public_room_ids = yield self.store.get_public_room_ids()
@@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler):
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
- self.store.get_state_for_events, [event.event_id]
+ self.state_store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
- messages = yield filter_events_for_client(self.store, user_id, messages)
+ messages = yield filter_events_for_client(
+ self.storage, user_id, messages
+ )
start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", room_end_token)
@@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
- room_state = yield self.store.get_state_for_events([member_event_id])
+ room_state = yield self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler):
)
messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking
+ self.storage, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token)
@@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler):
)
messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking
+ self.storage, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace("room_key", token)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0f8cce8f..d682dc2b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -59,6 +59,8 @@ class MessageHandler(object):
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@@ -74,15 +76,16 @@ class MessageHandler(object):
Raises:
SynapseError if something went wrong.
"""
- membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
- )
+ (
+ membership,
+ membership_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if membership == Membership.JOIN:
data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
- room_state = yield self.store.get_state_for_events(
+ room_state = yield self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@@ -135,12 +138,12 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.store, user_id, last_events
+ self.storage, user_id, last_events
)
event = last_events[0]
if visible_events:
- room_state = yield self.store.get_state_for_events(
+ room_state = yield self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
@@ -151,9 +154,10 @@ class MessageHandler(object):
% (user_id, room_id, at_token),
)
else:
- membership, membership_event_id = (
- yield self.auth.check_in_room_or_world_readable(room_id, user_id)
- )
+ (
+ membership,
+ membership_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
@@ -161,7 +165,7 @@ class MessageHandler(object):
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
- room_state = yield self.store.get_state_for_events(
+ room_state = yield self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]
@@ -234,6 +238,7 @@ class EventCreationHandler(object):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
@@ -687,7 +692,7 @@ class EventCreationHandler(object):
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as err:
- logger.warn("Denying new event %r because %s", event, err)
+ logger.warning("Denying new event %r because %s", event, err)
raise err
# Ensure that we can round trip before trying to persist in db
@@ -868,7 +873,7 @@ class EventCreationHandler(object):
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
- (event_stream_id, max_stream_id) = yield self.store.persist_event(
+ event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
event, context=context
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5744f457..260a4351 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -69,6 +69,8 @@ class PaginationHandler(object):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self.clock = hs.get_clock()
self._server_name = hs.hostname
@@ -125,7 +127,9 @@ class PaginationHandler(object):
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)):
- yield self.store.purge_history(room_id, token, delete_local_events)
+ yield self.storage.purge_events.purge_history(
+ room_id, token, delete_local_events
+ )
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
@@ -168,7 +172,7 @@ class PaginationHandler(object):
if joined:
raise SynapseError(400, "Users are still joined to this room")
- await self.store.purge_room(room_id)
+ await self.storage.purge_events.purge_room(room_id)
@defer.inlineCallbacks
def get_messages(
@@ -210,9 +214,10 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)):
- membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
- )
+ (
+ membership,
+ member_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -255,7 +260,7 @@ class PaginationHandler(object):
events = event_filter.filter(events)
events = yield filter_events_for_client(
- self.store, user_id, events, is_peeking=(member_event_id is None)
+ self.storage, user_id, events, is_peeking=(member_event_id is None)
)
if not events:
@@ -274,7 +279,7 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
@@ -295,10 +300,8 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = (
- yield self._event_serializer.serialize_events(
- state, time_now, as_client_event=as_client_event
- )
+ chunk["state"] = yield self._event_serializer.serialize_events(
+ state, time_now, as_client_event=as_client_event
)
return chunk
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 8690f69d..1e5a4613 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -152,7 +152,7 @@ class BaseProfileHandler(BaseHandler):
by_admin (bool): Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
@@ -207,7 +207,7 @@ class BaseProfileHandler(BaseHandler):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
@@ -231,7 +231,7 @@ class BaseProfileHandler(BaseHandler):
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
just_field = args.get("field", None)
@@ -275,7 +275,7 @@ class BaseProfileHandler(BaseHandler):
ratelimit=False, # Try to hide that these events aren't atomic.
)
except Exception as e:
- logger.warn(
+ logger.warning(
"Failed to update join event for room %s - %s", room_id, str(e)
)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 3e4d8c93..e3b528d2 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler):
self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def received_client_read_marker(self, room_id, user_id, event_id):
+ async def received_client_read_marker(self, room_id, user_id, event_id):
"""Updates the read marker for a given user in a given room if the event ID given
is ahead in the stream relative to the current read marker.
@@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler):
the read marker has changed.
"""
- with (yield self.read_marker_linearizer.queue((room_id, user_id))):
- existing_read_marker = yield self.store.get_account_data_for_room_and_type(
+ with await self.read_marker_linearizer.queue((room_id, user_id)):
+ existing_read_marker = await self.store.get_account_data_for_room_and_type(
user_id, room_id, "m.fully_read"
)
@@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler):
if existing_read_marker:
# Only update if the new marker is ahead in the stream
- should_update = yield self.store.is_event_after(
+ should_update = await self.store.is_event_after(
event_id, existing_read_marker["event_id"]
)
if should_update:
content = {"event_id": event_id}
- max_id = yield self.store.add_account_data_to_room(
+ max_id = await self.store.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6854c751..9283c039 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- @defer.inlineCallbacks
- def _received_remote_receipt(self, origin, content):
+ async def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = []
@@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler):
)
)
- yield self._handle_new_receipts(receipts)
+ await self._handle_new_receipts(receipts)
- @defer.inlineCallbacks
- def _handle_new_receipts(self, receipts):
+ async def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
min_batch_id = None
max_batch_id = None
for receipt in receipts:
- res = yield self.store.insert_receipt(
+ res = await self.store.insert_receipt(
receipt.room_id,
receipt.receipt_type,
receipt.user_id,
@@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- yield self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
+ await maybe_awaitable(
+ self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
+ )
)
return True
- @defer.inlineCallbacks
- def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+ async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler):
data={"ts": int(self.clock.time_msec())},
)
- is_new = yield self._handle_new_receipts([receipt])
+ is_new = await self._handle_new_receipts([receipt])
if not is_new:
return
- yield self.federation.send_read_receipt(receipt)
-
- @defer.inlineCallbacks
- def get_receipts_for_room(self, room_id, to_key):
- """Gets all receipts for a room, upto the given key.
- """
- result = yield self.store.get_linearized_receipts_for_room(
- room_id, to_key=to_key
- )
-
- if not result:
- return []
-
- return result
+ await self.federation.send_read_receipt(receipt)
class ReceiptEventSource(object):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 53410f12..95806af4 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -24,7 +24,6 @@ from synapse.api.errors import (
AuthError,
Codes,
ConsentNotGivenError,
- LimitExceededError,
RegistrationError,
SynapseError,
)
@@ -168,6 +167,7 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
+ yield self.check_registration_ratelimit(address)
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
@@ -217,8 +217,13 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
+ fail_count = 0
user = None
while not user:
+ # Fail after being unable to find a suitable ID a few times
+ if fail_count > 10:
+ raise SynapseError(500, "Unable to find a suitable guest user ID")
+
localpart = yield self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -233,10 +238,14 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=default_display_name,
address=address,
)
+
+ # Successfully registered
+ break
except SynapseError:
# if user id is taken, just generate another
user = None
user_id = None
+ fail_count += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
@@ -396,8 +405,8 @@ class RegistrationHandler(BaseHandler):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = (
- yield room_member_handler.lookup_room_alias(room_alias)
+ room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
+ room_alias
)
room_id = room_id.to_string()
else:
@@ -414,6 +423,29 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
+ def check_registration_ratelimit(self, address):
+ """A simple helper method to check whether the registration rate limit has been hit
+ for a given IP address
+
+ Args:
+ address (str|None): the IP address used to perform the registration. If this is
+ None, no ratelimiting will be performed.
+
+ Raises:
+ LimitExceededError: If the rate limit has been exceeded.
+ """
+ if not address:
+ return
+
+ time_now = self.clock.time()
+
+ self.ratelimiter.ratelimit(
+ address,
+ time_now_s=time_now,
+ rate_hz=self.hs.config.rc_registration.per_second,
+ burst_count=self.hs.config.rc_registration.burst_count,
+ )
+
def register_with_store(
self,
user_id,
@@ -446,22 +478,6 @@ class RegistrationHandler(BaseHandler):
Returns:
Deferred
"""
- # Don't rate limit for app services
- if appservice_id is None and address is not None:
- time_now = self.clock.time()
-
- allowed, time_allowed = self.ratelimiter.can_do_action(
- address,
- time_now_s=time_now,
- rate_hz=self.hs.config.rc_registration.per_second,
- burst_count=self.hs.config.rc_registration.burst_count,
- )
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now))
- )
-
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,
@@ -614,7 +630,7 @@ class RegistrationHandler(BaseHandler):
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
- # notifs are set up on a home server)
+ # notifs are set up on a homeserver)
if (
self.hs.config.email_enable_notifs
and self.hs.config.email_notif_for_new_users
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2816bd8f..e92b2eaf 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id,
new_version, # args for _upgrade_room
)
+
return ret
@defer.inlineCallbacks
@@ -147,21 +148,22 @@ class RoomCreationHandler(BaseHandler):
# we create and auth the tombstone event before properly creating the new
# room, to check our user has perms in the old room.
- tombstone_event, tombstone_context = (
- yield self.event_creation_handler.create_event(
- requester,
- {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- },
+ (
+ tombstone_event,
+ tombstone_context,
+ ) = yield self.event_creation_handler.create_event(
+ requester,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
},
- token_id=requester.access_token_id,
- )
+ },
+ token_id=requester.access_token_id,
)
old_room_version = yield self.store.get_room_version(old_room_id)
yield self.auth.check_from_context(
@@ -188,7 +190,12 @@ class RoomCreationHandler(BaseHandler):
requester, old_room_id, new_room_id, old_room_state
)
- # and finally, shut down the PLs in the old room, and update them in the new
+ # Copy over user push rules, tags and migrate room directory state
+ yield self.room_member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, new_room_id
+ )
+
+ # finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state
@@ -822,6 +829,8 @@ class RoomContextHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, event_filter):
@@ -848,7 +857,7 @@ class RoomContextHandler(object):
def filter_evts(events):
return filter_events_for_client(
- self.store, user.to_string(), events, is_peeking=is_peeking
+ self.storage, user.to_string(), events, is_peeking=is_peeking
)
event = yield self.store.get_event(
@@ -890,7 +899,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.store.get_state_for_events(
+ state = yield self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
results["state"] = list(state[last_event_id].values())
@@ -922,7 +931,7 @@ class RoomEventSource(object):
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
- logger.warn("Stream has topological part!!!! %r", from_key)
+ logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
app_service = self.store.get_app_service_by_user_id(user.to_string())
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 380e2fad..6cfee4b3 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -203,10 +203,6 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- # Copy over user state if we're joining an upgraded room
- yield self.copy_user_state_if_room_upgrade(
- room_id, requester.user.to_string()
- )
yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -455,11 +451,6 @@ class RoomMemberHandler(object):
requester, remote_room_hosts, room_id, target, content
)
- # Copy over user state if this is a join on an remote upgraded room
- yield self.copy_user_state_if_room_upgrade(
- room_id, requester.user.to_string()
- )
-
return remote_join_response
elif effective_membership_state == Membership.LEAVE:
@@ -498,36 +489,81 @@ class RoomMemberHandler(object):
return res
@defer.inlineCallbacks
- def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
- """Copy user-specific information when they join a new room if that new room is the
+ def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+ """Upon our server becoming aware of an upgraded room, either by upgrading a room
+ ourselves or joining one, we can transfer over information from the previous room.
+
+ Copies user state (tags/push rules) for every local user that was in the old room, as
+ well as migrating the room directory state.
+
+ Args:
+ old_room_id (str): The ID of the old room
+
+ room_id (str): The ID of the new room
+
+ Returns:
+ Deferred
+ """
+ # Find all local users that were in the old room and copy over each user's state
+ users = yield self.store.get_users_in_room(old_room_id)
+ yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
+
+ # Add new room to the room directory if the old room was there
+ # Remove old room from the room directory
+ old_room = yield self.store.get_room(old_room_id)
+ if old_room and old_room["is_public"]:
+ yield self.store.set_room_is_public(old_room_id, False)
+ yield self.store.set_room_is_public(room_id, True)
+
+ # Check if any groups we own contain the predecessor room
+ local_group_ids = yield self.store.get_local_groups_for_room(old_room_id)
+ for group_id in local_group_ids:
+ # Add new the new room to those groups
+ yield self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
+
+ # Remove the old room from those groups
+ yield self.store.remove_room_from_group(group_id, old_room_id)
+
+ @defer.inlineCallbacks
+ def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+ """Copy user-specific information when they join a new room when that new room is the
result of a room upgrade
Args:
- new_room_id (str): The ID of the room the user is joining
- user_id (str): The ID of the user
+ old_room_id (str): The ID of upgraded room
+ new_room_id (str): The ID of the new room
+ user_ids (Iterable[str]): User IDs to copy state for
Returns:
Deferred
"""
- # Check if the new room is an upgraded room
- predecessor = yield self.store.get_room_predecessor(new_room_id)
- if not predecessor:
- return
logger.debug(
- "Found predecessor for %s: %s. Copying over room tags and push " "rules",
+ "Copying over room tags and push rules from %s to %s for users %s",
+ old_room_id,
new_room_id,
- predecessor,
+ user_ids,
)
- # It is an upgraded room. Copy over old tags
- yield self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], new_room_id, user_id
- )
- # Copy over push rules
- yield self.store.copy_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], new_room_id, user_id
- )
+ for user_id in user_ids:
+ try:
+ # It is an upgraded room. Copy over old tags
+ yield self.copy_room_tags_and_direct_to_room(
+ old_room_id, new_room_id, user_id
+ )
+ # Copy over push rules
+ yield self.store.copy_push_rules_from_room_to_room_for_user(
+ old_room_id, new_room_id, user_id
+ )
+ except Exception:
+ logger.exception(
+ "Error copying tags and/or push rules from rooms %s to %s for user %s. "
+ "Skipping...",
+ old_room_id,
+ new_room_id,
+ user_id,
+ )
+ continue
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
@@ -759,22 +795,25 @@ class RoomMemberHandler(object):
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
- token, public_keys, fallback_public_key, display_name = (
- yield self.identity_handler.ask_id_server_for_third_party_invite(
- requester=requester,
- id_server=id_server,
- medium=medium,
- address=address,
- room_id=room_id,
- inviter_user_id=user.to_string(),
- room_alias=canonical_room_alias,
- room_avatar_url=room_avatar_url,
- room_join_rules=room_join_rules,
- room_name=room_name,
- inviter_display_name=inviter_display_name,
- inviter_avatar_url=inviter_avatar_url,
- id_access_token=id_access_token,
- )
+ (
+ token,
+ public_keys,
+ fallback_public_key,
+ display_name,
+ ) = yield self.identity_handler.ask_id_server_for_third_party_invite(
+ requester=requester,
+ id_server=id_server,
+ medium=medium,
+ address=address,
+ room_id=room_id,
+ inviter_user_id=user.to_string(),
+ room_alias=canonical_room_alias,
+ room_avatar_url=room_avatar_url,
+ room_join_rules=room_join_rules,
+ room_name=room_name,
+ inviter_display_name=inviter_display_name,
+ inviter_avatar_url=inviter_avatar_url,
+ id_access_token=id_access_token,
)
yield self.event_creation_handler.create_and_send_nonmember_event(
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index cd5e90ba..56ed262a 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -35,6 +35,8 @@ class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -221,7 +223,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
- self.store, user.to_string(), filtered_events
+ self.storage, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@@ -271,7 +273,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
- self.store, user.to_string(), filtered_events
+ self.storage, user.to_string(), filtered_events
)
room_events.extend(events)
@@ -340,11 +342,11 @@ class SearchHandler(BaseHandler):
)
res["events_before"] = yield filter_events_for_client(
- self.store, user.to_string(), res["events_before"]
+ self.storage, user.to_string(), res["events_before"]
)
res["events_after"] = yield filter_events_for_client(
- self.store, user.to_string(), res["events_after"]
+ self.storage, user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(
@@ -372,7 +374,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders]
)
- state = yield self.store.get_state_for_event(
+ state = yield self.state_store.get_state_for_event(
last_event_id, state_filter
)
@@ -394,15 +396,11 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = (
- yield self._event_serializer.serialize_events(
- context["events_before"], time_now
- )
+ context["events_before"] = yield self._event_serializer.serialize_events(
+ context["events_before"], time_now
)
- context["events_after"] = (
- yield self._event_serializer.serialize_events(
- context["events_after"], time_now
- )
+ context["events_after"] = yield self._event_serializer.serialize_events(
+ context["events_after"], time_now
)
state_results = {}
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 26bc2766..7f7d5639 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler):
user_deltas = {}
# Then count deltas for total_events and total_event_bytes.
- room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
+ (
+ room_count,
+ user_count,
+ ) = yield self.store.get_changes_room_total_events_and_bytes(
self.pos, max_pos
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d99160e9..b536d410 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -230,6 +230,8 @@ class SyncHandler(object):
self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
@@ -417,7 +419,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client(
- self.store,
+ self.storage,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@@ -470,7 +472,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client(
- self.store,
+ self.storage,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@@ -509,7 +511,7 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter
)
if event.is_state():
@@ -580,7 +582,7 @@ class SyncHandler(object):
return None
last_event = last_events[-1]
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -757,11 +759,11 @@ class SyncHandler(object):
if full_state:
if batch:
- current_state_ids = yield self.store.get_state_ids_for_event(
+ current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
@@ -781,7 +783,7 @@ class SyncHandler(object):
)
elif batch.limited:
if batch:
- state_at_timeline_start = yield self.store.get_state_ids_for_event(
+ state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
else:
@@ -810,7 +812,7 @@ class SyncHandler(object):
)
if batch:
- current_state_ids = yield self.store.get_state_ids_for_event(
+ current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
else:
@@ -841,7 +843,7 @@ class SyncHandler(object):
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
@@ -1204,10 +1206,11 @@ class SyncHandler(object):
since_token = sync_result_builder.since_token
if since_token and not sync_result_builder.full_state:
- account_data, account_data_by_room = (
- yield self.store.get_updated_account_data_for_user(
- user_id, since_token.account_data_key
- )
+ (
+ account_data,
+ account_data_by_room,
+ ) = yield self.store.get_updated_account_data_for_user(
+ user_id, since_token.account_data_key
)
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
@@ -1219,9 +1222,10 @@ class SyncHandler(object):
sync_config.user
)
else:
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(sync_config.user.to_string())
- )
+ (
+ account_data,
+ account_data_by_room,
+ ) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index ca8ae9fb..856337b7 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -120,7 +120,7 @@ class TypingHandler(object):
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
@@ -150,7 +150,7 @@ class TypingHandler(object):
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 29aa1e5a..8363d887 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
super().__init__(hs)
self._enabled = bool(hs.config.recaptcha_private_key)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
self._url = hs.config.recaptcha_siteverify_api
self._secret = hs.config.recaptcha_private_key
diff --git a/synapse/http/client.py b/synapse/http/client.py
index cdf828a4..d4c28544 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,6 +45,7 @@ from synapse.http import (
cancelled_to_request_timed_out_error,
redact_uri,
)
+from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util.async_helpers import timeout_deferred
@@ -183,7 +184,15 @@ class SimpleHttpClient(object):
using HTTP in Matrix
"""
- def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ hs,
+ treq_args={},
+ ip_whitelist=None,
+ ip_blacklist=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
"""
Args:
hs (synapse.server.HomeServer)
@@ -192,6 +201,8 @@ class SimpleHttpClient(object):
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
+ http_proxy (bytes): proxy server to use for http connections. host[:port]
+ https_proxy (bytes): proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -236,11 +247,13 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
- self.agent = Agent(
+ self.agent = ProxyAgent(
self.reactor,
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
+ http_proxy=http_proxy,
+ https_proxy=https_proxy,
)
if self._ip_blacklist:
@@ -535,7 +548,7 @@ class SimpleHttpClient(object):
b"Content-Length" in resp_headers
and int(resp_headers[b"Content-Length"][0]) > max_size
):
- logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
+ logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
@@ -543,7 +556,7 @@ class SimpleHttpClient(object):
)
if response.code > 299:
- logger.warn("Got %d when downloading %s" % (response.code, url))
+ logger.warning("Got %d when downloading %s" % (response.code, url))
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
# TODO: if our Content-Type is HTML or something, just read the first
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
new file mode 100644
index 00000000..be7b2ceb
--- /dev/null
+++ b/synapse/http/connectproxyclient.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer, protocol
+from twisted.internet.error import ConnectError
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.internet.protocol import connectionDone
+from twisted.web import http
+
+logger = logging.getLogger(__name__)
+
+
+class ProxyConnectError(ConnectError):
+ pass
+
+
+@implementer(IStreamClientEndpoint)
+class HTTPConnectProxyEndpoint(object):
+ """An Endpoint implementation which will send a CONNECT request to an http proxy
+
+ Wraps an existing HostnameEndpoint for the proxy.
+
+ When we get the connect() request from the connection pool (via the TLS wrapper),
+ we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
+ CONNECT request. Once that completes, we invoke the protocolFactory which was passed
+ in.
+
+ Args:
+ reactor: the Twisted reactor to use for the connection
+ proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
+ proxy
+ host (bytes): hostname that we want to CONNECT to
+ port (int): port that we want to connect to
+ """
+
+ def __init__(self, reactor, proxy_endpoint, host, port):
+ self._reactor = reactor
+ self._proxy_endpoint = proxy_endpoint
+ self._host = host
+ self._port = port
+
+ def __repr__(self):
+ return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
+
+ def connect(self, protocolFactory):
+ f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ d = self._proxy_endpoint.connect(f)
+ # once the tcp socket connects successfully, we need to wait for the
+ # CONNECT to complete.
+ d.addCallback(lambda conn: f.on_connection)
+ return d
+
+
+class HTTPProxiedClientFactory(protocol.ClientFactory):
+ """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
+
+ Once the CONNECT completes, invokes the original ClientFactory to build the
+ HTTP Protocol object and run the rest of the connection.
+
+ Args:
+ dst_host (bytes): hostname that we want to CONNECT to
+ dst_port (int): port that we want to connect to
+ wrapped_factory (protocol.ClientFactory): The original Factory
+ """
+
+ def __init__(self, dst_host, dst_port, wrapped_factory):
+ self.dst_host = dst_host
+ self.dst_port = dst_port
+ self.wrapped_factory = wrapped_factory
+ self.on_connection = defer.Deferred()
+
+ def startedConnecting(self, connector):
+ return self.wrapped_factory.startedConnecting(connector)
+
+ def buildProtocol(self, addr):
+ wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+
+ return HTTPConnectProtocol(
+ self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ )
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.debug("Connection to proxy failed: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector, reason):
+ logger.debug("Connection to proxy lost: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionLost(connector, reason)
+
+
+class HTTPConnectProtocol(protocol.Protocol):
+ """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
+
+ Args:
+ host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ to put in the CONNECT request
+
+ port (int): The original HTTP(s) port to put in the CONNECT request
+
+ wrapped_protocol (interfaces.IProtocol): the original protocol (probably
+ HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+
+ connected_deferred (Deferred): a Deferred which will be callbacked with
+ wrapped_protocol when the CONNECT completes
+ """
+
+ def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ self.host = host
+ self.port = port
+ self.wrapped_protocol = wrapped_protocol
+ self.connected_deferred = connected_deferred
+ self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.http_setup_client.on_connected.addCallback(self.proxyConnected)
+
+ def connectionMade(self):
+ self.http_setup_client.makeConnection(self.transport)
+
+ def connectionLost(self, reason=connectionDone):
+ if self.wrapped_protocol.connected:
+ self.wrapped_protocol.connectionLost(reason)
+
+ self.http_setup_client.connectionLost(reason)
+
+ if not self.connected_deferred.called:
+ self.connected_deferred.errback(reason)
+
+ def proxyConnected(self, _):
+ self.wrapped_protocol.makeConnection(self.transport)
+
+ self.connected_deferred.callback(self.wrapped_protocol)
+
+ # Get any pending data from the http buf and forward it to the original protocol
+ buf = self.http_setup_client.clearLineBuffer()
+ if buf:
+ self.wrapped_protocol.dataReceived(buf)
+
+ def dataReceived(self, data):
+ # if we've set up the HTTP protocol, we can send the data there
+ if self.wrapped_protocol.connected:
+ return self.wrapped_protocol.dataReceived(data)
+
+ # otherwise, we must still be setting up the connection: send the data to the
+ # setup client
+ return self.http_setup_client.dataReceived(data)
+
+
+class HTTPConnectSetupClient(http.HTTPClient):
+ """HTTPClient protocol to send a CONNECT message for proxies and read the response.
+
+ Args:
+ host (bytes): The hostname to send in the CONNECT message
+ port (int): The port to send in the CONNECT message
+ """
+
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+ self.on_connected = defer.Deferred()
+
+ def connectionMade(self):
+ logger.debug("Connected to proxy, sending CONNECT")
+ self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+ self.endHeaders()
+
+ def handleStatus(self, version, status, message):
+ logger.debug("Got Status: %s %s %s", status, message, version)
+ if status != b"200":
+ raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+
+ def handleEndHeaders(self):
+ logger.debug("End Headers")
+ self.on_connected.callback(None)
+
+ def handleResponse(self, body):
+ pass
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 3fe4ffb9..021b233a 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -148,7 +148,7 @@ class SrvResolver(object):
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
- logger.warn(
+ logger.warning(
"Failed to resolve %r, falling back to cache. %r", service_name, e
)
return list(cache_entry)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3f7c93ff..16765d54 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -149,7 +149,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
body = yield make_deferred_yieldable(d)
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
@@ -457,7 +457,7 @@ class MatrixFederationHttpClient(object):
except Exception as e:
# Eh, we're already going to raise an exception so lets
# ignore if this fails.
- logger.warn(
+ logger.warning(
"{%s} [%s] Failed to get error response: %s %s: %s",
request.txn_id,
request.destination,
@@ -478,7 +478,7 @@ class MatrixFederationHttpClient(object):
break
except RequestSendFailed as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
@@ -513,7 +513,7 @@ class MatrixFederationHttpClient(object):
raise
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
@@ -530,7 +530,7 @@ class MatrixFederationHttpClient(object):
"""
Builds the Authorization headers for a federation request
Args:
- destination (bytes|None): The desination home server of the request.
+ destination (bytes|None): The desination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
method (bytes): The HTTP method of the request
@@ -889,7 +889,7 @@ class MatrixFederationHttpClient(object):
d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d)
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
new file mode 100644
index 00000000..332da02a
--- /dev/null
+++ b/synapse/http/proxyagent.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.python.failure import Failure
+from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.error import SchemeNotSupported
+from twisted.web.iweb import IAgent
+
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+
+logger = logging.getLogger(__name__)
+
+_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+
+
+@implementer(IAgent)
+class ProxyAgent(_AgentBase):
+ """An Agent implementation which will use an HTTP proxy if one was requested
+
+ Args:
+ reactor: twisted reactor to place outgoing
+ connections.
+
+ contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+ verification parameters of OpenSSL. The default is to use a
+ `BrowserLikePolicyForHTTPS`, so unless you have special
+ requirements you can leave this as-is.
+
+ connectTimeout (float): The amount of time that this Agent will wait
+ for the peer to accept a connection.
+
+ bindAddress (bytes): The local address for client sockets to bind to.
+
+ pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+ non-persistent pool instance will be created.
+ """
+
+ def __init__(
+ self,
+ reactor,
+ contextFactory=BrowserLikePolicyForHTTPS(),
+ connectTimeout=None,
+ bindAddress=None,
+ pool=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
+ _AgentBase.__init__(self, reactor, pool)
+
+ self._endpoint_kwargs = {}
+ if connectTimeout is not None:
+ self._endpoint_kwargs["timeout"] = connectTimeout
+ if bindAddress is not None:
+ self._endpoint_kwargs["bindAddress"] = bindAddress
+
+ self.http_proxy_endpoint = _http_proxy_endpoint(
+ http_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self.https_proxy_endpoint = _http_proxy_endpoint(
+ https_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self._policy_for_https = contextFactory
+ self._reactor = reactor
+
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Issue a request to the server indicated by the given uri.
+
+ Supports `http` and `https` schemes.
+
+ An existing connection from the connection pool may be used or a new one may be
+ created.
+
+ See also: twisted.web.iweb.IAgent.request
+
+ Args:
+ method (bytes): The request method to use, such as `GET`, `POST`, etc
+
+ uri (bytes): The location of the resource to request.
+
+ headers (Headers|None): Extra headers to send with the request
+
+ bodyProducer (IBodyProducer|None): An object which can generate bytes to
+ make up the body of this request (for example, the properly encoded
+ contents of a file for a file upload). Or, None if the request is to
+ have no body.
+
+ Returns:
+ Deferred[IResponse]: completes when the header of the response has
+ been received (regardless of the response status code).
+ """
+ uri = uri.strip()
+ if not _VALID_URI.match(uri):
+ raise ValueError("Invalid URI {!r}".format(uri))
+
+ parsed_uri = URI.fromBytes(uri)
+ pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+ request_path = parsed_uri.originForm
+
+ if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ # Cache *all* connections under the same key, since we are only
+ # connecting to a single destination, the proxy:
+ pool_key = ("http-proxy", self.http_proxy_endpoint)
+ endpoint = self.http_proxy_endpoint
+ request_path = uri
+ elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ endpoint = HTTPConnectProxyEndpoint(
+ self._reactor,
+ self.https_proxy_endpoint,
+ parsed_uri.host,
+ parsed_uri.port,
+ )
+ else:
+ # not using a proxy
+ endpoint = HostnameEndpoint(
+ self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
+ )
+
+ logger.debug("Requesting %s via %s", uri, endpoint)
+
+ if parsed_uri.scheme == b"https":
+ tls_connection_creator = self._policy_for_https.creatorForNetloc(
+ parsed_uri.host, parsed_uri.port
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+ elif parsed_uri.scheme == b"http":
+ pass
+ else:
+ return defer.fail(
+ Failure(
+ SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
+ )
+ )
+
+ return self._requestWithEndpoint(
+ pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
+ )
+
+
+def _http_proxy_endpoint(proxy, reactor, **kwargs):
+ """Parses an http proxy setting and returns an endpoint for the proxy
+
+ Args:
+ proxy (bytes|None): the proxy setting
+ reactor: reactor to be used to connect to the proxy
+ kwargs: other args to be passed to HostnameEndpoint
+
+ Returns:
+ interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
+ or None
+ """
+ if proxy is None:
+ return None
+
+ # currently we only support hostname:port. Some apps also support
+ # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
+ # proxy.
+
+ host, port = parse_host_port(proxy, default_port=1080)
+ return HostnameEndpoint(reactor, host, port, **kwargs)
+
+
+def parse_host_port(hostport, default_port=None):
+ # could have sworn we had one of these somewhere else...
+ if b":" in hostport:
+ host, port = hostport.rsplit(b":", 1)
+ try:
+ port = int(port)
+ return host, port
+ except ValueError:
+ # the thing after the : wasn't a valid port; presumably this is an
+ # IPv6 address.
+ pass
+
+ return hostport, default_port
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 46af27c8..58f9cc61 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -170,7 +170,7 @@ class RequestMetrics(object):
tag = context.tag
if context != self.start_context:
- logger.warn(
+ logger.warning(
"Context have unexpectedly changed %r, %r",
context,
self.start_context,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2ccb210f..943d12c9 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -454,7 +454,7 @@ def respond_with_json(
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
- logger.warn(
+ logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 274c1a6a..e9a5e46c 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -219,13 +219,13 @@ def parse_json_value_from_request(request, allow_empty_body=False):
try:
content_unicode = content_bytes.decode("utf8")
except UnicodeDecodeError:
- logger.warn("Unable to decode UTF-8")
+ logger.warning("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
try:
content = json.loads(content_unicode)
except Exception as e:
- logger.warn("Unable to parse JSON: %s", e)
+ logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content
diff --git a/synapse/http/site.py b/synapse/http/site.py
index df5274c1..ff8184a3 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -199,7 +199,7 @@ class SynapseRequest(Request):
# It's useful to log it here so that we can get an idea of when
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
- logger.warn(
+ logger.warning(
"Error processing request %r: %s %s", self, reason.type, reason.value
)
@@ -305,7 +305,7 @@ class SynapseRequest(Request):
try:
self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
except Exception as e:
- logger.warn("Failed to stop metrics: %r", e)
+ logger.warning("Failed to stop metrics: %r", e)
class XForwardedForRequest(SynapseRequest):
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 3220e985..334ddaf3 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
def parse_drain_configs(
- drains: dict
+ drains: dict,
) -> typing.Generator[DrainConfiguration, None, None]:
"""
Parse the drain configurations.
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 0ebbde06..76ce7d88 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -153,7 +153,7 @@ class TerseJSONToTCPLogObserver(object):
An IObserver that writes JSON logs to a TCP target.
Args:
- hs (HomeServer): The Homeserver that is being logged for.
+ hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
metadata: Metadata to be added to each log entry.
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 370000e3..2c1fb9dd 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -294,7 +294,7 @@ class LoggingContext(object):
"""Enters this logging context into thread local storage"""
old_context = self.set_current_context(self)
if self.previous_context != old_context:
- logger.warn(
+ logger.warning(
"Expected previous context %r, found %r",
self.previous_context,
old_context,
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 4e091314..af161a81 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -159,6 +159,7 @@ class Notifier(object):
self.room_to_user_streams = {}
self.hs = hs
+ self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
@@ -425,7 +426,10 @@ class Notifier(object):
if name == "room":
new_events = yield filter_events_for_client(
- self.store, user.to_string(), new_events, is_peeking=is_peeking
+ self.storage,
+ user.to_string(),
+ new_events,
+ is_peeking=is_peeking,
)
elif name == "presence":
now = self.clock.time_msec()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 22491f37..1ba7bcd4 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -79,7 +79,7 @@ class BulkPushRuleEvaluator(object):
dict of user_id -> push_rules
"""
room_id = event.room_id
- rules_for_room = self._get_rules_for_room(room_id)
+ rules_for_room = yield self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(event, context)
@@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object):
room_members = yield self.store.get_joined_users_from_context(event, context)
- (power_levels, sender_power_level) = (
- yield self._get_power_levels_and_sender_level(event, context)
- )
+ (
+ power_levels,
+ sender_power_level,
+ ) = yield self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 42e5b0c0..8c818a86 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -234,14 +234,12 @@ class EmailPusher(object):
return
self.last_stream_ordering = last_stream_ordering
- pusher_still_exists = (
- yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id,
- self.email,
- self.user_id,
- last_stream_ordering,
- self.clock.time_msec(),
- )
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.email,
+ self.user_id,
+ last_stream_ordering,
+ self.clock.time_msec(),
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 62995878..e994037b 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -64,6 +64,7 @@ class HttpPusher(object):
def __init__(self, hs, pusherdict):
self.hs = hs
self.store = self.hs.get_datastore()
+ self.storage = self.hs.get_storage()
self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict["user_name"]
@@ -102,7 +103,7 @@ class HttpPusher(object):
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
@@ -210,14 +211,12 @@ class HttpPusher(object):
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
- pusher_still_exists = (
- yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id,
- self.pushkey,
- self.user_id,
- self.last_stream_ordering,
- self.clock.time_msec(),
- )
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.pushkey,
+ self.user_id,
+ self.last_stream_ordering,
+ self.clock.time_msec(),
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
@@ -246,7 +245,7 @@ class HttpPusher(object):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
- logger.warn(
+ logger.warning(
"Giving up on a notification to user %s, " "pushkey %s",
self.user_id,
self.pushkey,
@@ -299,7 +298,7 @@ class HttpPusher(object):
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
- logger.warn(
+ logger.warning(
("Ignoring rejected pushkey %s because we" " didn't send it"),
pk,
)
@@ -329,7 +328,7 @@ class HttpPusher(object):
return d
ctx = yield push_tools.get_context_for_event(
- self.store, self.state_handler, event, self.user_id
+ self.storage, self.state_handler, event, self.user_id
)
d = {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 5b16ab4a..1d15a06a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -119,6 +119,7 @@ class Mailer(object):
self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
+ self.storage = hs.get_storage()
self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name)
@@ -389,7 +390,7 @@ class Mailer(object):
}
the_events = yield filter_events_for_client(
- self.store, user_id, results["events_before"]
+ self.storage, user_id, results["events_before"]
)
the_events.append(notif_event)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 5ed9147d..b1587183 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -117,7 +117,7 @@ class PushRuleEvaluatorForEvent(object):
pattern = UserID.from_string(user_id).localpart
if not pattern:
- logger.warn("event_match condition with no pattern")
+ logger.warning("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
@@ -173,7 +173,7 @@ def _glob_matches(glob, value, word_boundary=False):
regex_cache[(glob, word_boundary)] = r
return r.search(value)
except re.error:
- logger.warn("Failed to parse glob to regex: %r", glob)
+ logger.warning("Failed to parse glob to regex: %r", glob)
return False
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index a54051a7..de5c101a 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
+from synapse.storage import Storage
@defer.inlineCallbacks
@@ -43,22 +44,22 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks
-def get_context_for_event(store, state_handler, ev, user_id):
+def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {}
- room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
+ room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or
# a list of people in the room
name = yield calculate_room_name(
- store, room_state_ids, user_id, fallback_to_single_member=False
+ storage.main, room_state_ids, user_id, fallback_to_single_member=False
)
if name:
ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
- sender_state_event = yield store.get_event(sender_state_event_id)
+ sender_state_event = yield storage.main.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event)
return ctx
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 08e840fd..0f699220 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -103,9 +103,7 @@ class PusherPool:
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
- last_stream_ordering = (
- yield self.store.get_latest_push_action_stream_ordering()
- )
+ last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
yield self.store.add_pusher(
user_id=user_id,
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index aa7da1c5..5871feaa 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -61,7 +61,6 @@ REQUIREMENTS = [
"bcrypt>=3.1.0",
"pillow>=4.3.0",
"sortedcontainers>=1.4.4",
- "psutil>=2.0.0",
"pymacaroons>=0.13.0",
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 03560c1f..c8056b0c 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -110,14 +110,14 @@ class ReplicationEndpoint(object):
return {}
@abc.abstractmethod
- def _handle_request(self, request, **kwargs):
+ async def _handle_request(self, request, **kwargs):
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
Returns:
- Deferred[dict]: A JSON serialisable dict to be used as response
- body of request.
+ tuple[int, dict]: HTTP status code and a JSON serialisable dict
+ to be used as response body of request.
"""
pass
@@ -180,7 +180,7 @@ class ReplicationEndpoint(object):
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
raise
- logger.warn("%s request timed out", cls.NAME)
+ logger.warning("%s request timed out", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 2f169559..9af4e7e1 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -82,8 +82,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload
- @defer.inlineCallbacks
- def _handle_request(self, request):
+ async def _handle_request(self, request):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
@@ -101,15 +100,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
EventType = event_type_from_format_version(format_ver)
event = EventType(event_dict, internal_metadata, rejected_reason)
- context = yield EventContext.deserialize(
- self.store, event_payload["context"]
- )
+ context = EventContext.deserialize(self.store, event_payload["context"])
event_and_contexts.append((event, context))
logger.info("Got %d events from federation", len(event_and_contexts))
- yield self.federation_handler.persist_events_and_notify(
+ await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
)
@@ -144,8 +141,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content}
- @defer.inlineCallbacks
- def _handle_request(self, request, edu_type):
+ async def _handle_request(self, request, edu_type):
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
@@ -154,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
logger.info("Got %r edu from %s", edu_type, origin)
- result = yield self.registry.on_edu(edu_type, origin, edu_content)
+ result = await self.registry.on_edu(edu_type, origin, edu_content)
return 200, result
@@ -193,8 +189,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
"""
return {"args": args}
- @defer.inlineCallbacks
- def _handle_request(self, request, query_type):
+ async def _handle_request(self, request, query_type):
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
@@ -202,7 +197,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
logger.info("Got %r query", query_type)
- result = yield self.registry.on_query(query_type, args)
+ result = await self.registry.on_query(query_type, args)
return 200, result
@@ -234,9 +229,8 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
"""
return {}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id):
- yield self.store.clean_room_for_join(room_id)
+ async def _handle_request(self, request, room_id):
+ await self.store.clean_room_for_join(room_id)
return 200, {}
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 786f5232..798b9d3a 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"is_guest": is_guest,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest
)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index b9ce3477..cc1f2497 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
@@ -65,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -79,7 +76,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
logger.info("remote_join: %s into room: %s", user_id, room_id)
- yield self.federation_handler.do_invite_join(
+ await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user_id, event_content
)
@@ -123,8 +120,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"remote_room_hosts": remote_room_hosts,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -137,7 +133,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
- event = yield self.federation_handler.do_remotely_reject_invite(
+ event = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id
)
ret = event.get_pdu_json()
@@ -148,9 +144,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
- logger.warn("Failed to reject invite: %s", e)
+ logger.warning("Failed to reject invite: %s", e)
- yield self.store.locally_reject_invite(user_id, room_id)
+ await self.store.locally_reject_invite(user_id, room_id)
ret = {}
return 200, ret
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 38260256..0c4aca12 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -74,11 +72,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"address": address,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
- yield self.registration_handler.register_with_store(
+ self.registration_handler.check_registration_ratelimit(content["address"])
+
+ await self.registration_handler.register_with_store(
user_id=user_id,
password_hash=content["password_hash"],
was_guest=content["was_guest"],
@@ -117,14 +116,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
"""
return {"auth_result": auth_result, "access_token": access_token}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
auth_result = content["auth_result"]
access_token = content["access_token"]
- yield self.registration_handler.post_registration_actions(
+ await self.registration_handler.post_registration_actions(
user_id=user_id, auth_result=auth_result, access_token=access_token
)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index adb9b2f7..9bafd60b 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -87,8 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
return payload
- @defer.inlineCallbacks
- def _handle_request(self, request, event_id):
+ async def _handle_request(self, request, event_id):
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
@@ -101,7 +100,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event = EventType(event_dict, internal_metadata, rejected_reason)
requester = Requester.deserialize(self.store, content["requester"])
- context = yield EventContext.deserialize(self.store, content["context"])
+ context = EventContext.deserialize(self.store, content["context"])
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
@@ -113,7 +112,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
- yield self.event_creation_handler.persist_and_notify_client_event(
+ await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 182cb2a1..456bc005 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Dict
import six
@@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore):
self.hs = hs
- def stream_positions(self):
+ def stream_positions(self) -> Dict[str, int]:
+ """
+ Get the current positions of all the streams this store wants to subscribe to
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 61557665..de50748c 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
- result["device_lists"] = self._device_list_id_gen.get_current_token()
+ # The user signature stream uses the same stream ID generator as the
+ # device list stream, so set them both to the device list ID
+ # generator's current token.
+ current_token = self._device_list_id_gen.get_current_token()
+ result[DeviceListsStream.NAME] = current_token
+ result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "device_lists":
+ if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ elif stream_name == UserSignatureStream.NAME:
+ for row in rows:
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index a44ceb00..fead7838 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,10 +16,17 @@
"""
import logging
+from typing import Dict
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+ AbstractReplicationClientHandler,
+ ClientReplicationStreamProtocol,
+)
+
from .commands import (
FederationAckCommand,
InvalidateCacheCommand,
@@ -27,7 +34,6 @@ from .commands import (
UserIpCommand,
UserSyncCommand,
)
-from .protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
maxDelay = 30 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler):
+ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
@@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
- def __init__(self, store):
+ def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
@@ -138,11 +144,13 @@ class ReplicationClientHandler(object):
if d:
d.callback(data)
- def get_streams_to_replicate(self):
+ def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
- Returns a dictionary of stream name to token.
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
@@ -168,7 +176,7 @@ class ReplicationClientHandler(object):
if self.connection:
self.connection.send_command(cmd)
else:
- logger.warn("Queuing command as not connected: %r", cmd.NAME)
+ logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 5ffdf267..afaf002f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
@@ -65,6 +65,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from .commands import (
@@ -249,7 +250,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return handler(cmd)
def close(self):
- logger.warn("[%s] Closing connection", self.id())
+ logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
self.transport.loseConnection()
self.on_connection_closed()
@@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636d..5f52264e 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = {
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
+ _base.UserSignatureStream,
)
}
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f03111c2..9e45429d 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
+UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object):
@@ -438,3 +439,20 @@ class GroupServerStream(Stream):
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
+
+
+class UserSignatureStream(Stream):
+ """A user has signed their own device with their user-signing key
+ """
+
+ NAME = "user_signature"
+ _LIMITED = False
+ ROW_TYPE = UserSignatureStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_device_stream_token
+ self.update_function = store.get_all_user_signature_changes_for_remotes
+
+ super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 939418ee..68a59a34 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -14,62 +14,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
-import hmac
import logging
import platform
import re
-from six import text_type
-from six.moves import http_client
-
import synapse
-from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
-from synapse.http.servlet import (
- RestServlet,
- assert_params_in_dict,
- parse_integer,
- parse_json_object_from_request,
- parse_string,
-)
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.admin._base import (
assert_requester_is_admin,
- assert_user_is_admin,
historical_admin_path_patterns,
)
+from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
+from synapse.rest.admin.rooms import ShutdownRoomRestServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
-from synapse.rest.admin.users import UserAdminServlet
-from synapse.types import UserID, create_requester
-from synapse.util.async_helpers import maybe_awaitable
+from synapse.rest.admin.users import (
+ AccountValidityRenewServlet,
+ DeactivateAccountRestServlet,
+ GetUsersPaginatedRestServlet,
+ ResetPasswordRestServlet,
+ SearchUsersRestServlet,
+ UserAdminServlet,
+ UserRegisterServlet,
+ UsersRestServlet,
+ WhoisRestServlet,
+)
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
-class UsersRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
-
- def __init__(self, hs):
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- async def on_GET(self, request, user_id):
- target_user = UserID.from_string(user_id)
- await assert_requester_is_admin(self.auth, request)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- ret = await self.handlers.admin_handler.get_users()
-
- return 200, ret
-
-
class VersionServlet(RestServlet):
PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
@@ -83,159 +60,6 @@ class VersionServlet(RestServlet):
return 200, self.res
-class UserRegisterServlet(RestServlet):
- """
- Attributes:
- NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
- nonces (dict[str, int]): The nonces that we will accept. A dict of
- nonce to the time it was generated, in int seconds.
- """
-
- PATTERNS = historical_admin_path_patterns("/register")
- NONCE_TIMEOUT = 60
-
- def __init__(self, hs):
- self.handlers = hs.get_handlers()
- self.reactor = hs.get_reactor()
- self.nonces = {}
- self.hs = hs
-
- def _clear_old_nonces(self):
- """
- Clear out old nonces that are older than NONCE_TIMEOUT.
- """
- now = int(self.reactor.seconds())
-
- for k, v in list(self.nonces.items()):
- if now - v > self.NONCE_TIMEOUT:
- del self.nonces[k]
-
- def on_GET(self, request):
- """
- Generate a new nonce.
- """
- self._clear_old_nonces()
-
- nonce = self.hs.get_secrets().token_hex(64)
- self.nonces[nonce] = int(self.reactor.seconds())
- return 200, {"nonce": nonce}
-
- async def on_POST(self, request):
- self._clear_old_nonces()
-
- if not self.hs.config.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
-
- body = parse_json_object_from_request(request)
-
- if "nonce" not in body:
- raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
-
- nonce = body["nonce"]
-
- if nonce not in self.nonces:
- raise SynapseError(400, "unrecognised nonce")
-
- # Delete the nonce, so it can't be reused, even if it's invalid
- del self.nonces[nonce]
-
- if "username" not in body:
- raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON
- )
- else:
- if (
- not isinstance(body["username"], text_type)
- or len(body["username"]) > 512
- ):
- raise SynapseError(400, "Invalid username")
-
- username = body["username"].encode("utf-8")
- if b"\x00" in username:
- raise SynapseError(400, "Invalid username")
-
- if "password" not in body:
- raise SynapseError(
- 400, "password must be specified", errcode=Codes.BAD_JSON
- )
- else:
- if (
- not isinstance(body["password"], text_type)
- or len(body["password"]) > 512
- ):
- raise SynapseError(400, "Invalid password")
-
- password = body["password"].encode("utf-8")
- if b"\x00" in password:
- raise SynapseError(400, "Invalid password")
-
- admin = body.get("admin", None)
- user_type = body.get("user_type", None)
-
- if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
- raise SynapseError(400, "Invalid user type")
-
- got_mac = body["mac"]
-
- want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
- digestmod=hashlib.sha1,
- )
- want_mac.update(nonce.encode("utf8"))
- want_mac.update(b"\x00")
- want_mac.update(username)
- want_mac.update(b"\x00")
- want_mac.update(password)
- want_mac.update(b"\x00")
- want_mac.update(b"admin" if admin else b"notadmin")
- if user_type:
- want_mac.update(b"\x00")
- want_mac.update(user_type.encode("utf8"))
- want_mac = want_mac.hexdigest()
-
- if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
- raise SynapseError(403, "HMAC incorrect")
-
- # Reuse the parts of RegisterRestServlet to reduce code duplication
- from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-
- register = RegisterRestServlet(self.hs)
-
- user_id = await register.registration_handler.register_user(
- localpart=body["username"].lower(),
- password=body["password"],
- admin=bool(admin),
- user_type=user_type,
- )
-
- result = await register._create_registration_details(user_id, body)
- return 200, result
-
-
-class WhoisRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
-
- def __init__(self, hs):
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- async def on_GET(self, request, user_id):
- target_user = UserID.from_string(user_id)
- requester = await self.auth.get_user_by_req(request)
- auth_user = requester.user
-
- if target_user != auth_user:
- await assert_user_is_admin(self.auth, auth_user)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only whois a local user")
-
- ret = await self.handlers.admin_handler.get_whois(target_user)
-
- return 200, ret
-
-
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
@@ -286,7 +110,7 @@ class PurgeHistoryRestServlet(RestServlet):
room_id, stream_ordering
)
if not r:
- logger.warn(
+ logger.warning(
"[purge] purging events not possible: No event found "
"(received_ts %i => stream_ordering %i)",
ts,
@@ -342,369 +166,6 @@ class PurgeHistoryStatusRestServlet(RestServlet):
return 200, purge_status.asdict()
-class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self._deactivate_account_handler = hs.get_deactivate_account_handler()
- self.auth = hs.get_auth()
-
- async def on_POST(self, request, target_user_id):
- await assert_requester_is_admin(self.auth, request)
- body = parse_json_object_from_request(request, allow_empty_body=True)
- erase = body.get("erase", False)
- if not isinstance(erase, bool):
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Param 'erase' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- UserID.from_string(target_user_id)
-
- result = await self._deactivate_account_handler.deactivate_account(
- target_user_id, erase
- )
- if result:
- id_server_unbind_result = "success"
- else:
- id_server_unbind_result = "no-support"
-
- return 200, {"id_server_unbind_result": id_server_unbind_result}
-
-
-class ShutdownRoomRestServlet(RestServlet):
- """Shuts down a room by removing all local users from the room and blocking
- all future invites and joins to the room. Any local aliases will be repointed
- to a new room created by `new_room_user_id` and kicked users will be auto
- joined to the new room.
- """
-
- PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
-
- DEFAULT_MESSAGE = (
- "Sharing illegal content on this server is not permitted and rooms in"
- " violation will be blocked."
- )
-
- def __init__(self, hs):
- self.hs = hs
- self.store = hs.get_datastore()
- self.state = hs.get_state_handler()
- self._room_creation_handler = hs.get_room_creation_handler()
- self.event_creation_handler = hs.get_event_creation_handler()
- self.room_member_handler = hs.get_room_member_handler()
- self.auth = hs.get_auth()
-
- async def on_POST(self, request, room_id):
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- content = parse_json_object_from_request(request)
- assert_params_in_dict(content, ["new_room_user_id"])
- new_room_user_id = content["new_room_user_id"]
-
- room_creator_requester = create_requester(new_room_user_id)
-
- message = content.get("message", self.DEFAULT_MESSAGE)
- room_name = content.get("room_name", "Content Violation Notification")
-
- info = await self._room_creation_handler.create_room(
- room_creator_requester,
- config={
- "preset": "public_chat",
- "name": room_name,
- "power_level_content_override": {"users_default": -10},
- },
- ratelimit=False,
- )
- new_room_id = info["room_id"]
-
- requester_user_id = requester.user.to_string()
-
- logger.info(
- "Shutting down room %r, joining to new room: %r", room_id, new_room_id
- )
-
- # This will work even if the room is already blocked, but that is
- # desirable in case the first attempt at blocking the room failed below.
- await self.store.block_room(room_id, requester_user_id)
-
- users = await self.state.get_current_users_in_room(room_id)
- kicked_users = []
- failed_to_kick_users = []
- for user_id in users:
- if not self.hs.is_mine_id(user_id):
- continue
-
- logger.info("Kicking %r from %r...", user_id, room_id)
-
- try:
- target_requester = create_requester(user_id)
- await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=room_id,
- action=Membership.LEAVE,
- content={},
- ratelimit=False,
- require_consent=False,
- )
-
- await self.room_member_handler.forget(target_requester.user, room_id)
-
- await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=new_room_id,
- action=Membership.JOIN,
- content={},
- ratelimit=False,
- require_consent=False,
- )
-
- kicked_users.append(user_id)
- except Exception:
- logger.exception(
- "Failed to leave old room and join new room for %r", user_id
- )
- failed_to_kick_users.append(user_id)
-
- await self.event_creation_handler.create_and_send_nonmember_event(
- room_creator_requester,
- {
- "type": "m.room.message",
- "content": {"body": message, "msgtype": "m.text"},
- "room_id": new_room_id,
- "sender": new_room_user_id,
- },
- ratelimit=False,
- )
-
- aliases_for_room = await maybe_awaitable(
- self.store.get_aliases_for_room(room_id)
- )
-
- await self.store.update_aliases_for_room(
- room_id, new_room_id, requester_user_id
- )
-
- return (
- 200,
- {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- },
- )
-
-
-class ResetPasswordRestServlet(RestServlet):
- """Post request to allow an administrator reset password for a user.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/reset_password/
- @user:to_reset_password?access_token=admin_access_token
- JsonBodyToSend:
- {
- "new_password": "secret"
- }
- Returns:
- 200 OK with empty object if success otherwise an error.
- """
-
- PATTERNS = historical_admin_path_patterns(
- "/reset_password/(?P<target_user_id>[^/]*)"
- )
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self._set_password_handler = hs.get_set_password_handler()
-
- async def on_POST(self, request, target_user_id):
- """Post request to allow an administrator reset password for a user.
- This needs user to have administrator access in Synapse.
- """
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- UserID.from_string(target_user_id)
-
- params = parse_json_object_from_request(request)
- assert_params_in_dict(params, ["new_password"])
- new_password = params["new_password"]
-
- await self._set_password_handler.set_password(
- target_user_id, new_password, requester
- )
- return 200, {}
-
-
-class GetUsersPaginatedRestServlet(RestServlet):
- """Get request to get specific number of users from Synapse.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/users_paginate/
- @admin:user?access_token=admin_access_token&start=0&limit=10
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
-
- PATTERNS = historical_admin_path_patterns(
- "/users_paginate/(?P<target_user_id>[^/]*)"
- )
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- async def on_GET(self, request, target_user_id):
- """Get request to get specific number of users from Synapse.
- This needs user to have administrator access in Synapse.
- """
- await assert_requester_is_admin(self.auth, request)
-
- target_user = UserID.from_string(target_user_id)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- order = "name" # order by name in user table
- start = parse_integer(request, "start", required=True)
- limit = parse_integer(request, "limit", required=True)
-
- logger.info("limit: %s, start: %s", limit, start)
-
- ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
- return 200, ret
-
- async def on_POST(self, request, target_user_id):
- """Post request to get specific number of users from Synapse..
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/users_paginate/
- @admin:user?access_token=admin_access_token
- JsonBodyToSend:
- {
- "start": "0",
- "limit": "10
- }
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
- await assert_requester_is_admin(self.auth, request)
- UserID.from_string(target_user_id)
-
- order = "name" # order by name in user table
- params = parse_json_object_from_request(request)
- assert_params_in_dict(params, ["limit", "start"])
- limit = params["limit"]
- start = params["start"]
- logger.info("limit: %s, start: %s", limit, start)
-
- ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
- return 200, ret
-
-
-class SearchUsersRestServlet(RestServlet):
- """Get request to search user table for specific users according to
- search term.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/search_users/
- @admin:user?access_token=admin_access_token&term=alice
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
-
- PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- async def on_GET(self, request, target_user_id):
- """Get request to search user table for specific users according to
- search term.
- This needs user to have a administrator access in Synapse.
- """
- await assert_requester_is_admin(self.auth, request)
-
- target_user = UserID.from_string(target_user_id)
-
- # To allow all users to get the users list
- # if not is_admin and target_user != auth_user:
- # raise AuthError(403, "You are not a server admin")
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- term = parse_string(request, "term", required=True)
- logger.info("term: %s ", term)
-
- ret = await self.handlers.admin_handler.search_users(term)
- return 200, ret
-
-
-class DeleteGroupAdminRestServlet(RestServlet):
- """Allows deleting of local groups
- """
-
- PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
-
- def __init__(self, hs):
- self.group_server = hs.get_groups_server_handler()
- self.is_mine_id = hs.is_mine_id
- self.auth = hs.get_auth()
-
- async def on_POST(self, request, group_id):
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- if not self.is_mine_id(group_id):
- raise SynapseError(400, "Can only delete local groups")
-
- await self.group_server.delete_group(group_id, requester.user.to_string())
- return 200, {}
-
-
-class AccountValidityRenewServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- self.hs = hs
- self.account_activity_handler = hs.get_account_validity_handler()
- self.auth = hs.get_auth()
-
- async def on_POST(self, request):
- await assert_requester_is_admin(self.auth, request)
-
- body = parse_json_object_from_request(request)
-
- if "user_id" not in body:
- raise SynapseError(400, "Missing property 'user_id' in the request body")
-
- expiration_ts = await self.account_activity_handler.renew_account_for_user(
- body["user_id"],
- body.get("expiration_ts"),
- not body.get("enable_renewal_emails", True),
- )
-
- res = {"expiration_ts": expiration_ts}
- return 200, res
-
-
########################################################################################
#
# please don't add more servlets here: this file is already long and unwieldy. Put
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
new file mode 100644
index 00000000..0b54ca09
--- /dev/null
+++ b/synapse/rest/admin/groups.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.rest.admin._base import (
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class DeleteGroupAdminRestServlet(RestServlet):
+ """Allows deleting of local groups
+ """
+
+ PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.group_server = hs.get_groups_server_handler()
+ self.is_mine_id = hs.is_mine_id
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ if not self.is_mine_id(group_id):
+ raise SynapseError(400, "Can only delete local groups")
+
+ await self.group_server.delete_group(group_id, requester.user.to_string())
+ return 200, {}
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
new file mode 100644
index 00000000..f7cc5e9b
--- /dev/null
+++ b/synapse/rest/admin/rooms.py
@@ -0,0 +1,157 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.constants import Membership
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.rest.admin._base import (
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+from synapse.types import create_requester
+from synapse.util.async_helpers import maybe_awaitable
+
+logger = logging.getLogger(__name__)
+
+
+class ShutdownRoomRestServlet(RestServlet):
+ """Shuts down a room by removing all local users from the room and blocking
+ all future invites and joins to the room. Any local aliases will be repointed
+ to a new room created by `new_room_user_id` and kicked users will be auto
+ joined to the new room.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
+
+ DEFAULT_MESSAGE = (
+ "Sharing illegal content on this server is not permitted and rooms in"
+ " violation will be blocked."
+ )
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ["new_room_user_id"])
+ new_room_user_id = content["new_room_user_id"]
+
+ room_creator_requester = create_requester(new_room_user_id)
+
+ message = content.get("message", self.DEFAULT_MESSAGE)
+ room_name = content.get("room_name", "Content Violation Notification")
+
+ info = await self._room_creation_handler.create_room(
+ room_creator_requester,
+ config={
+ "preset": "public_chat",
+ "name": room_name,
+ "power_level_content_override": {"users_default": -10},
+ },
+ ratelimit=False,
+ )
+ new_room_id = info["room_id"]
+
+ requester_user_id = requester.user.to_string()
+
+ logger.info(
+ "Shutting down room %r, joining to new room: %r", room_id, new_room_id
+ )
+
+ # This will work even if the room is already blocked, but that is
+ # desirable in case the first attempt at blocking the room failed below.
+ await self.store.block_room(room_id, requester_user_id)
+
+ users = await self.state.get_current_users_in_room(room_id)
+ kicked_users = []
+ failed_to_kick_users = []
+ for user_id in users:
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ logger.info("Kicking %r from %r...", user_id, room_id)
+
+ try:
+ target_requester = create_requester(user_id)
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=room_id,
+ action=Membership.LEAVE,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ await self.room_member_handler.forget(target_requester.user, room_id)
+
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=new_room_id,
+ action=Membership.JOIN,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ kicked_users.append(user_id)
+ except Exception:
+ logger.exception(
+ "Failed to leave old room and join new room for %r", user_id
+ )
+ failed_to_kick_users.append(user_id)
+
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ room_creator_requester,
+ {
+ "type": "m.room.message",
+ "content": {"body": message, "msgtype": "m.text"},
+ "room_id": new_room_id,
+ "sender": new_room_user_id,
+ },
+ ratelimit=False,
+ )
+
+ aliases_for_room = await maybe_awaitable(
+ self.store.get_aliases_for_room(room_id)
+ )
+
+ await self.store.update_aliases_for_room(
+ room_id, new_room_id, requester_user_id
+ )
+
+ return (
+ 200,
+ {
+ "kicked_users": kicked_users,
+ "failed_to_kick_users": failed_to_kick_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ },
+ )
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index d5d124a0..58a83f93 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -12,17 +12,419 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
+import hmac
+import logging
import re
-from synapse.api.errors import SynapseError
+from six import text_type
+from six.moves import http_client
+
+from synapse.api.constants import UserTypes
+from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_integer,
parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.admin._base import (
+ assert_requester_is_admin,
+ assert_user_is_admin,
+ historical_admin_path_patterns,
)
-from synapse.rest.admin import assert_requester_is_admin, assert_user_is_admin
from synapse.types import UserID
+logger = logging.getLogger(__name__)
+
+
+class UsersRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ ret = await self.admin_handler.get_users()
+
+ return 200, ret
+
+
+class GetUsersPaginatedRestServlet(RestServlet):
+ """Get request to get specific number of users from Synapse.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/users_paginate/
+ @admin:user?access_token=admin_access_token&start=0&limit=10
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/users_paginate/(?P<target_user_id>[^/]*)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, target_user_id):
+ """Get request to get specific number of users from Synapse.
+ This needs user to have administrator access in Synapse.
+ """
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(target_user_id)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ order = "name" # order by name in user table
+ start = parse_integer(request, "start", required=True)
+ limit = parse_integer(request, "limit", required=True)
+
+ logger.info("limit: %s, start: %s", limit, start)
+
+ ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
+ return 200, ret
+
+ async def on_POST(self, request, target_user_id):
+ """Post request to get specific number of users from Synapse..
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/users_paginate/
+ @admin:user?access_token=admin_access_token
+ JsonBodyToSend:
+ {
+ "start": "0",
+ "limit": "10
+ }
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+ await assert_requester_is_admin(self.auth, request)
+ UserID.from_string(target_user_id)
+
+ order = "name" # order by name in user table
+ params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["limit", "start"])
+ limit = params["limit"]
+ start = params["start"]
+ logger.info("limit: %s, start: %s", limit, start)
+
+ ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
+ return 200, ret
+
+
+class UserRegisterServlet(RestServlet):
+ """
+ Attributes:
+ NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
+ nonces (dict[str, int]): The nonces that we will accept. A dict of
+ nonce to the time it was generated, in int seconds.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/register")
+ NONCE_TIMEOUT = 60
+
+ def __init__(self, hs):
+ self.handlers = hs.get_handlers()
+ self.reactor = hs.get_reactor()
+ self.nonces = {}
+ self.hs = hs
+
+ def _clear_old_nonces(self):
+ """
+ Clear out old nonces that are older than NONCE_TIMEOUT.
+ """
+ now = int(self.reactor.seconds())
+
+ for k, v in list(self.nonces.items()):
+ if now - v > self.NONCE_TIMEOUT:
+ del self.nonces[k]
+
+ def on_GET(self, request):
+ """
+ Generate a new nonce.
+ """
+ self._clear_old_nonces()
+
+ nonce = self.hs.get_secrets().token_hex(64)
+ self.nonces[nonce] = int(self.reactor.seconds())
+ return 200, {"nonce": nonce}
+
+ async def on_POST(self, request):
+ self._clear_old_nonces()
+
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ body = parse_json_object_from_request(request)
+
+ if "nonce" not in body:
+ raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
+
+ nonce = body["nonce"]
+
+ if nonce not in self.nonces:
+ raise SynapseError(400, "unrecognised nonce")
+
+ # Delete the nonce, so it can't be reused, even if it's invalid
+ del self.nonces[nonce]
+
+ if "username" not in body:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON
+ )
+ else:
+ if (
+ not isinstance(body["username"], text_type)
+ or len(body["username"]) > 512
+ ):
+ raise SynapseError(400, "Invalid username")
+
+ username = body["username"].encode("utf-8")
+ if b"\x00" in username:
+ raise SynapseError(400, "Invalid username")
+
+ if "password" not in body:
+ raise SynapseError(
+ 400, "password must be specified", errcode=Codes.BAD_JSON
+ )
+ else:
+ if (
+ not isinstance(body["password"], text_type)
+ or len(body["password"]) > 512
+ ):
+ raise SynapseError(400, "Invalid password")
+
+ password = body["password"].encode("utf-8")
+ if b"\x00" in password:
+ raise SynapseError(400, "Invalid password")
+
+ admin = body.get("admin", None)
+ user_type = body.get("user_type", None)
+
+ if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+ raise SynapseError(400, "Invalid user type")
+
+ got_mac = body["mac"]
+
+ want_mac = hmac.new(
+ key=self.hs.config.registration_shared_secret.encode(),
+ digestmod=hashlib.sha1,
+ )
+ want_mac.update(nonce.encode("utf8"))
+ want_mac.update(b"\x00")
+ want_mac.update(username)
+ want_mac.update(b"\x00")
+ want_mac.update(password)
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
+ if user_type:
+ want_mac.update(b"\x00")
+ want_mac.update(user_type.encode("utf8"))
+ want_mac = want_mac.hexdigest()
+
+ if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
+ raise SynapseError(403, "HMAC incorrect")
+
+ # Reuse the parts of RegisterRestServlet to reduce code duplication
+ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+
+ register = RegisterRestServlet(self.hs)
+
+ user_id = await register.registration_handler.register_user(
+ localpart=body["username"].lower(),
+ password=body["password"],
+ admin=bool(admin),
+ user_type=user_type,
+ )
+
+ result = await register._create_registration_details(user_id, body)
+ return 200, result
+
+
+class WhoisRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ requester = await self.auth.get_user_by_req(request)
+ auth_user = requester.user
+
+ if target_user != auth_user:
+ await assert_user_is_admin(self.auth, auth_user)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only whois a local user")
+
+ ret = await self.handlers.admin_handler.get_whois(target_user)
+
+ return 200, ret
+
+
+class DeactivateAccountRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, target_user_id):
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ UserID.from_string(target_user_id)
+
+ result = await self._deactivate_account_handler.deactivate_account(
+ target_user_id, erase
+ )
+ if result:
+ id_server_unbind_result = "success"
+ else:
+ id_server_unbind_result = "no-support"
+
+ return 200, {"id_server_unbind_result": id_server_unbind_result}
+
+
+class AccountValidityRenewServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.hs = hs
+ self.account_activity_handler = hs.get_account_validity_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ body = parse_json_object_from_request(request)
+
+ if "user_id" not in body:
+ raise SynapseError(400, "Missing property 'user_id' in the request body")
+
+ expiration_ts = await self.account_activity_handler.renew_account_for_user(
+ body["user_id"],
+ body.get("expiration_ts"),
+ not body.get("enable_renewal_emails", True),
+ )
+
+ res = {"expiration_ts": expiration_ts}
+ return 200, res
+
+
+class ResetPasswordRestServlet(RestServlet):
+ """Post request to allow an administrator reset password for a user.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/reset_password/
+ @user:to_reset_password?access_token=admin_access_token
+ JsonBodyToSend:
+ {
+ "new_password": "secret"
+ }
+ Returns:
+ 200 OK with empty object if success otherwise an error.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/reset_password/(?P<target_user_id>[^/]*)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self._set_password_handler = hs.get_set_password_handler()
+
+ async def on_POST(self, request, target_user_id):
+ """Post request to allow an administrator reset password for a user.
+ This needs user to have administrator access in Synapse.
+ """
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ UserID.from_string(target_user_id)
+
+ params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["new_password"])
+ new_password = params["new_password"]
+
+ await self._set_password_handler.set_password(
+ target_user_id, new_password, requester
+ )
+ return 200, {}
+
+
+class SearchUsersRestServlet(RestServlet):
+ """Get request to search user table for specific users according to
+ search term.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/search_users/
+ @admin:user?access_token=admin_access_token&term=alice
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, target_user_id):
+ """Get request to search user table for specific users according to
+ search term.
+ This needs user to have a administrator access in Synapse.
+ """
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(target_user_id)
+
+ # To allow all users to get the users list
+ # if not is_admin and target_user != auth_user:
+ # raise AuthError(403, "You are not a server admin")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ term = parse_string(request, "term", required=True)
+ logger.info("term: %s ", term)
+
+ ret = await self.handlers.admin_handler.search_users(term)
+ return 200, ret
+
class UserAdminServlet(RestServlet):
"""
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 8414af08..19eb1500 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -92,8 +92,11 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
+ self._clock = hs.get_clock()
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter()
+ self._account_ratelimiter = Ratelimiter()
+ self._failed_attempts_ratelimiter = Ratelimiter()
def on_GET(self, request):
flows = []
@@ -202,15 +205,27 @@ class LoginRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ self._failed_attempts_ratelimiter.ratelimit(
+ (medium, address),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
+ )
+
# Check for login providers that support 3pid login types
- canonical_user_id, callback_3pid = (
- yield self.auth_handler.check_password_provider_3pid(
- medium, address, login_submission["password"]
- )
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = yield self.auth_handler.check_password_provider_3pid(
+ medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
- result = yield self._register_device_with_callback(
+
+ result = yield self._complete_login(
canonical_user_id, login_submission, callback_3pid
)
return result
@@ -221,9 +236,24 @@ class LoginRestServlet(RestServlet):
medium, address
)
if not user_id:
- logger.warn(
+ logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ self._failed_attempts_ratelimiter.can_do_action(
+ (medium, address),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
+ )
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
@@ -235,29 +265,84 @@ class LoginRestServlet(RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- canonical_user_id, callback = yield self.auth_handler.validate_login(
- identifier["user"], login_submission
+ if identifier["user"].startswith("@"):
+ qualified_user_id = identifier["user"]
+ else:
+ qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ self._failed_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
)
- result = yield self._register_device_with_callback(
+ try:
+ canonical_user_id, callback = yield self.auth_handler.validate_login(
+ identifier["user"], login_submission
+ )
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ self._failed_attempts_ratelimiter.can_do_action(
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
+ )
+ raise
+
+ result = yield self._complete_login(
canonical_user_id, login_submission, callback
)
return result
@defer.inlineCallbacks
- def _register_device_with_callback(self, user_id, login_submission, callback=None):
- """ Registers a device with a given user_id. Optionally run a callback
- function after registration has completed.
+ def _complete_login(
+ self, user_id, login_submission, callback=None, create_non_existant_users=False
+ ):
+ """Called when we've successfully authed the user and now need to
+ actually login them in (e.g. create devices). This gets called on
+ all succesful logins.
+
+ Applies the ratelimiting for succesful login attempts against an
+ account.
Args:
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
+ create_non_existant_users (bool): Whether to create the user if
+ they don't exist. Defaults to False.
Returns:
result (Dict[str,str]): Dictionary of account information after
successful registration.
"""
+
+ # Before we actually log them in we check if they've already logged in
+ # too often. This happens here rather than before as we don't
+ # necessarily know the user before now.
+ self._account_ratelimiter.ratelimit(
+ user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_account.per_second,
+ burst_count=self.hs.config.rc_login_account.burst_count,
+ update=True,
+ )
+
+ if create_non_existant_users:
+ user_id = yield self.auth_handler.check_user_exists(user_id)
+ if not user_id:
+ user_id = yield self.registration_handler.register_user(
+ localpart=UserID.from_string(user_id).localpart
+ )
+
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
@@ -280,11 +365,11 @@ class LoginRestServlet(RestServlet):
def do_token_login(self, login_submission):
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = (
- yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
)
- result = yield self._register_device_with_callback(user_id, login_submission)
+ result = yield self._complete_login(user_id, login_submission)
return result
@defer.inlineCallbacks
@@ -312,15 +397,8 @@ class LoginRestServlet(RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID(user, self.hs.hostname).to_string()
-
- registered_user_id = yield self.auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id = yield self.registration_handler.register_user(
- localpart=user
- )
-
- result = yield self._register_device_with_callback(
- registered_user_id, login_submission
+ result = yield self._complete_login(
+ user_id, login_submission, create_non_existant_users=True
)
return result
@@ -380,7 +458,7 @@ class CasTicketServlet(RestServlet):
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 9c1d4142..86bbcc0e 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
- info = yield self._room_creation_handler.create_room(
+ info = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet):
def on_PUT_no_state_key(self, request, room_id, event_type):
return self.on_PUT(request, room_id, event_type, "")
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_type, state_key):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_type, state_key):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(
request, "format", default="content", allowed_values=["content", "event"]
)
msg_handler = self.message_handler
- data = yield msg_handler.get_room_data(
+ data = await msg_handler.get_room_data(
user_id=requester.user.to_string(),
room_id=room_id,
event_type=event_type,
@@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content":
return 200, data.get_dict()["content"]
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ requester = await self.auth.get_user_by_req(request)
if txn_id:
set_tag("txn_id", txn_id)
@@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = yield self.room_member_handler.update_membership(
+ event = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_type, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_id, event_type, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
event_dict = {
@@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_identifier, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_identifier, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
content = parse_json_object_from_request(request)
@@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet):
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
+ room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
server = parse_string(request, "server", default=None)
try:
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ await self.auth.get_user_by_req(request, allow_guest=True)
except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
@@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server:
- data = yield handler.get_remote_public_room_list(
+ data = await handler.get_remote_public_room_list(
server, limit=limit, since_token=since_token
)
else:
- data = yield handler.get_local_public_room_list(
+ data = await handler.get_local_public_room_list(
limit=limit, since_token=since_token
)
return 200, data
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
@@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server:
- data = yield handler.get_remote_public_room_list(
+ data = await handler.get_remote_public_room_list(
server,
limit=limit,
since_token=since_token,
@@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
third_party_instance_id=third_party_instance_id,
)
else:
- data = yield handler.get_local_public_room_list(
+ data = await handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
search_filter=search_filter,
@@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
+ async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
@@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet):
membership = parse_string(request, "membership")
not_membership = parse_string(request, "not_membership")
- events = yield handler.get_state_events(
+ events = await handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
at_token=at_token,
@@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- users_with_profile = yield self.message_handler.get_joined_members(
+ users_with_profile = await self.message_handler.get_joined_members(
requester, room_id
)
@@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None)
@@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet):
as_client_event = False
else:
event_filter = None
- msgs = yield self.pagination_handler.get_messages(
+ msgs = await self.pagination_handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
@@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
- events = yield self.message_handler.get_state_events(
+ events = await self.message_handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
@@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
- content = yield self.initial_sync_handler.room_initial_sync(
+ content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
return 200, content
@@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
- event = yield self.event_handler.get_event(
+ event = await self.event_handler.get_event(
requester.user, room_id, event_id
)
except AuthError:
@@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- event = yield self._event_serializer.serialize_event(event, time_now)
+ event = await self._event_serializer.serialize_event(event, time_now)
return 200, event
return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet):
else:
event_filter = None
- results = yield self.room_context_handler.get_event_context(
+ results = await self.room_context_handler.get_event_context(
requester.user, room_id, event_id, limit, event_filter
)
@@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet):
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- results["events_before"] = yield self._event_serializer.serialize_events(
+ results["events_before"] = await self._event_serializer.serialize_events(
results["events_before"], time_now
)
- results["event"] = yield self._event_serializer.serialize_event(
+ results["event"] = await self._event_serializer.serialize_event(
results["event"], time_now
)
- results["events_after"] = yield self._event_serializer.serialize_events(
+ results["events_after"] = await self._event_serializer.serialize_events(
results["events_after"], time_now
)
- results["state"] = yield self._event_serializer.serialize_events(
+ results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now
)
@@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ async def on_POST(self, request, room_id, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
- yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
+ await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
@@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, membership_action, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_id, membership_action, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
Membership.JOIN,
@@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- yield self.room_member_handler.do_3pid_invite(
+ await self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
@@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if "reason" in content and membership_action in ["kick", "ban"]:
event_content = {"reason": content["reason"]}
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
@@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, event_id, txn_id=None):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]:
- yield self.typing_handler.started_typing(
+ await self.typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
timeout=timeout,
)
else:
- yield self.typing_handler.stopped_typing(
+ await self.typing_handler.stopped_typing(
target_user=target_user, auth_user=requester.user, room_id=room_id
)
@@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
batch = parse_string(request, "next_batch")
- results = yield self.handlers.search_handler.search(
+ results = await self.handlers.search_handler.search(
requester.user, content, batch
)
@@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
+ room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 80cf7126..f26eae79 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -71,7 +71,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"User password resets have been disabled due to lack of email config"
)
raise SynapseError(
@@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_password_reset_template_failure_html],
)
@@ -162,7 +162,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Password reset emails have been disabled due to lack of an email config"
)
raise SynapseError(
@@ -183,7 +183,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
@@ -350,7 +350,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Adding emails have been disabled due to lack of an email config"
)
raise SynapseError(
@@ -441,7 +441,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
if not self.hs.config.account_threepid_delegate_msisdn:
- logger.warn(
+ logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
)
@@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_add_threepid_template_failure_html],
)
@@ -488,7 +488,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
def on_GET(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Adding emails have been disabled due to lack of an email config"
)
raise SynapseError(
@@ -515,7 +515,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index b3bf8567..67cbc373 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_patterns
@@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None)
if read_event_id:
- yield self.receipts_handler.received_client_receipt(
+ await self.receipts_handler.received_client_receipt(
room_id,
"m.read",
user_id=requester.user.to_string(),
@@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet):
read_marker_event_id = body.get("m.fully_read", None)
if read_marker_event_id:
- yield self.read_marker_handler.received_client_read_marker(
+ await self.read_marker_handler.received_client_read_marker(
room_id,
user_id=requester.user.to_string(),
event_id=read_marker_event_id,
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 0dab03d2..92555bd4 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
@@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, receipt_type, event_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, receipt_type, event_id):
+ requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
- yield self.receipts_handler.received_client_receipt(
+ await self.receipts_handler.received_client_receipt(
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 4f24a124..91db9238 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -106,7 +106,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Email registration has been disabled due to lack of email config"
)
raise SynapseError(
@@ -207,7 +207,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
if not self.hs.config.account_threepid_delegate_msisdn:
- logger.warn(
+ logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
)
@@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
@@ -266,7 +266,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"User registration via email has been disabled due to lack of email config"
)
raise SynapseError(
@@ -287,7 +287,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
@@ -480,7 +480,7 @@ class RegisterRestServlet(RestServlet):
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
# the original registration params
- logger.warn("Ignoring initial_device_display_name without password")
+ logger.warning("Ignoring initial_device_display_name without password")
del body["initial_device_display_name"]
session_id = self.auth_handler.get_session_id(body)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a883c8ad..ccd8b17b 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -112,9 +112,14 @@ class SyncRestServlet(RestServlet):
full_state = parse_boolean(request, "full_state", default=False)
logger.debug(
- "/sync: user=%r, timeout=%r, since=%r,"
- " set_presence=%r, filter_id=%r, device_id=%r"
- % (user, timeout, since, set_presence, filter_id, device_id)
+ "/sync: user=%r, timeout=%r, since=%r, "
+ "set_presence=%r, filter_id=%r, device_id=%r",
+ user,
+ timeout,
+ since,
+ set_presence,
+ filter_id,
+ device_id,
)
request_key = (user, timeout, since, filter_id, full_state, device_id)
@@ -389,7 +394,7 @@ class SyncRestServlet(RestServlet):
# We've had bug reports that events were coming down under the
# wrong room.
if event.room_id != room.room_id:
- logger.warn(
+ logger.warning(
"Event %r is under room %r instead of %r",
event.event_id,
room.room_id,
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 1044ae7b..bb30ce3f 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet):
"m.require_identity_server": False,
# as per MSC2290
"m.separate_add_and_bind": True,
+ # Implements support for label-based filtering as described in
+ # MSC2326.
+ "org.matrix.label_based_filtering": True,
},
},
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 55580bc5..e7fc3f04 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource):
@wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
- server, = request.postpath
+ (server,) = request.postpath
query = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index b972e152..bd9186fe 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -363,7 +363,7 @@ class MediaRepository(object):
},
)
except RequestSendFailed as e:
- logger.warn(
+ logger.warning(
"Request failed fetching remote media %s/%s: %r",
server_name,
media_id,
@@ -372,7 +372,7 @@ class MediaRepository(object):
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
- logger.warn(
+ logger.warning(
"HTTP error fetching remote media %s/%s: %s",
server_name,
media_id,
@@ -383,10 +383,12 @@ class MediaRepository(object):
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
- logger.warn("Failed to fetch remote media %s/%s", server_name, media_id)
+ logger.warning(
+ "Failed to fetch remote media %s/%s", server_name, media_id
+ )
raise
except NotRetryingDestination:
- logger.warn("Not retrying destination %r", server_name)
+ logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception(
@@ -691,7 +693,7 @@ class MediaRepository(object):
try:
os.remove(full_path)
except OSError as e:
- logger.warn("Failed to remove file: %r", full_path)
+ logger.warning("Failed to remove file: %r", full_path)
if e.errno == errno.ENOENT:
pass
else:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index ec9c4619..87343d9d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -77,6 +77,8 @@ class PreviewUrlResource(DirectServeResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
@@ -120,8 +122,10 @@ class PreviewUrlResource(DirectServeResource):
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug(
- ("Matching attrib '%s' with value '%s' against" " pattern '%s'")
- % (attrib, value, pattern)
+ "Matching attrib '%s' with value '%s' against" " pattern '%s'",
+ attrib,
+ value,
+ pattern,
)
if value is None:
@@ -137,7 +141,7 @@ class PreviewUrlResource(DirectServeResource):
match = False
continue
if match:
- logger.warn("URL %s blocked by url_blacklist entry %s", url, entry)
+ logger.warning("URL %s blocked by url_blacklist entry %s", url, entry)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
)
@@ -189,7 +193,7 @@ class PreviewUrlResource(DirectServeResource):
media_info = yield self._download_url(url, user)
- logger.debug("got media_info of '%s'" % media_info)
+ logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
@@ -209,7 +213,7 @@ class PreviewUrlResource(DirectServeResource):
og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"]
else:
- logger.warn("Couldn't get dims for %s" % url)
+ logger.warning("Couldn't get dims for %s" % url)
# define our OG response for this media
elif _is_html(media_info["media_type"]):
@@ -257,7 +261,7 @@ class PreviewUrlResource(DirectServeResource):
og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"]
else:
- logger.warn("Couldn't get dims for %s" % og["og:image"])
+ logger.warning("Couldn't get dims for %s", og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name,
@@ -268,7 +272,7 @@ class PreviewUrlResource(DirectServeResource):
else:
del og["og:image"]
else:
- logger.warn("Failed to find any OG data in %s", url)
+ logger.warning("Failed to find any OG data in %s", url)
og = {}
# filter out any stupidly long values
@@ -283,7 +287,7 @@ class PreviewUrlResource(DirectServeResource):
for k in keys_to_remove:
del og[k]
- logger.debug("Calculated OG for %s as %s" % (url, og))
+ logger.debug("Calculated OG for %s as %s", url, og)
jsonog = json.dumps(og)
@@ -312,7 +316,7 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
- logger.debug("Trying to get url '%s'" % url)
+ logger.debug("Trying to get url '%s'", url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size
)
@@ -332,7 +336,7 @@ class PreviewUrlResource(DirectServeResource):
)
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
- logger.warn("Error downloading %s: %r", url, e)
+ logger.warning("Error downloading %s: %r", url, e)
raise SynapseError(
500,
@@ -413,7 +417,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
@@ -445,7 +449,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
try:
@@ -461,7 +465,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 08329884..931ce79b 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -182,7 +182,7 @@ class ThumbnailResource(DirectServeResource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- logger.warn("Failed to generate thumbnail")
+ logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
@@ -245,7 +245,7 @@ class ThumbnailResource(DirectServeResource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- logger.warn("Failed to generate thumbnail")
+ logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
diff --git a/synapse/server.py b/synapse/server.py
index 1fcc7375..be9af7f9 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,6 +23,7 @@
# Imports required for the default HomeServer() implementation
import abc
import logging
+import os
from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
@@ -95,6 +96,7 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
+from synapse.storage import DataStores, Storage
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -167,6 +169,7 @@ class HomeServer(object):
"filtering",
"http_client_context_factory",
"simple_http_client",
+ "proxied_http_client",
"media_repository",
"media_repository_resource",
"federation_transport_client",
@@ -196,6 +199,7 @@ class HomeServer(object):
"account_validity_handler",
"saml_handler",
"event_client_serializer",
+ "storage",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -217,6 +221,7 @@ class HomeServer(object):
self.hostname = hostname
self._building = {}
self._listening_services = []
+ self.start_time = None
self.clock = Clock(reactor)
self.distributor = Distributor()
@@ -224,7 +229,7 @@ class HomeServer(object):
self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
- self.datastore = None
+ self.datastores = None
# Other kwargs are explicit dependencies
for depname in kwargs:
@@ -233,8 +238,10 @@ class HomeServer(object):
def setup(self):
logger.info("Setting up.")
with self.get_db_conn() as conn:
- self.datastore = self.DATASTORE_CLASS(conn, self)
+ datastore = self.DATASTORE_CLASS(conn, self)
+ self.datastores = DataStores(datastore, conn, self)
conn.commit()
+ self.start_time = int(self.get_clock().time())
logger.info("Finished setting up.")
def setup_master(self):
@@ -266,7 +273,7 @@ class HomeServer(object):
return self.clock
def get_datastore(self):
- return self.datastore
+ return self.datastores.main
def get_config(self):
return self.config
@@ -308,6 +315,13 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
+ def build_proxied_http_client(self):
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
+ )
+
def build_room_creation_handler(self):
return RoomCreationHandler(self)
@@ -537,6 +551,9 @@ class HomeServer(object):
def build_event_client_serializer(self):
return EventClientSerializer(self)
+ def build_storage(self) -> Storage:
+ return Storage(self, self.datastores)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 16f8f6b5..b5e0b570 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -12,6 +12,7 @@ import synapse.handlers.message
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
+import synapse.http.client
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -38,8 +39,16 @@ class HomeServer(object):
pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
+ def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ or support any HTTP_PROXY settings"""
+ pass
+ def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ but does support HTTP_PROXY settings"""
+ pass
def get_deactivate_account_handler(
- self
+ self,
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
@@ -47,32 +56,32 @@ class HomeServer(object):
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
pass
def get_event_creation_handler(
- self
+ self,
) -> synapse.handlers.message.EventCreationHandler:
pass
def get_set_password_handler(
- self
+ self,
) -> synapse.handlers.set_password.SetPasswordHandler:
pass
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
pass
def get_federation_transport_client(
- self
+ self,
) -> synapse.federation.transport.client.TransportLayerClient:
pass
def get_media_repository_resource(
- self
+ self,
) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
pass
def get_media_repository(
- self
+ self,
) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass
def get_server_notices_manager(
- self
+ self,
) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
pass
def get_server_notices_sender(
- self
+ self,
) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
pass
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index c0e7f475..9fae2e0a 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -83,7 +83,7 @@ class ResourceLimitsServerNotices(object):
room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id)
if not room_id:
- logger.warn("Failed to get server notices room")
+ logger.warning("Failed to get server notices room")
return
yield self._check_and_set_tags(user_id, room_id)
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
new file mode 100644
index 00000000..efcc10f8
--- /dev/null
+++ b/synapse/spam_checker_api/__init__.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.storage.state import StateFilter
+
+logger = logging.getLogger(__name__)
+
+
+class SpamCheckerApi(object):
+ """A proxy object that gets passed to spam checkers so they can get
+ access to rooms and other relevant information.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+
+ self._store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def get_state_events_in_room(self, room_id, types):
+ """Gets state events for the given room.
+
+ Args:
+ room_id (string): The room ID to get state events in.
+ types (tuple): The event type and state key (using None
+ to represent 'any') of the room state to acquire.
+
+ Returns:
+ twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
+ The filtered state events in the room.
+ """
+ state_ids = yield self._store.get_filtered_current_state_ids(
+ room_id=room_id, state_filter=StateFilter.from_types(types)
+ )
+ state = yield self._store.get_events(state_ids.values())
+ return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index dc9f5a90..139beef8 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,6 +16,7 @@
import logging
from collections import namedtuple
+from typing import Iterable, Optional
from six import iteritems, itervalues
@@ -27,6 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
+from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
@@ -103,6 +105,7 @@ class StateHandler(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ self.state_store = hs.get_storage().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -211,15 +214,17 @@ class StateHandler(object):
return joined_hosts
@defer.inlineCallbacks
- def compute_event_context(self, event, old_state=None):
+ def compute_event_context(
+ self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
+ ):
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
generates a new state group if necessary.
Args:
- event (synapse.events.EventBase):
- old_state (dict|None): The state at the event if it can't be
+ event:
+ old_state: The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
@@ -231,6 +236,9 @@ class StateHandler(object):
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
+
+ # FIXME: why do we populate current_state_ids? I thought the point was
+ # that we weren't supposed to have any state for outliers?
if old_state:
prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
if event.is_state():
@@ -247,113 +255,103 @@ class StateHandler(object):
# group for it.
context = EventContext.with_state(
state_group=None,
+ state_group_before_event=None,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
)
return context
+ #
+ # first of all, figure out the state before the event
+ #
+
if old_state:
- # We already have the state, so we don't need to calculate it.
- # Let's just correctly fill out the context and create a
- # new state group for it.
-
- prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
-
- if event.is_state():
- key = (event.type, event.state_key)
- if key in prev_state_ids:
- replaces = prev_state_ids[key]
- if replaces != event.event_id: # Paranoia check
- event.unsigned["replaces_state"] = replaces
- current_state_ids = dict(prev_state_ids)
- current_state_ids[key] = event.event_id
- else:
- current_state_ids = prev_state_ids
+ # if we're given the state before the event, then we use that
+ state_ids_before_event = {
+ (s.type, s.state_key): s.event_id for s in old_state
+ }
+ state_group_before_event = None
+ state_group_before_event_prev_group = None
+ deltas_to_state_group_before_event = None
- state_group = yield self.store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=None,
- delta_ids=None,
- current_state_ids=current_state_ids,
- )
+ else:
+ # otherwise, we'll need to resolve the state across the prev_events.
+ logger.debug("calling resolve_state_groups from compute_event_context")
- context = EventContext.with_state(
- state_group=state_group,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
+ entry = yield self.resolve_state_groups_for_events(
+ event.room_id, event.prev_event_ids()
)
- return context
+ state_ids_before_event = entry.state
+ state_group_before_event = entry.state_group
+ state_group_before_event_prev_group = entry.prev_group
+ deltas_to_state_group_before_event = entry.delta_ids
- logger.debug("calling resolve_state_groups from compute_event_context")
+ #
+ # make sure that we have a state group at that point. If it's not a state event,
+ # that will be the state group for the new event. If it *is* a state event,
+ # it might get rejected (in which case we'll need to persist it with the
+ # previous state group)
+ #
- entry = yield self.resolve_state_groups_for_events(
- event.room_id, event.prev_event_ids()
- )
+ if not state_group_before_event:
+ state_group_before_event = yield self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ )
- prev_state_ids = entry.state
- prev_group = None
- delta_ids = None
+ # XXX: can we update the state cache entry for the new state group? or
+ # could we set a flag on resolve_state_groups_for_events to tell it to
+ # always make a state group?
+
+ #
+ # now if it's not a state event, we're done
+ #
+
+ if not event.is_state():
+ return EventContext.with_state(
+ state_group_before_event=state_group_before_event,
+ state_group=state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ prev_state_ids=state_ids_before_event,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ )
- if event.is_state():
- # If this is a state event then we need to create a new state
- # group for the state after this event.
+ #
+ # otherwise, we'll need to create a new state group for after the event
+ #
- key = (event.type, event.state_key)
- if key in prev_state_ids:
- replaces = prev_state_ids[key]
+ key = (event.type, event.state_key)
+ if key in state_ids_before_event:
+ replaces = state_ids_before_event[key]
+ if replaces != event.event_id:
event.unsigned["replaces_state"] = replaces
- current_state_ids = dict(prev_state_ids)
- current_state_ids[key] = event.event_id
-
- if entry.state_group:
- # If the state at the event has a state group assigned then
- # we can use that as the prev group
- prev_group = entry.state_group
- delta_ids = {key: event.event_id}
- elif entry.prev_group:
- # If the state at the event only has a prev group, then we can
- # use that as a prev group too.
- prev_group = entry.prev_group
- delta_ids = dict(entry.delta_ids)
- delta_ids[key] = event.event_id
-
- state_group = yield self.store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=prev_group,
- delta_ids=delta_ids,
- current_state_ids=current_state_ids,
- )
- else:
- current_state_ids = prev_state_ids
- prev_group = entry.prev_group
- delta_ids = entry.delta_ids
-
- if entry.state_group is None:
- entry.state_group = yield self.store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=entry.prev_group,
- delta_ids=entry.delta_ids,
- current_state_ids=current_state_ids,
- )
- entry.state_id = entry.state_group
-
- state_group = entry.state_group
-
- context = EventContext.with_state(
- state_group=state_group,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
- prev_group=prev_group,
+ state_ids_after_event = dict(state_ids_before_event)
+ state_ids_after_event[key] = event.event_id
+ delta_ids = {key: event.event_id}
+
+ state_group_after_event = yield self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event,
delta_ids=delta_ids,
+ current_state_ids=state_ids_after_event,
)
- return context
+ return EventContext.with_state(
+ state_group=state_group_after_event,
+ state_group_before_event=state_group_before_event,
+ current_state_ids=state_ids_after_event,
+ prev_state_ids=state_ids_before_event,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ )
@measure_func()
@defer.inlineCallbacks
@@ -376,14 +374,16 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
- state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
+ state_groups_ids = yield self.state_store.get_state_groups_ids(
+ room_id, event_ids
+ )
if len(state_groups_ids) == 0:
return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
- prev_group, delta_ids = yield self.store.get_state_group_delta(name)
+ prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
return _StateCacheEntry(
state=state_list,
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index a249ecd2..0460fe8c 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,7 +27,28 @@ data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
-from synapse.storage.data_stores.main import DataStore # noqa: F401
+from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main import DataStore
+from synapse.storage.persist_events import EventsPersistenceStorage
+from synapse.storage.purge_events import PurgeEventsStorage
+from synapse.storage.state import StateGroupStorage
+
+__all__ = ["DataStores", "DataStore"]
+
+
+class Storage(object):
+ """The high level interfaces for talking to various storage layers.
+ """
+
+ def __init__(self, hs, stores: DataStores):
+ # We include the main data store here mainly so that we don't have to
+ # rewrite all the existing code to split it into high vs low level
+ # interfaces.
+ self.main = stores.main
+
+ self.persistence = EventsPersistenceStorage(hs, stores)
+ self.purge_events = PurgeEventsStorage(hs, stores)
+ self.state = StateGroupStorage(hs, stores)
def are_all_users_on_domain(txn, database_engine, domain):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f5906fcd..ab596fa6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -361,14 +361,11 @@ class SQLBaseStore(object):
expiration_ts,
)
- self._simple_insert_txn(
+ self._simple_upsert_txn(
txn,
"account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- "email_sent": False,
- },
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
def start_profiling(self):
@@ -494,7 +491,7 @@ class SQLBaseStore(object):
exception_callbacks = []
if LoggingContext.current_context() == LoggingContext.sentinel:
- logger.warn("Starting db txn '%s' from sentinel context", desc)
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
result = yield self.runWithConnection(
@@ -532,7 +529,7 @@ class SQLBaseStore(object):
"""
parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel:
- logger.warn(
+ logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
)
parent_context = None
@@ -719,7 +716,7 @@ class SQLBaseStore(object):
raise
# presumably we raced with another transaction: let's retry.
- logger.warn(
+ logger.warning(
"IntegrityError when upserting into %s; retrying: %s", table, e
)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 80b57a94..37d469ff 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -94,13 +94,16 @@ class BackgroundUpdateStore(SQLBaseStore):
self._all_done = False
def start_doing_background_updates(self):
- run_as_background_process("background_updates", self._run_background_updates)
+ run_as_background_process("background_updates", self.run_background_updates)
@defer.inlineCallbacks
- def _run_background_updates(self):
+ def run_background_updates(self, sleep=True):
logger.info("Starting background schema updates")
while True:
- yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+ if sleep:
+ yield self.hs.get_clock().sleep(
+ self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0
+ )
try:
result = yield self.do_next_background_update(
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index 56094078..cb184a98 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -12,3 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+
+class DataStores(object):
+ """The various data stores.
+
+ These are low level interfaces to physical databases.
+ """
+
+ def __init__(self, main_store, db_conn, hs):
+ # Note we pass in the main store here as workers use a different main
+ # store.
+ self.main = main_store
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index b185ba0b..10c940df 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -139,7 +139,10 @@ class DataStore(
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
- db_conn, "device_lists_stream", "stream_id"
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[("user_signature_stream", "stream_id")],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
@@ -317,7 +320,7 @@ class DataStore(
) u
"""
txn.execute(sql, (time_from,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
def count_r30_users(self):
@@ -396,7 +399,7 @@ class DataStore(
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
results["all"] = count
return results
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index f04aad07..96cd0fb7 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -358,8 +358,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
- sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
- txn.execute(sql, (stream_id, stream_id))
+ # Compatible method of performing an upsert
+ sql = "SELECT stream_id FROM device_max_stream_id"
+
+ txn.execute(sql)
+ rows = txn.fetchone()
+ if rows:
+ db_stream_id = rows[0]
+ if db_stream_id < stream_id:
+ # Insert the new stream_id
+ sql = "UPDATE device_max_stream_id SET stream_id = ?"
+ else:
+ # No rows, perform an insert
+ sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)"
+
+ txn.execute(sql, (stream_id,))
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a35423..71f62036 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
- def get_devices_by_remote(self, destination, from_stream_id, limit):
- """Get stream of updates to send to remote servers
+ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ """Get a stream of device updates to send to the given remote server.
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
Returns:
- Deferred[tuple[int, list[dict]]]:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
- response), and the list of updates
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
- "get_devices_by_remote",
- self._get_devices_by_remote_txn,
+ "get_device_updates_by_remote",
+ self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
+ users = set(r[0] for r in updates)
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
# context which created the Edu.
query_map = {}
- for update in updates:
- if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
- key = (update[0], update[1])
-
- update_context = update[3]
- update_stream_id = update[2]
+ if (
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif (
+ user_id in self_signing_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+ else:
+ key = (user_id, device_id)
- previous_update_stream_id, _ = query_map.get(key, (0, None))
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
- if update_stream_id > previous_update_stream_id:
- query_map[key] = (update_stream_id, update_context)
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map:
+ if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
+
return now_stream_id, results
- def _get_devices_by_remote_txn(
+ def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
+ # get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
List[Dict]: List of objects representing an device update EDU
"""
- devices = yield self.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
- query_map.keys(),
- include_all_devices=True,
- include_deleted_devices=True,
+ devices = (
+ yield self.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
)
results = []
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
- results.append(result)
+ results.append(("m.device_list_update", result))
return results
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index ef88e792..1cbbae5b 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -321,9 +321,17 @@ class EndToEndRoomKeyStore(SQLBaseStore):
def _delete_e2e_room_keys_version_txn(txn):
if version is None:
this_version = self._get_current_version(txn, user_id)
+ if this_version is None:
+ raise StoreError(404, "No current backup version")
else:
this_version = version
+ self._simple_delete_txn(
+ txn,
+ table="e2e_room_keys",
+ keyvalues={"user_id": user_id, "version": this_version},
+ )
+
return self._simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index a0bc6f2d..073412a7 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id,
)
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ """Return a list of changes from the user signature stream to notify remotes.
+ Note that the user signature stream represents when a user signs their
+ device with their user-signing key, which is not published to other
+ users or servers, so no `destination` is needed in the returned
+ list. However, this is needed to poke workers.
+
+ Args:
+ from_key (int): the stream ID to start at (exclusive)
+ to_key (int): the stream ID to end at (inclusive)
+
+ Returns:
+ Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
+ """
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id
+ """
+ return self._execute(
+ "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ )
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index a470a48e..90bef0cd 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -364,9 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
- logger.debug(
- "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
- )
+ logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
event_results = set()
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 22025eff..04ce21ac 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
stream_row = txn.fetchone()
if stream_row:
- offset_stream_ordering, = stream_row
+ (offset_stream_ordering,) = stream_row
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 03b5111c..878f7568 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,28 +17,26 @@
import itertools
import logging
-from collections import Counter as c_counter, OrderedDict, deque, namedtuple
+from collections import Counter as c_counter, OrderedDict, namedtuple
from functools import wraps
from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
-from prometheus_client import Counter, Histogram
+from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event_dict
-from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.utils import log_function
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.event_federation import EventFederationStore
@@ -46,10 +44,8 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import batch_iter
-from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -60,37 +56,6 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
-# The number of times we are recalculating the current state
-state_delta_counter = Counter("synapse_storage_events_state_delta", "")
-
-# The number of times we are recalculating state when there is only a
-# single forward extremity
-state_delta_single_event_counter = Counter(
- "synapse_storage_events_state_delta_single_event", ""
-)
-
-# The number of times we are reculating state when we could have resonably
-# calculated the delta when we calculated the state for an event we were
-# persisting.
-state_delta_reuse_delta_counter = Counter(
- "synapse_storage_events_state_delta_reuse_delta", ""
-)
-
-# The number of forward extremities for each new event.
-forward_extremities_counter = Histogram(
- "synapse_storage_events_forward_extremities_persisted",
- "Number of forward extremities for each new event",
- buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
-)
-
-# The number of stale forward extremities for each new event. Stale extremities
-# are those that were in the previous set of extremities as well as the new.
-stale_forward_extremities_counter = Histogram(
- "synapse_storage_events_stale_forward_extremities_persisted",
- "Number of unchanged forward extremities for each new event",
- buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
-)
-
def encode_json(json_object):
"""
@@ -102,110 +67,6 @@ def encode_json(json_object):
return out
-class _EventPeristenceQueue(object):
- """Queues up events so that they can be persisted in bulk with only one
- concurrent transaction per room.
- """
-
- _EventPersistQueueItem = namedtuple(
- "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
- )
-
- def __init__(self):
- self._event_persist_queues = {}
- self._currently_persisting_rooms = set()
-
- def add_to_queue(self, room_id, events_and_contexts, backfilled):
- """Add events to the queue, with the given persist_event options.
-
- NB: due to the normal usage pattern of this method, it does *not*
- follow the synapse logcontext rules, and leaves the logcontext in
- place whether or not the returned deferred is ready.
-
- Args:
- room_id (str):
- events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
-
- Returns:
- defer.Deferred: a deferred which will resolve once the events are
- persisted. Runs its callbacks *without* a logcontext.
- """
- queue = self._event_persist_queues.setdefault(room_id, deque())
- if queue:
- # if the last item in the queue has the same `backfilled` setting,
- # we can just add these new events to that item.
- end_item = queue[-1]
- if end_item.backfilled == backfilled:
- end_item.events_and_contexts.extend(events_and_contexts)
- return end_item.deferred.observe()
-
- deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-
- queue.append(
- self._EventPersistQueueItem(
- events_and_contexts=events_and_contexts,
- backfilled=backfilled,
- deferred=deferred,
- )
- )
-
- return deferred.observe()
-
- def handle_queue(self, room_id, per_item_callback):
- """Attempts to handle the queue for a room if not already being handled.
-
- The given callback will be invoked with for each item in the queue,
- of type _EventPersistQueueItem. The per_item_callback will continuously
- be called with new items, unless the queue becomnes empty. The return
- value of the function will be given to the deferreds waiting on the item,
- exceptions will be passed to the deferreds as well.
-
- This function should therefore be called whenever anything is added
- to the queue.
-
- If another callback is currently handling the queue then it will not be
- invoked.
- """
-
- if room_id in self._currently_persisting_rooms:
- return
-
- self._currently_persisting_rooms.add(room_id)
-
- @defer.inlineCallbacks
- def handle_queue_loop():
- try:
- queue = self._get_drainining_queue(room_id)
- for item in queue:
- try:
- ret = yield per_item_callback(item)
- except Exception:
- with PreserveLoggingContext():
- item.deferred.errback()
- else:
- with PreserveLoggingContext():
- item.deferred.callback(ret)
- finally:
- queue = self._event_persist_queues.pop(room_id, None)
- if queue:
- self._event_persist_queues[room_id] = queue
- self._currently_persisting_rooms.discard(room_id)
-
- # set handle_queue_loop off in the background
- run_as_background_process("persist_events", handle_queue_loop)
-
- def _get_drainining_queue(self, room_id):
- queue = self._event_persist_queues.setdefault(room_id, deque())
-
- try:
- while True:
- yield queue.popleft()
- except IndexError:
- # Queue has been drained.
- pass
-
-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@@ -221,7 +82,7 @@ def _retry_on_integrity_error(func):
@defer.inlineCallbacks
def f(self, *args, **kwargs):
try:
- res = yield func(self, *args, **kwargs)
+ res = yield func(self, *args, delete_existing=False, **kwargs)
except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs)
@@ -241,9 +102,6 @@ class EventsStore(
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
- self._event_persist_queue = _EventPeristenceQueue()
- self._state_resolution_handler = hs.get_state_resolution_handler()
-
# Collect metrics on the number of forward extremities that exist.
# Counter of number of extremities to count
self._current_forward_extremities_amount = c_counter()
@@ -286,340 +144,106 @@ class EventsStore(
res = yield self.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
- @defer.inlineCallbacks
- def persist_events(self, events_and_contexts, backfilled=False):
- """
- Write events to the database
- Args:
- events_and_contexts: list of tuples of (event, context)
- backfilled (bool): Whether the results are retrieved from federation
- via backfill or not. Used to determine if they're "new" events
- which might update the current state etc.
-
- Returns:
- Deferred[int]: the stream ordering of the latest persisted event
- """
- partitioned = {}
- for event, ctx in events_and_contexts:
- partitioned.setdefault(event.room_id, []).append((event, ctx))
-
- deferreds = []
- for room_id, evs_ctxs in iteritems(partitioned):
- d = self._event_persist_queue.add_to_queue(
- room_id, evs_ctxs, backfilled=backfilled
- )
- deferreds.append(d)
-
- for room_id in partitioned:
- self._maybe_start_persisting(room_id)
-
- yield make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
-
- max_persisted_id = yield self._stream_id_gen.get_current_token()
-
- return max_persisted_id
-
- @defer.inlineCallbacks
- @log_function
- def persist_event(self, event, context, backfilled=False):
- """
-
- Args:
- event (EventBase):
- context (EventContext):
- backfilled (bool):
-
- Returns:
- Deferred: resolves to (int, int): the stream ordering of ``event``,
- and the stream ordering of the latest persisted event
- """
- deferred = self._event_persist_queue.add_to_queue(
- event.room_id, [(event, context)], backfilled=backfilled
- )
-
- self._maybe_start_persisting(event.room_id)
-
- yield make_deferred_yieldable(deferred)
-
- max_persisted_id = yield self._stream_id_gen.get_current_token()
- return (event.internal_metadata.stream_ordering, max_persisted_id)
-
- def _maybe_start_persisting(self, room_id):
- @defer.inlineCallbacks
- def persisting_queue(item):
- with Measure(self._clock, "persist_events"):
- yield self._persist_events(
- item.events_and_contexts, backfilled=item.backfilled
- )
-
- self._event_persist_queue.handle_queue(room_id, persisting_queue)
-
@_retry_on_integrity_error
@defer.inlineCallbacks
- def _persist_events(
- self, events_and_contexts, backfilled=False, delete_existing=False
+ def _persist_events_and_state_updates(
+ self,
+ events_and_contexts,
+ current_state_for_room,
+ state_delta_for_room,
+ new_forward_extremeties,
+ backfilled=False,
+ delete_existing=False,
):
- """Persist events to db
+ """Persist a set of events alongside updates to the current state and
+ forward extremities tables.
Args:
events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
+ current_state_for_room (dict[str, dict]): Map from room_id to the
+ current state of the room based on forward extremities
+ state_delta_for_room (dict[str, tuple]): Map from room_id to tuple
+ of `(to_delete, to_insert)` where to_delete is a list
+ of type/state keys to remove from current state, and to_insert
+ is a map (type,key)->event_id giving the state delta in each
+ room.
+ new_forward_extremities (dict[str, list[str]]): Map from room_id
+ to list of event IDs that are the new forward extremities of
+ the room.
+ backfilled (bool)
delete_existing (bool):
Returns:
Deferred: resolves when the events have been persisted
"""
- if not events_and_contexts:
- return
- chunks = [
- events_and_contexts[x : x + 100]
- for x in range(0, len(events_and_contexts), 100)
- ]
-
- for chunk in chunks:
- # We can't easily parallelize these since different chunks
- # might contain the same event. :(
-
- # NB: Assumes that we are only persisting events for one room
- # at a time.
-
- # map room_id->list[event_ids] giving the new forward
- # extremities in each room
- new_forward_extremeties = {}
+ # We want to calculate the stream orderings as late as possible, as
+ # we only notify after all events with a lesser stream ordering have
+ # been persisted. I.e. if we spend 10s inside the with block then
+ # that will delay all subsequent events from being notified about.
+ # Hence why we do it down here rather than wrapping the entire
+ # function.
+ #
+ # Its safe to do this after calculating the state deltas etc as we
+ # only need to protect the *persistence* of the events. This is to
+ # ensure that queries of the form "fetch events since X" don't
+ # return events and stream positions after events that are still in
+ # flight, as otherwise subsequent requests "fetch event since Y"
+ # will not return those events.
+ #
+ # Note: Multiple instances of this function cannot be in flight at
+ # the same time for the same room.
+ if backfilled:
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ len(events_and_contexts)
+ )
+ else:
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ len(events_and_contexts)
+ )
- # map room_id->(type,state_key)->event_id tracking the full
- # state in each room after adding these events.
- # This is simply used to prefill the get_current_state_ids
- # cache
- current_state_for_room = {}
+ with stream_ordering_manager as stream_orderings:
+ for (event, context), stream in zip(events_and_contexts, stream_orderings):
+ event.internal_metadata.stream_ordering = stream
- # map room_id->(to_delete, to_insert) where to_delete is a list
- # of type/state keys to remove from current state, and to_insert
- # is a map (type,key)->event_id giving the state delta in each
- # room
- state_delta_for_room = {}
+ yield self.runInteraction(
+ "persist_events",
+ self._persist_events_txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ delete_existing=delete_existing,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremeties=new_forward_extremeties,
+ )
+ persist_event_counter.inc(len(events_and_contexts))
if not backfilled:
- with Measure(self._clock, "_calculate_state_and_extrem"):
- # Work out the new "current state" for each room.
- # We do this by working out what the new extremities are and then
- # calculating the state from that.
- events_by_room = {}
- for event, context in chunk:
- events_by_room.setdefault(event.room_id, []).append(
- (event, context)
- )
-
- for room_id, ev_ctx_rm in iteritems(events_by_room):
- latest_event_ids = yield self.get_latest_event_ids_in_room(
- room_id
- )
- new_latest_event_ids = yield self._calculate_new_extremities(
- room_id, ev_ctx_rm, latest_event_ids
- )
-
- latest_event_ids = set(latest_event_ids)
- if new_latest_event_ids == latest_event_ids:
- # No change in extremities, so no change in state
- continue
-
- # there should always be at least one forward extremity.
- # (except during the initial persistence of the send_join
- # results, in which case there will be no existing
- # extremities, so we'll `continue` above and skip this bit.)
- assert new_latest_event_ids, "No forward extremities left!"
-
- new_forward_extremeties[room_id] = new_latest_event_ids
-
- len_1 = (
- len(latest_event_ids) == 1
- and len(new_latest_event_ids) == 1
- )
- if len_1:
- all_single_prev_not_state = all(
- len(event.prev_event_ids()) == 1
- and not event.is_state()
- for event, ctx in ev_ctx_rm
- )
- # Don't bother calculating state if they're just
- # a long chain of single ancestor non-state events.
- if all_single_prev_not_state:
- continue
-
- state_delta_counter.inc()
- if len(new_latest_event_ids) == 1:
- state_delta_single_event_counter.inc()
-
- # This is a fairly handwavey check to see if we could
- # have guessed what the delta would have been when
- # processing one of these events.
- # What we're interested in is if the latest extremities
- # were the same when we created the event as they are
- # now. When this server creates a new event (as opposed
- # to receiving it over federation) it will use the
- # forward extremities as the prev_events, so we can
- # guess this by looking at the prev_events and checking
- # if they match the current forward extremities.
- for ev, _ in ev_ctx_rm:
- prev_event_ids = set(ev.prev_event_ids())
- if latest_event_ids == prev_event_ids:
- state_delta_reuse_delta_counter.inc()
- break
-
- logger.info("Calculating state delta for room %s", room_id)
- with Measure(
- self._clock, "persist_events.get_new_state_after_events"
- ):
- res = yield self._get_new_state_after_events(
- room_id,
- ev_ctx_rm,
- latest_event_ids,
- new_latest_event_ids,
- )
- current_state, delta_ids = res
-
- # If either are not None then there has been a change,
- # and we need to work out the delta (or use that
- # given)
- if delta_ids is not None:
- # If there is a delta we know that we've
- # only added or replaced state, never
- # removed keys entirely.
- state_delta_for_room[room_id] = ([], delta_ids)
- elif current_state is not None:
- with Measure(
- self._clock, "persist_events.calculate_state_delta"
- ):
- delta = yield self._calculate_state_delta(
- room_id, current_state
- )
- state_delta_for_room[room_id] = delta
-
- # If we have the current_state then lets prefill
- # the cache with it.
- if current_state is not None:
- current_state_for_room[room_id] = current_state
-
- # We want to calculate the stream orderings as late as possible, as
- # we only notify after all events with a lesser stream ordering have
- # been persisted. I.e. if we spend 10s inside the with block then
- # that will delay all subsequent events from being notified about.
- # Hence why we do it down here rather than wrapping the entire
- # function.
- #
- # Its safe to do this after calculating the state deltas etc as we
- # only need to protect the *persistence* of the events. This is to
- # ensure that queries of the form "fetch events since X" don't
- # return events and stream positions after events that are still in
- # flight, as otherwise subsequent requests "fetch event since Y"
- # will not return those events.
- #
- # Note: Multiple instances of this function cannot be in flight at
- # the same time for the same room.
- if backfilled:
- stream_ordering_manager = self._backfill_id_gen.get_next_mult(
- len(chunk)
+ # backfilled events have negative stream orderings, so we don't
+ # want to set the event_persisted_position to that.
+ synapse.metrics.event_persisted_position.set(
+ events_and_contexts[-1][0].internal_metadata.stream_ordering
)
- else:
- stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk))
-
- with stream_ordering_manager as stream_orderings:
- for (event, context), stream in zip(chunk, stream_orderings):
- event.internal_metadata.stream_ordering = stream
-
- yield self.runInteraction(
- "persist_events",
- self._persist_events_txn,
- events_and_contexts=chunk,
- backfilled=backfilled,
- delete_existing=delete_existing,
- state_delta_for_room=state_delta_for_room,
- new_forward_extremeties=new_forward_extremeties,
- )
- persist_event_counter.inc(len(chunk))
- if not backfilled:
- # backfilled events have negative stream orderings, so we don't
- # want to set the event_persisted_position to that.
- synapse.metrics.event_persisted_position.set(
- chunk[-1][0].internal_metadata.stream_ordering
- )
-
- for event, context in chunk:
- if context.app_service:
- origin_type = "local"
- origin_entity = context.app_service.id
- elif self.hs.is_mine_id(event.sender):
- origin_type = "local"
- origin_entity = "*client*"
- else:
- origin_type = "remote"
- origin_entity = get_domain_from_id(event.sender)
-
- event_counter.labels(event.type, origin_type, origin_entity).inc()
-
- for room_id, new_state in iteritems(current_state_for_room):
- self.get_current_state_ids.prefill((room_id,), new_state)
-
- for room_id, latest_event_ids in iteritems(new_forward_extremeties):
- self.get_latest_event_ids_in_room.prefill(
- (room_id,), list(latest_event_ids)
- )
-
- @defer.inlineCallbacks
- def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
- """Calculates the new forward extremities for a room given events to
- persist.
-
- Assumes that we are only persisting events for one room at a time.
- """
-
- # we're only interested in new events which aren't outliers and which aren't
- # being rejected.
- new_events = [
- event
- for event, ctx in event_contexts
- if not event.internal_metadata.is_outlier()
- and not ctx.rejected
- and not event.internal_metadata.is_soft_failed()
- ]
-
- latest_event_ids = set(latest_event_ids)
-
- # start with the existing forward extremities
- result = set(latest_event_ids)
-
- # add all the new events to the list
- result.update(event.event_id for event in new_events)
-
- # Now remove all events which are prev_events of any of the new events
- result.difference_update(
- e_id for event in new_events for e_id in event.prev_event_ids()
- )
-
- # Remove any events which are prev_events of any existing events.
- existing_prevs = yield self._get_events_which_are_prevs(result)
- result.difference_update(existing_prevs)
+ for event, context in events_and_contexts:
+ if context.app_service:
+ origin_type = "local"
+ origin_entity = context.app_service.id
+ elif self.hs.is_mine_id(event.sender):
+ origin_type = "local"
+ origin_entity = "*client*"
+ else:
+ origin_type = "remote"
+ origin_entity = get_domain_from_id(event.sender)
- # Finally handle the case where the new events have soft-failed prev
- # events. If they do we need to remove them and their prev events,
- # otherwise we end up with dangling extremities.
- existing_prevs = yield self._get_prevs_before_rejected(
- e_id for event in new_events for e_id in event.prev_event_ids()
- )
- result.difference_update(existing_prevs)
+ event_counter.labels(event.type, origin_type, origin_entity).inc()
- # We only update metrics for events that change forward extremities
- # (e.g. we ignore backfill/outliers/etc)
- if result != latest_event_ids:
- forward_extremities_counter.observe(len(result))
- stale = latest_event_ids & result
- stale_forward_extremities_counter.observe(len(stale))
+ for room_id, new_state in iteritems(current_state_for_room):
+ self.get_current_state_ids.prefill((room_id,), new_state)
- return result
+ for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+ self.get_latest_event_ids_in_room.prefill(
+ (room_id,), list(latest_event_ids)
+ )
@defer.inlineCallbacks
def _get_events_which_are_prevs(self, event_ids):
@@ -725,188 +349,6 @@ class EventsStore(
return existing_prevs
- @defer.inlineCallbacks
- def _get_new_state_after_events(
- self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
- ):
- """Calculate the current state dict after adding some new events to
- a room
-
- Args:
- room_id (str):
- room to which the events are being added. Used for logging etc
-
- events_context (list[(EventBase, EventContext)]):
- events and contexts which are being added to the room
-
- old_latest_event_ids (iterable[str]):
- the old forward extremities for the room.
-
- new_latest_event_ids (iterable[str]):
- the new forward extremities for the room.
-
- Returns:
- Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
-
- If there has been a change then we only return the delta if its
- already been calculated. Conversely if we do know the delta then
- the new current state is only returned if we've already calculated
- it.
- """
- # map from state_group to ((type, key) -> event_id) state map
- state_groups_map = {}
-
- # Map from (prev state group, new state group) -> delta state dict
- state_group_deltas = {}
-
- for ev, ctx in events_context:
- if ctx.state_group is None:
- # This should only happen for outlier events.
- if not ev.internal_metadata.is_outlier():
- raise Exception(
- "Context for new event %s has no state "
- "group" % (ev.event_id,)
- )
- continue
-
- if ctx.state_group in state_groups_map:
- continue
-
- # We're only interested in pulling out state that has already
- # been cached in the context. We'll pull stuff out of the DB later
- # if necessary.
- current_state_ids = ctx.get_cached_current_state_ids()
- if current_state_ids is not None:
- state_groups_map[ctx.state_group] = current_state_ids
-
- if ctx.prev_group:
- state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
-
- # We need to map the event_ids to their state groups. First, let's
- # check if the event is one we're persisting, in which case we can
- # pull the state group from its context.
- # Otherwise we need to pull the state group from the database.
-
- # Set of events we need to fetch groups for. (We know none of the old
- # extremities are going to be in events_context).
- missing_event_ids = set(old_latest_event_ids)
-
- event_id_to_state_group = {}
- for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding.
- for ev, ctx in events_context:
- if event_id == ev.event_id and ctx.state_group is not None:
- event_id_to_state_group[event_id] = ctx.state_group
- break
- else:
- # If we couldn't find it, then we'll need to pull
- # the state from the database
- missing_event_ids.add(event_id)
-
- if missing_event_ids:
- # Now pull out the state groups for any missing events from DB
- event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
- event_id_to_state_group.update(event_to_groups)
-
- # State groups of old_latest_event_ids
- old_state_groups = set(
- event_id_to_state_group[evid] for evid in old_latest_event_ids
- )
-
- # State groups of new_latest_event_ids
- new_state_groups = set(
- event_id_to_state_group[evid] for evid in new_latest_event_ids
- )
-
- # If they old and new groups are the same then we don't need to do
- # anything.
- if old_state_groups == new_state_groups:
- return None, None
-
- if len(new_state_groups) == 1 and len(old_state_groups) == 1:
- # If we're going from one state group to another, lets check if
- # we have a delta for that transition. If we do then we can just
- # return that.
-
- new_state_group = next(iter(new_state_groups))
- old_state_group = next(iter(old_state_groups))
-
- delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
- if delta_ids is not None:
- # We have a delta from the existing to new current state,
- # so lets just return that. If we happen to already have
- # the current state in memory then lets also return that,
- # but it doesn't matter if we don't.
- new_state = state_groups_map.get(new_state_group)
- return new_state, delta_ids
-
- # Now that we have calculated new_state_groups we need to get
- # their state IDs so we can resolve to a single state set.
- missing_state = new_state_groups - set(state_groups_map)
- if missing_state:
- group_to_state = yield self._get_state_for_groups(missing_state)
- state_groups_map.update(group_to_state)
-
- if len(new_state_groups) == 1:
- # If there is only one state group, then we know what the current
- # state is.
- return state_groups_map[new_state_groups.pop()], None
-
- # Ok, we need to defer to the state handler to resolve our state sets.
-
- state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
-
- events_map = {ev.event_id: ev for ev, _ in events_context}
-
- # We need to get the room version, which is in the create event.
- # Normally that'd be in the database, but its also possible that we're
- # currently trying to persist it.
- room_version = None
- for ev, _ in events_context:
- if ev.type == EventTypes.Create and ev.state_key == "":
- room_version = ev.content.get("room_version", "1")
- break
-
- if not room_version:
- room_version = yield self.get_room_version(room_id)
-
- logger.debug("calling resolve_state_groups from preserve_events")
- res = yield self._state_resolution_handler.resolve_state_groups(
- room_id,
- room_version,
- state_groups,
- events_map,
- state_res_store=StateResolutionStore(self),
- )
-
- return res.state, None
-
- @defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, current_state):
- """Calculate the new state deltas for a room.
-
- Assumes that we are only persisting events for one room at a time.
-
- Returns:
- tuple[list, dict] (to_delete, to_insert): where to_delete are the
- type/state_keys to remove from current_state_events and `to_insert`
- are the updates to current_state_events.
- """
- existing_state = yield self.get_current_state_ids(room_id)
-
- to_delete = [key for key in existing_state if key not in current_state]
-
- to_insert = {
- key: ev_id
- for key, ev_id in iteritems(current_state)
- if ev_id != existing_state.get(key)
- }
-
- return to_delete, to_insert
-
@log_function
def _persist_events_txn(
self,
@@ -1490,6 +932,13 @@ class EventsStore(
self._handle_event_relations(txn, event)
+ # Store the labels for this event.
+ labels = event.content.get(EventContentFields.LABELS)
+ if labels:
+ self.insert_labels_for_event_txn(
+ txn, event.event_id, labels, event.room_id, event.depth
+ )
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1683,7 +1132,7 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_messages", _count_messages)
@@ -1704,7 +1153,7 @@ class EventsStore(
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
@@ -1719,7 +1168,7 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_daily_active_rooms", _count)
@@ -1926,6 +1375,10 @@ class EventsStore(
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
+
+ Returns:
+ Deferred[set[int]]: The set of state groups that are referenced by
+ deleted events.
"""
return self.runInteraction(
@@ -2062,11 +1515,10 @@ class EventsStore(
[(room_id, event_id) for event_id, in new_backwards_extrems],
)
- logger.info("[purge] finding redundant state groups")
+ logger.info("[purge] finding state groups referenced by deleted events")
# Get all state groups that are referenced by events that are to be
- # deleted. We then go and check if they are referenced by other events
- # or state groups, and if not we delete them.
+ # deleted.
txn.execute(
"""
SELECT DISTINCT state_group FROM events_to_purge
@@ -2079,60 +1531,6 @@ class EventsStore(
"[purge] found %i referenced state groups", len(referenced_state_groups)
)
- logger.info("[purge] finding state groups that can be deleted")
-
- _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups)
- state_groups_to_delete, remaining_state_groups = _
-
- logger.info(
- "[purge] found %i state groups to delete", len(state_groups_to_delete)
- )
-
- logger.info(
- "[purge] de-delta-ing %i remaining state groups",
- len(remaining_state_groups),
- )
-
- # Now we turn the state groups that reference to-be-deleted state
- # groups to non delta versions.
- for sg in remaining_state_groups:
- logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
-
- self._simple_delete_txn(
- txn, table="state_groups_state", keyvalues={"state_group": sg}
- )
-
- self._simple_delete_txn(
- txn, table="state_group_edges", keyvalues={"state_group": sg}
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": sg,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(curr_state)
- ],
- )
-
- logger.info("[purge] removing redundant state groups")
- txn.executemany(
- "DELETE FROM state_groups_state WHERE state_group = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
- txn.executemany(
- "DELETE FROM state_groups WHERE id = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
-
logger.info("[purge] removing events from event_to_state_groups")
txn.execute(
"DELETE FROM event_to_state_groups "
@@ -2204,7 +1602,7 @@ class EventsStore(
""",
(room_id,),
)
- min_depth, = txn.fetchone()
+ (min_depth,) = txn.fetchone()
logger.info("[purge] updating room_depth to %d", min_depth)
@@ -2219,138 +1617,35 @@ class EventsStore(
logger.info("[purge] done")
- def _find_unreferenced_groups_during_purge(self, txn, state_groups):
- """Used when purging history to figure out which state groups can be
- deleted and which need to be de-delta'ed (due to one of its prev groups
- being scheduled for deletion).
-
- Args:
- txn
- state_groups (set[int]): Set of state groups referenced by events
- that are going to be deleted.
-
- Returns:
- tuple[set[int], set[int]]: The set of state groups that can be
- deleted and the set of state groups that need to be de-delta'ed
- """
- # Graph of state group -> previous group
- graph = {}
-
- # Set of events that we have found to be referenced by events
- referenced_groups = set()
-
- # Set of state groups we've already seen
- state_groups_seen = set(state_groups)
-
- # Set of state groups to handle next.
- next_to_search = set(state_groups)
- while next_to_search:
- # We bound size of groups we're looking up at once, to stop the
- # SQL query getting too big
- if len(next_to_search) < 100:
- current_search = next_to_search
- next_to_search = set()
- else:
- current_search = set(itertools.islice(next_to_search, 100))
- next_to_search -= current_search
-
- # Check if state groups are referenced
- sql = """
- SELECT DISTINCT state_group FROM event_to_state_groups
- LEFT JOIN events_to_purge AS ep USING (event_id)
- WHERE ep.event_id IS NULL AND
- """
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "state_group", current_search
- )
- txn.execute(sql + clause, list(args))
-
- referenced = set(sg for sg, in txn)
- referenced_groups |= referenced
-
- # We don't continue iterating up the state group graphs for state
- # groups that are referenced.
- current_search -= referenced
-
- rows = self._simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=current_search,
- keyvalues={},
- retcols=("prev_state_group", "state_group"),
- )
-
- prevs = set(row["state_group"] for row in rows)
- # We don't bother re-handling groups we've already seen
- prevs -= state_groups_seen
- next_to_search |= prevs
- state_groups_seen |= prevs
-
- for row in rows:
- # Note: Each state group can have at most one prev group
- graph[row["state_group"]] = row["prev_state_group"]
-
- to_delete = state_groups_seen - referenced_groups
-
- to_dedelta = set()
- for sg in referenced_groups:
- prev_sg = graph.get(sg)
- if prev_sg and prev_sg in to_delete:
- to_dedelta.add(sg)
-
- return to_delete, to_dedelta
+ return referenced_state_groups
def purge_room(self, room_id):
"""Deletes all record of a room
Args:
- room_id (str):
+ room_id (str)
+
+ Returns:
+ Deferred[List[int]]: The list of state groups to delete.
"""
return self.runInteraction("purge_room", self._purge_room_txn, room_id)
def _purge_room_txn(self, txn, room_id):
- # first we have to delete the state groups states
- logger.info("[purge] removing %s from state_groups_state", room_id)
-
+ # First we fetch all the state groups that should be deleted, before
+ # we delete that information.
txn.execute(
"""
- DELETE FROM state_groups_state WHERE state_group IN (
- SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
- WHERE events.room_id=?
- )
+ SELECT DISTINCT state_group FROM events
+ INNER JOIN event_to_state_groups USING(event_id)
+ WHERE events.room_id = ?
""",
(room_id,),
)
- # ... and the state group edges
- logger.info("[purge] removing %s from state_group_edges", room_id)
-
- txn.execute(
- """
- DELETE FROM state_group_edges WHERE state_group IN (
- SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
- WHERE events.room_id=?
- )
- """,
- (room_id,),
- )
-
- # ... and the state groups
- logger.info("[purge] removing %s from state_groups", room_id)
-
- txn.execute(
- """
- DELETE FROM state_groups WHERE id IN (
- SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
- WHERE events.room_id=?
- )
- """,
- (room_id,),
- )
+ state_groups = [row[0] for row in txn]
- # and then tables which lack an index on room_id but have one on event_id
+ # Now we delete tables which lack an index on room_id but have one on event_id
for table in (
"event_auth",
"event_edges",
@@ -2396,7 +1691,6 @@ class EventsStore(
"room_stats_earliest_token",
"rooms",
"stream_ordering_to_exterm",
- "topics",
"users_in_public_rooms",
"users_who_share_private_rooms",
# no useful index, but let's clear them anyway
@@ -2439,12 +1733,170 @@ class EventsStore(
logger.info("[purge] done")
+ return state_groups
+
+ def purge_unreferenced_state_groups(
+ self, room_id: str, state_groups_to_delete
+ ) -> defer.Deferred:
+ """Deletes no longer referenced state groups and de-deltas any state
+ groups that reference them.
+
+ Args:
+ room_id: The room the state groups belong to (must all be in the
+ same room).
+ state_groups_to_delete (Collection[int]): Set of all state groups
+ to delete.
+ """
+
+ return self.runInteraction(
+ "purge_unreferenced_state_groups",
+ self._purge_unreferenced_state_groups,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ logger.info(
+ "[purge] found %i state groups to delete", len(state_groups_to_delete)
+ )
+
+ rows = self._simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ retcols=("state_group",),
+ )
+
+ remaining_state_groups = set(
+ row["state_group"]
+ for row in rows
+ if row["state_group"] not in state_groups_to_delete
+ )
+
+ logger.info(
+ "[purge] de-delta-ing %i remaining state groups",
+ len(remaining_state_groups),
+ )
+
+ # Now we turn the state groups that reference to-be-deleted state
+ # groups to non delta versions.
+ for sg in remaining_state_groups:
+ logger.info("[purge] de-delta-ing remaining state group %s", sg)
+ curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state[sg]
+
+ self._simple_delete_txn(
+ txn, table="state_groups_state", keyvalues={"state_group": sg}
+ )
+
+ self._simple_delete_txn(
+ txn, table="state_group_edges", keyvalues={"state_group": sg}
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": sg,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(curr_state)
+ ],
+ )
+
+ logger.info("[purge] removing redundant state groups")
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+
@defer.inlineCallbacks
- def is_event_after(self, event_id1, event_id2):
+ def get_previous_state_groups(self, state_groups):
+ """Fetch the previous groups of the given state groups.
+
+ Args:
+ state_groups (Iterable[int])
+
+ Returns:
+ Deferred[dict[int, int]]: mapping from state group to previous
+ state group.
+ """
+
+ rows = yield self._simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("prev_state_group", "state_group"),
+ desc="get_previous_state_groups",
+ )
+
+ return {row["state_group"]: row["prev_state_group"] for row in rows}
+
+ def purge_room_state(self, room_id, state_groups_to_delete):
+ """Deletes all record of a room from state tables
+
+ Args:
+ room_id (str):
+ state_groups_to_delete (list[int]): State groups to delete
+ """
+
+ return self.runInteraction(
+ "purge_room_state",
+ self._purge_room_state_txn,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ # first we have to delete the state groups states
+ logger.info("[purge] removing %s from state_groups_state", room_id)
+
+ self._simple_delete_many_txn(
+ txn,
+ table="state_groups_state",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state group edges
+ logger.info("[purge] removing %s from state_group_edges", room_id)
+
+ self._simple_delete_many_txn(
+ txn,
+ table="state_group_edges",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state groups
+ logger.info("[purge] removing %s from state_groups", room_id)
+
+ self._simple_delete_many_txn(
+ txn,
+ table="state_groups",
+ column="id",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
- to_1, so_1 = yield self._get_event_ordering(event_id1)
- to_2, so_2 = yield self._get_event_ordering(event_id2)
+ to_1, so_1 = await self._get_event_ordering(event_id1)
+ to_2, so_2 = await self._get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
@@ -2477,6 +1929,33 @@ class EventsStore(
get_all_updated_current_state_deltas_txn,
)
+ def insert_labels_for_event_txn(
+ self, txn, event_id, labels, room_id, topological_ordering
+ ):
+ """Store the mapping between an event's ID and its labels, with one row per
+ (event_id, label) tuple.
+
+ Args:
+ txn (LoggingTransaction): The transaction to execute.
+ event_id (str): The event's ID.
+ labels (list[str]): A list of text labels.
+ room_id (str): The ID of the room the event was sent to.
+ topological_ordering (int): The position of the event in the room's topology.
+ """
+ return self._simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": room_id,
+ "topological_ordering": topological_ordering,
+ }
+ for label in labels
+ ],
+ )
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 31ea6f91..aa87f9ab 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
@@ -85,6 +86,10 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"event_fix_redactions_bytes", self._event_fix_redactions_bytes
)
+ self.register_background_update_handler(
+ "event_store_labels", self._event_store_labels
+ )
+
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -438,7 +443,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if not rows:
return 0
- upper_event_id, = rows[-1]
+ (upper_event_id,) = rows[-1]
# Update the redactions with the received_ts.
#
@@ -503,3 +508,68 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
yield self._end_background_update("event_fix_redactions_bytes")
return 1
+
+ @defer.inlineCallbacks
+ def _event_store_labels(self, progress, batch_size):
+ """Background update handler which will store labels for existing events."""
+ last_event_id = progress.get("last_event_id", "")
+
+ def _event_store_labels_txn(txn):
+ txn.execute(
+ """
+ SELECT event_id, json FROM event_json
+ LEFT JOIN event_labels USING (event_id)
+ WHERE event_id > ? AND label IS NULL
+ ORDER BY event_id LIMIT ?
+ """,
+ (last_event_id, batch_size),
+ )
+
+ results = list(txn)
+
+ nbrows = 0
+ last_row_event_id = ""
+ for (event_id, event_json_raw) in results:
+ try:
+ event_json = json.loads(event_json_raw)
+
+ self._simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": event_json["room_id"],
+ "topological_ordering": event_json["depth"],
+ }
+ for label in event_json["content"].get(
+ EventContentFields.LABELS, []
+ )
+ if isinstance(label, str)
+ ],
+ )
+ except Exception as e:
+ logger.warning(
+ "Unable to load event %s (no labels will be imported): %s",
+ event_id,
+ e,
+ )
+
+ nbrows += 1
+ last_row_event_id = event_id
+
+ self._background_update_progress_txn(
+ txn, "event_store_labels", {"last_event_id": last_row_event_id}
+ )
+
+ return nbrows
+
+ num_rows = yield self.runInteraction(
+ desc="event_store_labels", func=_event_store_labels_txn
+ )
+
+ if not num_rows:
+ yield self._end_background_update("event_store_labels")
+
+ return num_rows
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index aeae5a2b..5ded539a 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
@@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
@@ -553,6 +553,21 @@ class GroupServerStore(SQLBaseStore):
desc="remove_user_from_summary",
)
+ def get_local_groups_for_room(self, room_id):
+ """Get all of the local group that contain a given room
+ Args:
+ room_id (str): The ID of a room
+ Returns:
+ Deferred[list[str]]: A twisted.Deferred containing a list of group ids
+ containing this room
+ """
+ return self._simple_select_onecol(
+ table="group_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="group_id",
+ desc="get_local_groups_for_room",
+ )
+
def get_users_for_summary_by_role(self, group_id, include_private=False):
"""Get the users and roles that should be included in a summary request
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index e6ee1e4a..b41c3d31 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index cd95f1ce..b520062d 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -143,7 +143,7 @@ class PushRulesWorkerStore(
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return bool(count)
return self.runInteraction(
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index f005c1ae..d76861cd 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -44,7 +44,7 @@ class PusherWorkerStore(SQLBaseStore):
r["data"] = json.loads(dataJson)
except Exception as e:
- logger.warn(
+ logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
dataJson,
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 6c5b2928..89147ad5 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE appservice_id IS NULL
"""
)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
@@ -488,14 +488,14 @@ class RegistrationWorkerStore(SQLBaseStore):
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
- avoid the case of ID 10000000 being pre-allocated, so us wasting the
- first (and shortest) many generated user IDs.
+ avoid the case of ID 1000 being pre-allocated and starting at 1001 while
+ 0-999 are available.
"""
def _find_next_generated_user_id(txn):
- # We bound between '@1' and '@a' to avoid pulling the entire table
+ # We bound between '@0' and '@a' to avoid pulling the entire table
# out.
- txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'")
+ txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
@@ -577,6 +577,19 @@ class RegistrationWorkerStore(SQLBaseStore):
return self._simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
+ desc="user_delete_threepid",
+ )
+
+ def user_delete_threepids(self, user_id: str):
+ """Delete all threepid this user has bound
+
+ Args:
+ user_id: The user id to delete all threepids of
+
+ """
+ return self._simple_delete(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
desc="user_delete_threepids",
)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index e47ab604..2af24a20 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -720,7 +720,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
- cache = self._get_joined_hosts_cache(room_id)
+ cache = yield self._get_joined_hosts_cache(room_id)
joined_hosts = yield cache.get_destinations(state_entry)
return joined_hosts
@@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
if not row or not row[0]:
return processed, True
- next_room, = row
+ (next_room,) = row
sql = """
UPDATE current_state_events
diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
new file mode 100644
index 00000000..1d2ddb1b
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* delete room keys that belong to deleted room key version, or to room key
+ * versions that don't exist (anymore)
+ */
+DELETE FROM e2e_room_keys
+WHERE version NOT IN (
+ SELECT version
+ FROM e2e_room_keys_versions
+ WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id
+ AND e2e_room_keys_versions.deleted = 0
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
new file mode 100644
index 00000000..5e29c1da
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
@@ -0,0 +1,30 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- room_id and topoligical_ordering are denormalised from the events table in order to
+-- make the index work.
+CREATE TABLE IF NOT EXISTS event_labels (
+ event_id TEXT,
+ label TEXT,
+ room_id TEXT NOT NULL,
+ topological_ordering BIGINT NOT NULL,
+ PRIMARY KEY(event_id, label)
+);
+
+
+-- This index enables an event pagination looking for a particular label to index the
+-- event_labels table first, which is much quicker than scanning the events table and then
+-- filtering by label, if the label is rarely used relative to the size of the room.
+CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
new file mode 100644
index 00000000..5f5e0499
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('event_store_labels', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
new file mode 100644
index 00000000..e8b1fd35
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
@@ -0,0 +1,42 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Change the hidden column from a default value of FALSE to a default value of
+ * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
+ * string 'FALSE', which is truthy.
+ *
+ * Since sqlite doesn't allow us to just change the default value, we have to
+ * recreate the table, copy the data, fix the rows that have incorrect data, and
+ * replace the old table with the new table.
+ */
+
+CREATE TABLE IF NOT EXISTS devices2 (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ display_name TEXT,
+ last_seen BIGINT,
+ ip TEXT,
+ user_agent TEXT,
+ hidden BOOLEAN DEFAULT 0,
+ CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
+
+INSERT INTO devices2 SELECT * FROM devices;
+
+UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
+
+DROP TABLE devices;
+
+ALTER TABLE devices2 RENAME TO devices;
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 0e084974..d1d7c686 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -196,7 +196,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
" ON event_search USING GIN (vector)"
)
except psycopg2.ProgrammingError as e:
- logger.warn(
+ logger.warning(
"Ignoring error %r when trying to switch from GIST to GIN", e
)
@@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore):
)
)
txn.execute(query, (value, search_query))
- headline, = txn.fetchall()[0]
+ (headline,) = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
# result.
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index d54442e5..6a90daea 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import Iterable, Tuple
from six import iteritems, itervalues
from six.moves import range
@@ -23,6 +24,8 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
@@ -282,7 +285,11 @@ class StateGroupWorkerStore(
room_id (str)
Returns:
- Deferred[unicode|None]: predecessor room id
+ Deferred[dict|None]: A dictionary containing the structure of the predecessor
+ field from the room's create event. The structure is subject to other servers,
+ but it is expected to be:
+ * room_id (str): The room ID of the predecessor room
+ * event_id (str): The ID of the tombstone event in the predecessor room
Raises:
NotFoundError if the room is unknown
@@ -722,16 +729,18 @@ class StateGroupWorkerStore(
member_filter, non_member_filter = state_filter.get_member_split()
# Now we look them up in the member and non-member caches
- non_member_state, incomplete_groups_nm, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, state_filter=non_member_filter
- )
+ (
+ non_member_state,
+ incomplete_groups_nm,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, state_filter=non_member_filter
)
- member_state, incomplete_groups_m, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache, state_filter=member_filter
- )
+ (
+ member_state,
+ incomplete_groups_m,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache, state_filter=member_filter
)
state = dict(non_member_state)
@@ -986,6 +995,29 @@ class StateGroupWorkerStore(
return self.runInteraction("store_state_group", _store_state_group_txn)
+ @defer.inlineCallbacks
+ def get_referenced_state_groups(self, state_groups):
+ """Check if the state groups are referenced by events.
+
+ Args:
+ state_groups (Iterable[int])
+
+ Returns:
+ Deferred[set[int]]: The subset of state groups that are
+ referenced.
+ """
+
+ rows = yield self._simple_select_many_batch(
+ table="event_to_state_groups",
+ column="state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("DISTINCT state_group",),
+ desc="get_referenced_state_groups",
+ )
+
+ return set(row["state_group"] for row in rows)
+
class StateBackgroundUpdateStore(
StateGroupBackgroundUpdateStore, BackgroundUpdateStore
@@ -1073,7 +1105,7 @@ class StateBackgroundUpdateStore(
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
- prev_group, = txn.fetchone()
+ (prev_group,) = txn.fetchone()
new_last_state_group = state_group
if prev_group:
@@ -1215,7 +1247,9 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
- def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+ def _store_event_state_mappings_txn(
+ self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+ ):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
@@ -1224,7 +1258,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
- state_groups[event.event_id] = context.prev_group
+ state_groups[event.event_id] = context.state_group_before_event
continue
state_groups[event.event_id] = context.state_group
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 5ab639b2..45b3de7d 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -332,7 +332,7 @@ class StatsStore(StateDeltasStore):
def _bulk_update_stats_delta_txn(txn):
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
- logger.info(
+ logger.debug(
"Updating %s stats for %s: %s", stats_type, stats_id, fields
)
self._update_stats_delta_txn(
@@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore):
(room_id,),
)
- current_state_events_count, = txn.fetchone()
+ (current_state_events_count,) = txn.fetchone()
users_in_room = self.get_users_in_room_txn(txn, room_id)
@@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore):
""",
(user_id,),
)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count, pos
joined_rooms, pos = yield self.runInteraction(
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 263999df..8780fdd9 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -229,6 +229,14 @@ def filter_to_clause(event_filter):
clauses.append("contains_url = ?")
args.append(event_filter.contains_url)
+ # We're only applying the "labels" filter on the database query, because applying the
+ # "not_labels" filter via a SQL query is non-trivial. Instead, we let
+ # event_filter.check_fields apply it, which is not as efficient but makes the
+ # implementation simpler.
+ if event_filter.labels:
+ clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
+ args.extend(event_filter.labels)
+
return " AND ".join(clauses), args
@@ -863,13 +871,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args.append(int(limit))
- sql = (
- "SELECT event_id, topological_ordering, stream_ordering"
- " FROM events"
- " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
- " ORDER BY topological_ordering %(order)s,"
- " stream_ordering %(order)s LIMIT ?"
- ) % {"bounds": bounds, "order": order}
+ select_keywords = "SELECT"
+ join_clause = ""
+ if event_filter and event_filter.labels:
+ # If we're not filtering on a label, then joining on event_labels will
+ # return as many row for a single event as the number of labels it has. To
+ # avoid this, only join if we're filtering on at least one label.
+ join_clause = """
+ LEFT JOIN event_labels
+ USING (event_id, room_id, topological_ordering)
+ """
+ if len(event_filter.labels) > 1:
+ # Using DISTINCT in this SELECT query is quite expensive, because it
+ # requires the engine to sort on the entire (not limited) result set,
+ # i.e. the entire events table. We only need to use it when we're
+ # filtering on more than two labels, because that's the only scenario
+ # in which we can possibly to get multiple times the same event ID in
+ # the results.
+ select_keywords += "DISTINCT"
+
+ sql = """
+ %(select_keywords)s event_id, topological_ordering, stream_ordering
+ FROM events
+ %(join_clause)s
+ WHERE outlier = ? AND room_id = ? AND %(bounds)s
+ ORDER BY topological_ordering %(order)s,
+ stream_ordering %(order)s LIMIT ?
+ """ % {
+ "select_keywords": select_keywords,
+ "join_clause": join_clause,
+ "bounds": bounds,
+ "order": order,
+ }
txn.execute(sql, args)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
new file mode 100644
index 00000000..fa03ca9f
--- /dev/null
+++ b/synapse/storage/persist_events.py
@@ -0,0 +1,649 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import deque, namedtuple
+
+from six import iteritems
+from six.moves import range
+
+from prometheus_client import Counter, Histogram
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateResolutionStore
+from synapse.storage.data_stores import DataStores
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+ "synapse_storage_events_state_delta_single_event", ""
+)
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+ "synapse_storage_events_state_delta_reuse_delta", ""
+)
+
+# The number of forward extremities for each new event.
+forward_extremities_counter = Histogram(
+ "synapse_storage_events_forward_extremities_persisted",
+ "Number of forward extremities for each new event",
+ buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+# The number of stale forward extremities for each new event. Stale extremities
+# are those that were in the previous set of extremities as well as the new.
+stale_forward_extremities_counter = Histogram(
+ "synapse_storage_events_stale_forward_extremities_persisted",
+ "Number of unchanged forward extremities for each new event",
+ buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+
+class _EventPeristenceQueue(object):
+ """Queues up events so that they can be persisted in bulk with only one
+ concurrent transaction per room.
+ """
+
+ _EventPersistQueueItem = namedtuple(
+ "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+ )
+
+ def __init__(self):
+ self._event_persist_queues = {}
+ self._currently_persisting_rooms = set()
+
+ def add_to_queue(self, room_id, events_and_contexts, backfilled):
+ """Add events to the queue, with the given persist_event options.
+
+ NB: due to the normal usage pattern of this method, it does *not*
+ follow the synapse logcontext rules, and leaves the logcontext in
+ place whether or not the returned deferred is ready.
+
+ Args:
+ room_id (str):
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+
+ Returns:
+ defer.Deferred: a deferred which will resolve once the events are
+ persisted. Runs its callbacks *without* a logcontext.
+ """
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+ if queue:
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
+ end_item = queue[-1]
+ if end_item.backfilled == backfilled:
+ end_item.events_and_contexts.extend(events_and_contexts)
+ return end_item.deferred.observe()
+
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+
+ queue.append(
+ self._EventPersistQueueItem(
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ deferred=deferred,
+ )
+ )
+
+ return deferred.observe()
+
+ def handle_queue(self, room_id, per_item_callback):
+ """Attempts to handle the queue for a room if not already being handled.
+
+ The given callback will be invoked with for each item in the queue,
+ of type _EventPersistQueueItem. The per_item_callback will continuously
+ be called with new items, unless the queue becomnes empty. The return
+ value of the function will be given to the deferreds waiting on the item,
+ exceptions will be passed to the deferreds as well.
+
+ This function should therefore be called whenever anything is added
+ to the queue.
+
+ If another callback is currently handling the queue then it will not be
+ invoked.
+ """
+
+ if room_id in self._currently_persisting_rooms:
+ return
+
+ self._currently_persisting_rooms.add(room_id)
+
+ @defer.inlineCallbacks
+ def handle_queue_loop():
+ try:
+ queue = self._get_drainining_queue(room_id)
+ for item in queue:
+ try:
+ ret = yield per_item_callback(item)
+ except Exception:
+ with PreserveLoggingContext():
+ item.deferred.errback()
+ else:
+ with PreserveLoggingContext():
+ item.deferred.callback(ret)
+ finally:
+ queue = self._event_persist_queues.pop(room_id, None)
+ if queue:
+ self._event_persist_queues[room_id] = queue
+ self._currently_persisting_rooms.discard(room_id)
+
+ # set handle_queue_loop off in the background
+ run_as_background_process("persist_events", handle_queue_loop)
+
+ def _get_drainining_queue(self, room_id):
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+
+ try:
+ while True:
+ yield queue.popleft()
+ except IndexError:
+ # Queue has been drained.
+ pass
+
+
+class EventsPersistenceStorage(object):
+ """High level interface for handling persisting newly received events.
+
+ Takes care of batching up events by room, and calculating the necessary
+ current state and forward extremity changes.
+ """
+
+ def __init__(self, hs, stores: DataStores):
+ # We ultimately want to split out the state store from the main store,
+ # so we use separate variables here even though they point to the same
+ # store for now.
+ self.main_store = stores.main
+ self.state_store = stores.main
+
+ self._clock = hs.get_clock()
+ self.is_mine_id = hs.is_mine_id
+ self._event_persist_queue = _EventPeristenceQueue()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
+ @defer.inlineCallbacks
+ def persist_events(self, events_and_contexts, backfilled=False):
+ """
+ Write events to the database
+ Args:
+ events_and_contexts: list of tuples of (event, context)
+ backfilled (bool): Whether the results are retrieved from federation
+ via backfill or not. Used to determine if they're "new" events
+ which might update the current state etc.
+
+ Returns:
+ Deferred[int]: the stream ordering of the latest persisted event
+ """
+ partitioned = {}
+ for event, ctx in events_and_contexts:
+ partitioned.setdefault(event.room_id, []).append((event, ctx))
+
+ deferreds = []
+ for room_id, evs_ctxs in iteritems(partitioned):
+ d = self._event_persist_queue.add_to_queue(
+ room_id, evs_ctxs, backfilled=backfilled
+ )
+ deferreds.append(d)
+
+ for room_id in partitioned:
+ self._maybe_start_persisting(room_id)
+
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
+
+ max_persisted_id = yield self.main_store.get_current_events_token()
+
+ return max_persisted_id
+
+ @defer.inlineCallbacks
+ def persist_event(self, event, context, backfilled=False):
+ """
+
+ Args:
+ event (EventBase):
+ context (EventContext):
+ backfilled (bool):
+
+ Returns:
+ Deferred: resolves to (int, int): the stream ordering of ``event``,
+ and the stream ordering of the latest persisted event
+ """
+ deferred = self._event_persist_queue.add_to_queue(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
+
+ self._maybe_start_persisting(event.room_id)
+
+ yield make_deferred_yieldable(deferred)
+
+ max_persisted_id = yield self.main_store.get_current_events_token()
+ return (event.internal_metadata.stream_ordering, max_persisted_id)
+
+ def _maybe_start_persisting(self, room_id):
+ @defer.inlineCallbacks
+ def persisting_queue(item):
+ with Measure(self._clock, "persist_events"):
+ yield self._persist_events(
+ item.events_and_contexts, backfilled=item.backfilled
+ )
+
+ self._event_persist_queue.handle_queue(room_id, persisting_queue)
+
+ @defer.inlineCallbacks
+ def _persist_events(self, events_and_contexts, backfilled=False):
+ """Calculates the change to current state and forward extremities, and
+ persists the given events and with those updates.
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+ delete_existing (bool):
+
+ Returns:
+ Deferred: resolves when the events have been persisted
+ """
+ if not events_and_contexts:
+ return
+
+ chunks = [
+ events_and_contexts[x : x + 100]
+ for x in range(0, len(events_and_contexts), 100)
+ ]
+
+ for chunk in chunks:
+ # We can't easily parallelize these since different chunks
+ # might contain the same event. :(
+
+ # NB: Assumes that we are only persisting events for one room
+ # at a time.
+
+ # map room_id->list[event_ids] giving the new forward
+ # extremities in each room
+ new_forward_extremeties = {}
+
+ # map room_id->(type,state_key)->event_id tracking the full
+ # state in each room after adding these events.
+ # This is simply used to prefill the get_current_state_ids
+ # cache
+ current_state_for_room = {}
+
+ # map room_id->(to_delete, to_insert) where to_delete is a list
+ # of type/state keys to remove from current state, and to_insert
+ # is a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room = {}
+
+ if not backfilled:
+ with Measure(self._clock, "_calculate_state_and_extrem"):
+ # Work out the new "current state" for each room.
+ # We do this by working out what the new extremities are and then
+ # calculating the state from that.
+ events_by_room = {}
+ for event, context in chunk:
+ events_by_room.setdefault(event.room_id, []).append(
+ (event, context)
+ )
+
+ for room_id, ev_ctx_rm in iteritems(events_by_room):
+ latest_event_ids = yield self.main_store.get_latest_event_ids_in_room(
+ room_id
+ )
+ new_latest_event_ids = yield self._calculate_new_extremities(
+ room_id, ev_ctx_rm, latest_event_ids
+ )
+
+ latest_event_ids = set(latest_event_ids)
+ if new_latest_event_ids == latest_event_ids:
+ # No change in extremities, so no change in state
+ continue
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
+
+ len_1 = (
+ len(latest_event_ids) == 1
+ and len(new_latest_event_ids) == 1
+ )
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_event_ids()) == 1
+ and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ continue
+
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(ev.prev_event_ids())
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.info("Calculating state delta for room %s", room_id)
+ with Measure(
+ self._clock, "persist_events.get_new_state_after_events"
+ ):
+ res = yield self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids = res
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ state_delta_for_room[room_id] = ([], delta_ids)
+ elif current_state is not None:
+ with Measure(
+ self._clock, "persist_events.calculate_state_delta"
+ ):
+ delta = yield self._calculate_state_delta(
+ room_id, current_state
+ )
+ state_delta_for_room[room_id] = delta
+
+ # If we have the current_state then lets prefill
+ # the cache with it.
+ if current_state is not None:
+ current_state_for_room[room_id] = current_state
+
+ yield self.main_store._persist_events_and_state_updates(
+ chunk,
+ current_state_for_room=current_state_for_room,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremeties=new_forward_extremeties,
+ backfilled=backfilled,
+ )
+
+ @defer.inlineCallbacks
+ def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+ """Calculates the new forward extremities for a room given events to
+ persist.
+
+ Assumes that we are only persisting events for one room at a time.
+ """
+
+ # we're only interested in new events which aren't outliers and which aren't
+ # being rejected.
+ new_events = [
+ event
+ for event, ctx in event_contexts
+ if not event.internal_metadata.is_outlier()
+ and not ctx.rejected
+ and not event.internal_metadata.is_soft_failed()
+ ]
+
+ latest_event_ids = set(latest_event_ids)
+
+ # start with the existing forward extremities
+ result = set(latest_event_ids)
+
+ # add all the new events to the list
+ result.update(event.event_id for event in new_events)
+
+ # Now remove all events which are prev_events of any of the new events
+ result.difference_update(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+
+ # Remove any events which are prev_events of any existing events.
+ existing_prevs = yield self.main_store._get_events_which_are_prevs(result)
+ result.difference_update(existing_prevs)
+
+ # Finally handle the case where the new events have soft-failed prev
+ # events. If they do we need to remove them and their prev events,
+ # otherwise we end up with dangling extremities.
+ existing_prevs = yield self.main_store._get_prevs_before_rejected(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+ result.difference_update(existing_prevs)
+
+ # We only update metrics for events that change forward extremities
+ # (e.g. we ignore backfill/outliers/etc)
+ if result != latest_event_ids:
+ forward_extremities_counter.observe(len(result))
+ stale = latest_event_ids & result
+ stale_forward_extremities_counter.observe(len(stale))
+
+ return result
+
+ @defer.inlineCallbacks
+ def _get_new_state_after_events(
+ self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
+ ):
+ """Calculate the current state dict after adding some new events to
+ a room
+
+ Args:
+ room_id (str):
+ room to which the events are being added. Used for logging etc
+
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
+
+ Returns:
+ Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
+ Returns a tuple of two state maps, the first being the full new current
+ state and the second being the delta to the existing current state.
+ If both are None then there has been no change.
+
+ If there has been a change then we only return the delta if its
+ already been calculated. Conversely if we do know the delta then
+ the new current state is only returned if we've already calculated
+ it.
+ """
+ # map from state_group to ((type, key) -> event_id) state map
+ state_groups_map = {}
+
+ # Map from (prev state group, new state group) -> delta state dict
+ state_group_deltas = {}
+
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # This should only happen for outlier events.
+ if not ev.internal_metadata.is_outlier():
+ raise Exception(
+ "Context for new event %s has no state "
+ "group" % (ev.event_id,)
+ )
+ continue
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ # We're only interested in pulling out state that has already
+ # been cached in the context. We'll pull stuff out of the DB later
+ # if necessary.
+ current_state_ids = ctx.get_cached_current_state_ids()
+ if current_state_ids is not None:
+ state_groups_map[ctx.state_group] = current_state_ids
+
+ if ctx.prev_group:
+ state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
+ for event_id in new_latest_event_ids:
+ # First search in the list of new events we're adding.
+ for ev, ctx in events_context:
+ if event_id == ev.event_id and ctx.state_group is not None:
+ event_id_to_state_group[event_id] = ctx.state_group
+ break
+ else:
+ # If we couldn't find it, then we'll need to pull
+ # the state from the database
+ missing_event_ids.add(event_id)
+
+ if missing_event_ids:
+ # Now pull out the state groups for any missing events from DB
+ event_to_groups = yield self.main_store._get_state_group_for_events(
+ missing_event_ids
+ )
+ event_id_to_state_group.update(event_to_groups)
+
+ # State groups of old_latest_event_ids
+ old_state_groups = set(
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ )
+
+ # State groups of new_latest_event_ids
+ new_state_groups = set(
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ )
+
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return None, None
+
+ if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+ # If we're going from one state group to another, lets check if
+ # we have a delta for that transition. If we do then we can just
+ # return that.
+
+ new_state_group = next(iter(new_state_groups))
+ old_state_group = next(iter(old_state_groups))
+
+ delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
+ if delta_ids is not None:
+ # We have a delta from the existing to new current state,
+ # so lets just return that. If we happen to already have
+ # the current state in memory then lets also return that,
+ # but it doesn't matter if we don't.
+ new_state = state_groups_map.get(new_state_group)
+ return new_state, delta_ids
+
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = yield self.state_store._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
+
+ if len(new_state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ return state_groups_map[new_state_groups.pop()], None
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
+
+ state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
+
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+
+ # We need to get the room version, which is in the create event.
+ # Normally that'd be in the database, but its also possible that we're
+ # currently trying to persist it.
+ room_version = None
+ for ev, _ in events_context:
+ if ev.type == EventTypes.Create and ev.state_key == "":
+ room_version = ev.content.get("room_version", "1")
+ break
+
+ if not room_version:
+ room_version = yield self.main_store.get_room_version(room_id)
+
+ logger.debug("calling resolve_state_groups from preserve_events")
+ res = yield self._state_resolution_handler.resolve_state_groups(
+ room_id,
+ room_version,
+ state_groups,
+ events_map,
+ state_res_store=StateResolutionStore(self.main_store),
+ )
+
+ return res.state, None
+
+ @defer.inlineCallbacks
+ def _calculate_state_delta(self, room_id, current_state):
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ tuple[list, dict] (to_delete, to_insert): where to_delete are the
+ type/state_keys to remove from current_state_events and `to_insert`
+ are the updates to current_state_events.
+ """
+ existing_state = yield self.main_store.get_current_state_ids(room_id)
+
+ to_delete = [key for key in existing_state if key not in current_state]
+
+ to_insert = {
+ key: ev_id
+ for key, ev_id in iteritems(current_state)
+ if ev_id != existing_state.get(key)
+ }
+
+ return to_delete, to_insert
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
new file mode 100644
index 00000000..a3681820
--- /dev/null
+++ b/synapse/storage/purge_events.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+
+from twisted.internet import defer
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStorage(object):
+ """High level interface for purging rooms and event history.
+ """
+
+ def __init__(self, hs, stores):
+ self.stores = stores
+
+ @defer.inlineCallbacks
+ def purge_room(self, room_id: str):
+ """Deletes all record of a room
+ """
+
+ state_groups_to_delete = yield self.stores.main.purge_room(room_id)
+ yield self.stores.main.purge_room_state(room_id, state_groups_to_delete)
+
+ @defer.inlineCallbacks
+ def purge_history(self, room_id, token, delete_local_events):
+ """Deletes room history before a certain point
+
+ Args:
+ room_id (str):
+
+ token (str): A topological token to delete events before
+
+ delete_local_events (bool):
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
+ """
+ state_groups = yield self.stores.main.purge_history(
+ room_id, token, delete_local_events
+ )
+
+ logger.info("[purge] finding state groups that can be deleted")
+
+ sg_to_delete = yield self._find_unreferenced_groups(state_groups)
+
+ yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete)
+
+ @defer.inlineCallbacks
+ def _find_unreferenced_groups(self, state_groups):
+ """Used when purging history to figure out which state groups can be
+ deleted.
+
+ Args:
+ state_groups (set[int]): Set of state groups referenced by events
+ that are going to be deleted.
+
+ Returns:
+ Deferred[set[int]] The set of state groups that can be deleted.
+ """
+ # Graph of state group -> previous group
+ graph = {}
+
+ # Set of events that we have found to be referenced by events
+ referenced_groups = set()
+
+ # Set of state groups we've already seen
+ state_groups_seen = set(state_groups)
+
+ # Set of state groups to handle next.
+ next_to_search = set(state_groups)
+ while next_to_search:
+ # We bound size of groups we're looking up at once, to stop the
+ # SQL query getting too big
+ if len(next_to_search) < 100:
+ current_search = next_to_search
+ next_to_search = set()
+ else:
+ current_search = set(itertools.islice(next_to_search, 100))
+ next_to_search -= current_search
+
+ referenced = yield self.stores.main.get_referenced_state_groups(
+ current_search
+ )
+ referenced_groups |= referenced
+
+ # We don't continue iterating up the state group graphs for state
+ # groups that are referenced.
+ current_search -= referenced
+
+ edges = yield self.stores.main.get_previous_state_groups(current_search)
+
+ prevs = set(edges.values())
+ # We don't bother re-handling groups we've already seen
+ prevs -= state_groups_seen
+ next_to_search |= prevs
+ state_groups_seen |= prevs
+
+ graph.update(edges)
+
+ to_delete = state_groups_seen - referenced_groups
+
+ return to_delete
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index a2df8fa8..37358468 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -19,6 +19,8 @@ from six import iteritems, itervalues
import attr
+from twisted.internet import defer
+
from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__)
@@ -322,3 +324,234 @@ class StateFilter(object):
)
return member_filter, non_member_filter
+
+
+class StateGroupStorage(object):
+ """High level interface to fetching state for event.
+ """
+
+ def __init__(self, hs, stores):
+ self.stores = stores
+
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+ (prev_group, delta_ids)
+ """
+
+ return self.stores.main.get_state_group_delta(state_group)
+
+ @defer.inlineCallbacks
+ def get_state_groups_ids(self, _room_id, event_ids):
+ """Get the event IDs of all the state for the state groups for the given events
+
+ Args:
+ _room_id (str): id of the room for these events
+ event_ids (iterable[str]): ids of the events
+
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+ if not event_ids:
+ return {}
+
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+ groups = set(itervalues(event_to_groups))
+ group_to_state = yield self.stores.main._get_state_for_groups(groups)
+
+ return group_to_state
+
+ @defer.inlineCallbacks
+ def get_state_ids_for_group(self, state_group):
+ """Get the event IDs of all the state in the given state group
+
+ Args:
+ state_group (int)
+
+ Returns:
+ Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+ """
+ group_to_state = yield self._get_state_for_groups((state_group,))
+
+ return group_to_state[state_group]
+
+ @defer.inlineCallbacks
+ def get_state_groups(self, room_id, event_ids):
+ """ Get the state groups for the given list of event_ids
+ Returns:
+ Deferred[dict[int, list[EventBase]]]:
+ dict of state_group_id -> list of state events.
+ """
+ if not event_ids:
+ return {}
+
+ group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+ state_event_map = yield self.stores.main.get_events(
+ [
+ ev_id
+ for group_ids in itervalues(group_to_ids)
+ for ev_id in itervalues(group_ids)
+ ],
+ get_prev_content=False,
+ )
+
+ return {
+ group: [
+ state_event_map[v]
+ for v in itervalues(event_id_map)
+ if v in state_event_map
+ ]
+ for group, event_id_map in iteritems(group_to_ids)
+ }
+
+ def _get_state_groups_from_groups(self, groups, state_filter):
+ """Returns the state groups for a given set of groups, filtering on
+ types of state events.
+
+ Args:
+ groups(list[int]): list of state group IDs to query
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+
+ return self.stores.main._get_state_groups_from_groups(groups, state_filter)
+
+ @defer.inlineCallbacks
+ def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
+ """Given a list of event_ids and type tuples, return a list of state
+ dicts for each event.
+ Args:
+ event_ids (list[string])
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
+ """
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+ groups = set(itervalues(event_to_groups))
+ group_to_state = yield self.stores.main._get_state_for_groups(
+ groups, state_filter
+ )
+
+ state_event_map = yield self.stores.main.get_events(
+ [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
+ get_prev_content=False,
+ )
+
+ event_to_state = {
+ event_id: {
+ k: state_event_map[v]
+ for k, v in iteritems(group_to_state[group])
+ if v in state_event_map
+ }
+ for event_id, group in iteritems(event_to_groups)
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ @defer.inlineCallbacks
+ def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
+ """
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
+
+ Args:
+ event_ids(list(str)): events whose state should be returned
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns:
+ A deferred dict from event_id -> (type, state_key) -> event_id
+ """
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+ groups = set(itervalues(event_to_groups))
+ group_to_state = yield self.stores.main._get_state_for_groups(
+ groups, state_filter
+ )
+
+ event_to_state = {
+ event_id: group_to_state[group]
+ for event_id, group in iteritems(event_to_groups)
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ @defer.inlineCallbacks
+ def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id(str): event whose state should be returned
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns:
+ A deferred dict from (type, state_key) -> state_event
+ """
+ state_map = yield self.get_state_for_events([event_id], state_filter)
+ return state_map[event_id]
+
+ @defer.inlineCallbacks
+ def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id(str): event whose state should be returned
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns:
+ A deferred dict from (type, state_key) -> state_event
+ """
+ state_map = yield self.get_state_ids_for_events([event_id], state_filter)
+ return state_map[event_id]
+
+ def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+ return self.stores.main._get_state_for_groups(groups, state_filter)
+
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ Deferred[int]: The state group ID
+ """
+ return self.stores.main.store_state_group(
+ event_id, room_id, prev_group, delta_ids, current_state_ids
+ )
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index cbb0a481..9d851bea 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
- val, = cur.fetchone()
+ (val,) = cur.fetchone()
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 804dbca4..5c4de2e6 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -86,11 +86,12 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
- def observe(self):
+ def observe(self) -> defer.Deferred:
"""Observe the underlying deferred.
- Can return either a deferred if the underlying deferred is still pending
- (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ This returns a brand new deferred that is resolved when the underlying
+ deferred is resolved. Interacting with the returned deferred does not
+ effect the underdlying deferred.
"""
if not self._result:
d = defer.Deferred()
@@ -105,7 +106,7 @@ class ObservableDeferred(object):
return d
else:
success, res = self._result
- return res if success else defer.fail(res)
+ return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
@@ -138,7 +139,7 @@ def concurrently_execute(func, args, limit):
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred.
+ func (func): Function to execute, should return a deferred or coroutine.
args (list): List of arguments to pass to func, each invocation of func
gets a signle argument.
limit (int): Maximum number of conccurent executions.
@@ -148,11 +149,10 @@ def concurrently_execute(func, args, limit):
"""
it = iter(args)
- @defer.inlineCallbacks
- def _concurrently_execute_inner():
+ async def _concurrently_execute_inner():
try:
while True:
- yield func(next(it))
+ await maybe_awaitable(func(next(it)))
except StopIteration:
pass
@@ -309,7 +309,7 @@ class Linearizer(object):
)
else:
- logger.warn(
+ logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name,
key,
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 43fd65d6..da5077b4 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -107,7 +107,7 @@ def register_cache(cache_type, cache_name, cache, collect_callback=None):
if collect_callback:
collect_callback()
except Exception as e:
- logger.warn("Error calculating metrics for %s: %s", cache_name, e)
+ logger.warning("Error calculating metrics for %s: %s", cache_name, e)
raise
yield GaugeMetricFamily("__unused", "")
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5ac2530a..84f5ae22 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,8 +17,8 @@ import functools
import inspect
import logging
import threading
-from collections import namedtuple
-from typing import Any, cast
+from typing import Any, Tuple, Union, cast
+from weakref import WeakValueDictionary
from six import itervalues
@@ -38,6 +38,8 @@ from . import register_cache
logger = logging.getLogger(__name__)
+CacheKey = Union[Tuple, Any]
+
class _CachedFunction(Protocol):
invalidate = None # type: Any
@@ -430,7 +432,7 @@ class CacheDescriptor(_CacheDescriptorBase):
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
- kwargs["cache_context"] = _CacheContext(cache, cache_key)
+ kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
@@ -438,7 +440,7 @@ class CacheDescriptor(_CacheDescriptorBase):
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
- observer = cached_result_d
+ observer = defer.succeed(cached_result_d)
except KeyError:
ret = defer.maybeDeferred(
@@ -482,9 +484,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function.
- Once wrapped, the function returns either a Deferred which resolves to
- the list of results, or (if all results were cached), just the list of
- results.
+ Once wrapped, the function returns a Deferred which resolves to the list
+ of results.
"""
def __init__(
@@ -618,21 +619,45 @@ class CacheListDescriptor(_CacheDescriptorBase):
)
return make_deferred_yieldable(d)
else:
- return results
+ return defer.succeed(results)
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
-class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
- # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
- # which namedtuple does for us (i.e. two _CacheContext are the same if
- # their caches and keys match). This is important in particular to
- # dedupe when we add callbacks to lru cache nodes, otherwise the number
- # of callbacks would grow.
- def invalidate(self):
- self.cache.invalidate(self.key)
+class _CacheContext:
+ """Holds cache information from the cached function higher in the calling order.
+
+ Can be used to invalidate the higher level cache entry if something changes
+ on a lower level.
+ """
+
+ _cache_context_objects = (
+ WeakValueDictionary()
+ ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+
+ def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None
+ self._cache = cache
+ self._cache_key = cache_key
+
+ def invalidate(self): # type: () -> None
+ """Invalidates the cache entry referred to by the context."""
+ self._cache.invalidate(self._cache_key)
+
+ @classmethod
+ def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext
+ """Returns an instance constructed with the given arguments.
+
+ A new instance is only created if none already exists.
+ """
+
+ # We make sure there are no identical _CacheContext instances. This is
+ # important in particular to dedupe when we add callbacks to lru cache
+ # nodes, otherwise the number of callbacks would grow.
+ return cls._cache_context_objects.setdefault(
+ (cache, cache_key), cls(cache, cache_key)
+ )
def cached(
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 1a20c596..3c0e8469 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
def create_resource_tree(desired_tree, root_resource):
- """Create the resource tree for this Home Server.
+ """Create the resource tree for this homeserver.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4b1bcdf2..32868043 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -119,7 +119,7 @@ class Measure(object):
context = LoggingContext.current_context()
if context != self.start_context:
- logger.warn(
+ logger.warning(
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
self.start_context,
context,
@@ -128,7 +128,7 @@ class Measure(object):
return
if not context:
- logger.warn("Expected context. (%r)", self.name)
+ logger.warning("Expected context. (%r)", self.name)
return
current = context.get_resource_usage()
@@ -140,7 +140,7 @@ class Measure(object):
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError:
- logger.warn(
+ logger.warning(
"Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current
)
diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
index 6c0f2bb0..207cd17c 100644
--- a/synapse/util/rlimit.py
+++ b/synapse/util/rlimit.py
@@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no):
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
)
except (ValueError, resource.error) as e:
- logger.warn("Failed to set file or core limit: %s", e)
+ logger.warning("Failed to set file or core limit: %s", e)
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index fa404b9d..ab7d03af 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -42,6 +42,7 @@ def get_version_string(module):
try:
null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
+
try:
git_branch = (
subprocess.check_output(
@@ -51,7 +52,8 @@ def get_version_string(module):
.decode("ascii")
)
git_branch = "b=" + git_branch
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ # FileNotFoundError can arise when git is not installed
git_branch = ""
try:
@@ -63,7 +65,7 @@ def get_version_string(module):
.decode("ascii")
)
git_tag = "t=" + git_tag
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_tag = ""
try:
@@ -74,7 +76,7 @@ def get_version_string(module):
.strip()
.decode("ascii")
)
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_commit = ""
try:
@@ -89,7 +91,7 @@ def get_version_string(module):
)
git_dirty = "dirty" if is_dirty else ""
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
diff --git a/synapse/visibility.py b/synapse/visibility.py
index bf0f1eeb..8c843feb 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
+from synapse.storage import Storage
from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id
@@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
def filter_events_for_client(
- store, user_id, events, is_peeking=False, always_include_ids=frozenset()
+ storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset()
):
"""
Check which events a user is allowed to see
Args:
- store (synapse.storage.DataStore): our datastore (can also be a worker
- store)
+ storage
user_id(str): user id to be checked
events(list[synapse.events.EventBase]): sequence of events to be checked
is_peeking(bool): should be True if:
@@ -68,12 +68,12 @@ def filter_events_for_client(
events = list(e for e in events if not e.internal_metadata.is_soft_failed())
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
- event_id_to_state = yield store.get_state_for_events(
+ event_id_to_state = yield storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
)
- ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
+ ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id
)
@@ -84,7 +84,7 @@ def filter_events_for_client(
else []
)
- erased_senders = yield store.are_users_erased((e.sender for e in events))
+ erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
def allowed(event):
"""
@@ -213,13 +213,17 @@ def filter_events_for_client(
@defer.inlineCallbacks
def filter_events_for_server(
- store, server_name, events, redact=True, check_history_visibility_only=False
+ storage: Storage,
+ server_name,
+ events,
+ redact=True,
+ check_history_visibility_only=False,
):
"""Filter a list of events based on whether given server is allowed to
see them.
Args:
- store (DataStore)
+ storage
server_name (str)
events (iterable[FrozenEvent])
redact (bool): Whether to return a redacted version of the event, or
@@ -274,7 +278,7 @@ def filter_events_for_server(
# Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
- event_to_state_ids = yield store.get_state_ids_for_events(
+ event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""),)
@@ -292,14 +296,14 @@ def filter_events_for_server(
if not visibility_ids:
all_open = True
else:
- event_map = yield store.get_events(visibility_ids)
+ event_map = yield storage.main.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in itervalues(event_map)
)
if not check_history_visibility_only:
- erased_senders = yield store.are_users_erased((e.sender for e in events))
+ erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
@@ -328,7 +332,7 @@ def filter_events_for_server(
# first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events.
- event_to_state_ids = yield store.get_state_ids_for_events(
+ event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@@ -358,7 +362,7 @@ def filter_events_for_server(
return False
return state_key[idx + 1 :] == server_name
- event_map = yield store.get_events(
+ event_map = yield storage.main.get_events(
[
e_id
for e_id, key in iteritems(event_id_to_state_key)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6ba623de..2dc50522 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -19,6 +19,7 @@ import jsonschema
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
@@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase):
"types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"],
+ "org.matrix.labels": ["#fun"],
+ "org.matrix.not_labels": ["#work"],
},
"ephemeral": {
"types": ["m.receipt", "m.typing"],
@@ -320,6 +323,46 @@ class FilteringTestCase(unittest.TestCase):
)
self.assertFalse(Filter(definition).check(event))
+ def test_filter_labels(self):
+ definition = {"org.matrix.labels": ["#fun"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#fun"]},
+ )
+
+ self.assertTrue(Filter(definition).check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#notfun"]},
+ )
+
+ self.assertFalse(Filter(definition).check(event))
+
+ def test_filter_not_labels(self):
+ definition = {"org.matrix.not_labels": ["#fun"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#fun"]},
+ )
+
+ self.assertFalse(Filter(definition).check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#notfun"]},
+ )
+
+ self.assertTrue(Filter(definition).check(event))
+
@defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c4f0bbd3..8efd39c7 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
@@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d56220f4..b4d92cf7 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
+logger = logging.getLogger(__name__)
+
class FederationTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -79,3 +85,123 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.code, 403, failure)
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")
+
+ def test_rejected_message_event_state(self):
+ """
+ Check that we store the state group correctly for rejected non-state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ join_event.format_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def test_rejected_state_event_state(self):
+ """
+ Check that we store the state group correctly for rejected state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": "org.matrix.test",
+ "state_key": "test_key",
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ join_event.format_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def _build_and_send_join_event(self, other_server, other_user, room_id):
+ join_event = self.get_success(
+ self.handler.on_make_join_request(other_server, room_id, other_user)
+ )
+ # the auth code requires that a signature exists, but doesn't check that
+ # signature... go figure.
+ join_event.signatures[other_server] = {"x": "y"}
+ with LoggingContext(request="send_join"):
+ d = run_in_background(
+ self.handler.on_send_join_request, other_server, join_event
+ )
+ self.get_success(d)
+
+ # sanity-check: the room should show that the new user is a member
+ r = self.get_success(self.store.get_current_state_ids(room_id))
+ self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
+
+ return join_event
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 67f10130..5ec568f4 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
- "get_devices_by_remote",
+ "get_device_updates_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_devices_by_remote.return_value = (0, [])
+ self.datastore.get_device_updates_by_remote.return_value = (0, [])
def get_received_txn_response(*args):
return defer.succeed(None)
@@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ None
+ )
def test_started_typing_local(self):
self.room_members = [U_APPLE, U_BANANA]
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba64..2096ba3c 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
+
+
+def get_test_https_policy():
+ """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+ Returns:
+ IPolicyForHTTPS
+ """
+ ca_file = get_test_ca_cert_file()
+ with open(ca_file) as stream:
+ content = stream.read()
+ cert = Certificate.loadPEM(content)
+ trust_root = trustRootFromCertificates([cert])
+ return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 71d70252..cfcd98ff 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
+ # grab a hold of the TLS connection, in case it gets torn down
+ server_tls_connection = server_tls_protocol._tlsConnection
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_protocol = server_tls_protocol.wrappedProtocol
+
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
- server_name = server_tls_protocol._tlsConnection.get_servername()
+ server_name = server_tls_connection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
- # fish the test server back out of the server-side TLS protocol.
- return server_tls_protocol.wrappedProtocol
+ return http_protocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 00000000..22abf765
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def _make_connection(
+ self, client_factory, server_factory, ssl=False, expected_sni=None
+ ):
+ """Builds a test server, and completes the outgoing client connection
+
+ Args:
+ client_factory (interfaces.IProtocolFactory): the the factory that the
+ application is trying to use to make the outbound connection. We will
+ invoke it to build the client Protocol
+
+ server_factory (interfaces.IProtocolFactory): a factory to build the
+ server-side protocol
+
+ ssl (bool): If true, we will expect an ssl connection and wrap
+ server_factory with a TLSMemoryBIOFactory
+
+ expected_sni (bytes|None): the expected SNI value
+
+ Returns:
+ IProtocol: the server Protocol returned by server_factory
+ """
+ if ssl:
+ server_factory = _wrap_server_factory_for_tls(server_factory)
+
+ server_protocol = server_factory.buildProtocol(None)
+
+ # now, tell the client protocol factory to build the client protocol,
+ # and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server protocol to send its stuff back to the client, too
+ server_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_protocol)
+ )
+
+ if ssl:
+ http_protocol = server_protocol.wrappedProtocol
+ tls_connection = server_protocol._tlsConnection
+ else:
+ http_protocol = server_protocol
+ tls_connection = None
+
+ # give the reactor a pump to get the TLS juices flowing (if needed)
+ self.reactor.advance(0)
+
+ if expected_sni is not None:
+ server_name = tls_connection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ return http_protocol
+
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request(self):
+ agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=True,
+ expected_sni=b"test.com",
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_http_request_via_proxy(self):
+ agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
+
+ Args:
+ factory (interfaces.IProtocolFactory): protocol factory to wrap
+ sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ if sanlist is None:
+ sanlist = [b"DNS:test.com"]
+
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+
+
+def _get_test_protocol_factory():
+ """Get a protocol Factory which will build an HTTPChannel
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ return server_factory
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62..af2327fb 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
config = self.default_config()
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+ hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
return hs
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 104349cd..4f924ce4 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore()
+ self.storage = hs.get_storage()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index a368117b..b68e9fe0 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
- self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)]))
+ self.get_success(
+ self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+ )
self.replicate()
event_source = RoomEventSource(self.hs)
@@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self.master_store.persist_events([(event, context)], backfilled=True)
+ self.storage.persistence.persist_events(
+ [(event, context)], backfilled=True
+ )
)
else:
- self.get_success(self.master_store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index d3a4f717..95750582 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -561,3 +561,85 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"]
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ "state_groups",
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store._simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+ test_purge_room.skip = "Disabled because it's currently broken"
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2f2ca746..5e38fd6c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
@@ -811,6 +811,105 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
+ def test_filter_labels(self):
+ """Test that we can filter by a label."""
+ message_filter = json.dumps(
+ {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
+ )
+
+ events = self._test_filter_labels(message_filter)
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_filter_not_labels(self):
+ """Test that we can filter by the absence of a label."""
+ message_filter = json.dumps(
+ {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
+ )
+
+ events = self._test_filter_labels(message_filter)
+
+ self.assertEqual(len(events), 3, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+ self.assertEqual(
+ events[2]["content"]["body"], "with two wrong labels", events[2]
+ )
+
+ def test_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label."""
+ sync_filter = json.dumps(
+ {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+ )
+
+ events = self._test_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def _test_filter_labels(self, message_filter):
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ )
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&filter=%s"
+ % (self.room_id, token, message_filter),
+ )
+ self.render(request)
+
+ return channel.json_body["chunk"]
+
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index cdded88b..8ea0cb05 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -106,13 +106,22 @@ class RestHelper(object):
self.auth_user_id = temp_id
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
- if txn_id is None:
- txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"
- path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = {"msgtype": "m.text", "body": body}
+
+ return self.send_event(
+ room_id, "m.room.message", content, txn_id, tok, expect_code
+ )
+
+ def send_event(
+ self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
+ ):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+
+ path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
if tok:
path = path + "?access_token=%s" % tok
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 71895094..3283c0e4 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
from mock import Mock
import synapse.rest.admin
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
@@ -26,7 +28,12 @@ from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test"
- servlets = [sync.register_servlets]
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
def make_homeserver(self, reactor, clock):
@@ -70,6 +77,140 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
+class SyncFilterTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def test_sync_filter_labels(self):
+ """Test that we can filter by a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_sync_filter_not_labels(self):
+ """Test that we can filter by the absence of a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.not_labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 3, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+ self.assertEqual(
+ events[2]["content"]["body"], "with two wrong labels", events[2]
+ )
+
+ def test_sync_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def _test_sync_filter_labels(self, sync_filter):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ request, channel = self.make_request(
+ "GET", "/sync?filter=%s" % sync_filter, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+
+
class SyncTypingTests(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/server.py b/tests/server.py
index e397ebe8..f878aeaa 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -161,7 +161,11 @@ def make_request(
path = path.encode("ascii")
# Decorate it to be the full path, if we're using shorthand
- if shorthand and not path.startswith(b"/_matrix"):
+ if (
+ shorthand
+ and not path.startswith(b"/_matrix")
+ and not path.startswith(b"/_synapse")
+ ):
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
@@ -391,11 +395,24 @@ class FakeTransport(object):
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
- self.disconnected = True
+
+ # if we still have data to write, delay until that is done
+ if self.buffer:
+ logger.info(
+ "FakeTransport: Delaying disconnect until buffer is flushed"
+ )
+ else:
+ self.disconnected = True
def abortConnection(self):
logger.info("FakeTransport: abortConnection()")
- self.loseConnection()
+
+ if not self.disconnecting:
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(None)
+
+ self.disconnected = True
def pauseProducing(self):
if not self.producer:
@@ -426,6 +443,9 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _produce)
def write(self, byt):
+ if self.disconnecting:
+ raise Exception("Writing to disconnecting FakeTransport")
+
self.buffer = self.buffer + byt
# always actually do the write asynchronously. Some protocols (notably the
@@ -470,6 +490,10 @@ class FakeTransport(object):
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
+ if not self.buffer and self.disconnecting:
+ logger.info("FakeTransport: Buffer now empty, completing disconnect")
+ self.disconnected = True
+
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
"""
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index dd49a145..9b81b536 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
a.func.prefill(("foo",), ObservableDeferred(d))
- self.assertEquals(a.func("foo"), d.result)
+ self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f..6f8d9909 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
- def test_get_devices_by_remote(self):
+ def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"somehost", -1, limit=100
)
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks
- def test_get_devices_by_remote_limited(self):
+ def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
#
# first we should get a single update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100
)
self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device
# update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self._check_devices_in_updates(device_ids3, device_updates)
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
- received_device_ids = {update["device_id"] for update in device_updates}
+ received_device_ids = {
+ update["device_id"] for edu_type, update in device_updates
+ }
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
new file mode 100644
index 00000000..d128fde4
--- /dev/null
+++ b/tests/storage/test_e2e_room_keys.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests import unittest
+
+# sample room_key data for use in the tests
+room_key = {
+ "first_message_index": 1,
+ "forwarded_count": 1,
+ "is_verified": False,
+ "session_data": "SSBBTSBBIEZJU0gK",
+}
+
+
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_room_keys_version_delete(self):
+ # test that deleting a room key backup deletes the keys
+ version1 = self.get_success(
+ self.store.create_e2e_room_keys_version(
+ "user_id", {"algorithm": "rot13", "auth_data": {}}
+ )
+ )
+
+ self.get_success(
+ self.store.set_e2e_room_key(
+ "user_id", version1, "room", "session", room_key
+ )
+ )
+
+ version2 = self.get_success(
+ self.store.create_e2e_room_keys_version(
+ "user_id", {"algorithm": "rot13", "auth_data": {}}
+ )
+ )
+
+ self.get_success(
+ self.store.set_e2e_room_key(
+ "user_id", version2, "room", "session", room_key
+ )
+ )
+
+ # make sure the keys were stored properly
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1))
+ self.assertEqual(len(keys["rooms"]), 1)
+
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2))
+ self.assertEqual(len(keys["rooms"]), 1)
+
+ # delete version1
+ self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1))
+
+ # make sure the key from version1 is gone, and the key from version2 is
+ # still there
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1))
+ self.assertEqual(len(keys["rooms"]), 0)
+
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2))
+ self.assertEqual(len(keys["rooms"]), 1)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index f671599c..b9fafaa1 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- storage = self.hs.get_datastore()
+ store = self.hs.get_datastore()
+ storage = self.hs.get_storage()
# Get the topological token
- event = storage.get_topological_token_for_event(last["event_id"])
+ event = store.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
# Purge everything before this topological token
- purge = storage.purge_history(self.room_id, event, True)
+ purge = storage.purge_events.purge_history(self.room_id, event, True)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
+ get_first = store.get_event(first["event_id"])
+ get_second = store.get_event(second["event_id"])
+ get_third = store.get_event(third["event_id"])
+ get_last = store.get_event(last["event_id"])
self.pump()
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 427d3c49..4561c3e3 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.store.persist_event(event_1, context_1))
+ self.get_success(self.storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.store.persist_event(event_2, context_2))
+ self.get_success(self.storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1bee4570..3ddaa151 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield self.store.persist_event(
+ yield self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 447a3c6f..9ddd17f7 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5c2cf3c2..43200654 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_datastore = self.store
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -63,7 +65,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
builder
)
- yield self.store.persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@@ -82,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups_ids(
+ state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id]
)
self.assertEqual(len(state_group_map), 1)
@@ -101,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
+ state_group_map = yield self.storage.state.get_state_groups(
+ self.room, [e2.event_id]
+ )
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -141,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.store.get_state_for_event(e5.event_id)
+ state = yield self.storage.state.get_state_for_event(e5.event_id)
self.assertIsNotNone(e4)
@@ -157,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -181,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
@@ -199,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -215,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+ group_ids = yield self.storage.state.get_state_groups_ids(
+ room_id, [e5.event_id]
+ )
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -237,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -250,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -267,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -287,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -304,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -317,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -331,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
- group
- )
+ (
+ is_all,
+ known_absent,
+ state_dict_ids,
+ ) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
@@ -346,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
state_dict_ids.pop((e2.type, e2.state_key))
- self.store._state_group_cache.invalidate(group)
- self.store._state_group_cache.update(
- sequence=self.store._state_group_cache.sequence,
+ self.state_datastore._state_group_cache.invalidate(group)
+ self.state_datastore._state_group_cache.update(
+ sequence=self.state_datastore._state_group_cache.sequence,
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
fetched_keys=((e1.type, e1.state_key),),
)
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
- group
- )
+ (
+ is_all,
+ known_absent,
+ state_dict_ids,
+ ) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, False)
self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
@@ -369,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -381,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -394,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -405,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -424,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -435,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -448,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -459,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index a73f18f8..7d82b584 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -36,7 +36,8 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0]
@@ -58,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
)
self.handler = self.homeserver.get_handlers().federation_handler
- self.handler.do_auth = lambda *a, **b: succeed(True)
+ self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+ context
+ )
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
@@ -75,7 +78,8 @@ class MessageAcceptTests(unittest.TestCase):
self.assertEqual(
self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0],
"$join:test.serv",
@@ -97,7 +101,8 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0]
@@ -137,6 +142,6 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure the invalid event isn't there
extrem = maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
new file mode 100644
index 00000000..7657bdde
--- /dev/null
+++ b/tests/test_phone_home.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import resource
+
+import mock
+
+from synapse.app.homeserver import phone_stats_home
+
+from tests.unittest import HomeserverTestCase
+
+
+class PhoneHomeStatsTestCase(HomeserverTestCase):
+ def test_performance_frozen_clock(self):
+ """
+ If time doesn't move, don't error out.
+ """
+ past_stats = [
+ (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
+ ]
+ stats = {}
+ self.get_success(phone_stats_home(self.hs, stats, past_stats))
+ self.assertEqual(stats["cpu_average"], 0)
+
+ def test_performance_100(self):
+ """
+ 1 second of usage over 1 second is 100% CPU usage.
+ """
+ real_res = resource.getrusage(resource.RUSAGE_SELF)
+ old_resource = mock.Mock(spec=real_res)
+ old_resource.ru_utime = real_res.ru_utime - 1
+ old_resource.ru_stime = real_res.ru_stime
+ old_resource.ru_maxrss = real_res.ru_maxrss
+
+ past_stats = [(self.hs.get_clock().time(), old_resource)]
+ stats = {}
+ self.reactor.advance(1)
+ self.get_success(phone_stats_home(self.hs, stats, past_stats))
+ self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5)
diff --git a/tests/test_state.py b/tests/test_state.py
index 610ec9fb..17653594 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -21,6 +21,7 @@ from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler
from tests import unittest
@@ -158,10 +159,12 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = StateGroupStore()
+ storage = Mock(main=self.store, state=self.store)
hs = Mock(
spec_set=[
"config",
"get_datastore",
+ "get_storage",
"get_auth",
"get_state_handler",
"get_clock",
@@ -174,6 +177,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
+ hs.get_storage.return_value = storage
self.state = StateHandler(hs)
self.event_id = 0
@@ -195,16 +199,22 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(graph.walk())
- context_store = {}
+ context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
context = yield self.state.compute_event_context(event)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertEqual(2, len(prev_state_ids))
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
@@ -238,12 +248,19 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between B and C
+
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
)
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
graph = Graph(
@@ -289,11 +306,18 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between C and D because bans win over other
+ # changes
+
+ ctx_c = context_store["C"]
+ ctx_e = context_store["E"]
+ prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
)
+ self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
+ self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@defer.inlineCallbacks
def test_branch_have_perms_conflict(self):
@@ -357,12 +381,20 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # B ends up winning the resolution between B and C because power levels
+ # win over other changes.
+ ctx_b = context_store["B"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
)
+ self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
def _add_depths(self, nodes, edges):
def _get_depth(ev):
node = nodes[ev]
@@ -387,13 +419,16 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(current_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ self.assertCountEqual(
+ (e.event_id for e in old_state), current_state_ids.values()
)
- self.assertIsNotNone(context.state_group)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertEqual(context.state_group_before_event, context.state_group)
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
@@ -408,11 +443,18 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
prev_state_ids = yield context.get_prev_state_ids(self.store)
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(prev_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ self.assertCountEqual(
+ (e.event_id for e in old_state + [event]), current_state_ids.values()
)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertNotEqual(context.state_group_before_event, context.state_group)
+ self.assertEqual(context.state_group_before_event, context.prev_group)
+ self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
+
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id"
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 18f1a003..f7381b28 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -14,6 +14,8 @@
# limitations under the License.
import logging
+from mock import Mock
+
from twisted.internet import defer
from twisted.internet.defer import succeed
@@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
+ self.storage = self.hs.get_storage()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
@@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -100,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# ... and the filtering happens.
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
for i in range(0, len(events_to_filter)):
@@ -137,7 +140,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -159,7 +162,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -180,7 +183,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -257,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
logger.info("Starting filtering")
start = time.time()
+
+ storage = Mock()
+ storage.main = test_store
+ storage.state = test_store
+
filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter
)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 5713870f..39e360fe 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -310,14 +310,14 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
- self.assertEqual(r, ["spam", "eggs"])
+ self.assertEqual(r.result, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
- self.assertEqual(r, ["chips"])
+ self.assertEqual(r.result, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
@@ -325,9 +325,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2)
- self.assertEqual(r, ["spam", "eggs"])
+ self.assertEqual(r.result, ["spam", "eggs"])
r = obj.fn(1, 3)
- self.assertEqual(r, ["chips"])
+ self.assertEqual(r.result, ["chips"])
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
diff --git a/tests/utils.py b/tests/utils.py
index 8cced4b7..7dc9bdc5 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -325,10 +325,16 @@ def setup_test_homeserver(
if homeserverToUse.__name__ == "TestHomeServer":
hs.setup_master()
else:
+ # If we have been given an explicit datastore we probably want to mock
+ # out the DataStores somehow too. This all feels a bit wrong, but then
+ # mocking the stores feels wrong too.
+ datastores = Mock(datastore=datastore)
+
hs = homeserverToUse(
name,
db_pool=None,
datastore=datastore,
+ datastores=datastores,
config=config,
version_string="Synapse/tests",
database_engine=db_engine,
@@ -646,7 +652,7 @@ def create_room(hs, room_id, creator_id):
creator_id (str)
"""
- store = hs.get_datastore()
+ persistence_store = hs.get_storage().persistence
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
@@ -663,4 +669,4 @@ def create_room(hs, room_id, creator_id):
event, context = yield event_creation_handler.create_new_client_event(builder)
- yield store.persist_event(event, context)
+ yield persistence_store.persist_event(event, context)
diff --git a/tox.ini b/tox.ini
index e3a53f34..62b350ea 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = packaging, py35, py36, py37, check_codestyle, check_isort
+envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort
[base]
basepython = python3.7
@@ -114,16 +114,16 @@ skip_install = True
basepython = python3.6
deps =
flake8
- black==19.3b0 # We pin so that our tests don't start failing on new releases of black.
+ black==19.10b0 # We pin so that our tests don't start failing on new releases of black.
commands =
python -m black --check --diff .
- /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}"
+ /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}"
{toxinidir}/scripts-dev/config-lint.sh
[testenv:check_isort]
skip_install = True
deps = isort
-commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests"
+commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev scripts"
[testenv:check-newsfragment]
skip_install = True
@@ -167,6 +167,6 @@ deps =
env =
MYPYPATH = stubs/
extras = all
-commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \
+commands = mypy \
synapse/logging/ \
synapse/config/