summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2020-06-14 18:05:46 +0200
committerAndrej Shadura <andrewsh@debian.org>2020-06-14 18:05:46 +0200
commit08d5e062dc43f8af47c39699483652c1235ec5d2 (patch)
treefe437819faf5d7e1173b41388ac2973cc0c342d1
parentda7f96aa2a3b1485dafa016f38aac1d4376b64e7 (diff)
New upstream version 1.15.0
-rw-r--r--.github/ISSUE_TEMPLATE.md5
-rw-r--r--.github/ISSUE_TEMPLATE/BUG_REPORT.md8
-rw-r--r--CHANGES.md179
-rw-r--r--CONTRIBUTING.md184
-rw-r--r--INSTALL.md62
-rw-r--r--README.rst24
-rw-r--r--UPGRADE.rst12
-rw-r--r--contrib/grafana/synapse.json613
-rw-r--r--contrib/systemd/matrix-synapse.service3
-rw-r--r--debian/changelog12
-rw-r--r--docker/Dockerfile2
-rw-r--r--docker/Dockerfile-dhvirtualenv8
-rw-r--r--docs/admin_api/README.rst18
-rw-r--r--docs/admin_api/delete_group.md4
-rw-r--r--docs/admin_api/media_admin_api.md6
-rw-r--r--docs/admin_api/purge_history_api.rst9
-rw-r--r--docs/admin_api/purge_remote_media.rst7
-rw-r--r--docs/admin_api/room_membership.md3
-rw-r--r--docs/admin_api/rooms.md54
-rw-r--r--docs/admin_api/user_admin_api.rst285
-rw-r--r--docs/dev/git.md148
-rw-r--r--docs/dev/git/branches.jpgbin0 -> 72228 bytes
-rw-r--r--docs/dev/git/clean.pngbin0 -> 110840 bytes
-rw-r--r--docs/dev/git/squash.pngbin0 -> 29667 bytes
-rw-r--r--docs/openid.md206
-rw-r--r--docs/reverse_proxy.md154
-rw-r--r--docs/saml_mapping_providers.md77
-rw-r--r--docs/sample_config.yaml267
-rw-r--r--docs/spam_checker.md19
-rw-r--r--docs/sso_mapping_providers.md148
-rw-r--r--docs/systemd-with-workers/system/matrix-synapse-worker@.service2
-rw-r--r--docs/tcp_replication.md4
-rw-r--r--docs/turn-howto.md57
-rw-r--r--mypy.ini3
-rwxr-xr-xscripts-dev/build_debian_packages2
-rwxr-xr-xscripts-dev/check-newsfragment20
-rw-r--r--scripts-dev/convert_server_keys.py9
-rwxr-xr-xscripts/synapse_port_db7
-rwxr-xr-xsetup.py1
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py62
-rw-r--r--synapse/api/constants.py6
-rw-r--r--synapse/api/ratelimiting.py153
-rw-r--r--synapse/api/room_versions.py30
-rw-r--r--synapse/app/generic_worker.py220
-rw-r--r--synapse/app/homeserver.py82
-rw-r--r--synapse/appservice/__init__.py2
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/cache.py198
-rw-r--r--synapse/config/captcha.py17
-rw-r--r--synapse/config/database.py6
-rw-r--r--synapse/config/emailconfig.py4
-rw-r--r--synapse/config/homeserver.py4
-rw-r--r--synapse/config/key.py4
-rw-r--r--synapse/config/logger.py1
-rw-r--r--synapse/config/metrics.py3
-rw-r--r--synapse/config/oidc_config.py203
-rw-r--r--synapse/config/ratelimiting.py8
-rw-r--r--synapse/config/registration.py12
-rw-r--r--synapse/config/repository.py1
-rw-r--r--synapse/config/saml2_config.py20
-rw-r--r--synapse/config/server.py33
-rw-r--r--synapse/config/server_notices_config.py2
-rw-r--r--synapse/config/spam_checker.py38
-rw-r--r--synapse/config/sso.py20
-rw-r--r--synapse/config/workers.py45
-rw-r--r--synapse/event_auth.py88
-rw-r--r--synapse/events/spamcheck.py78
-rw-r--r--synapse/events/utils.py35
-rw-r--r--synapse/events/validator.py7
-rw-r--r--synapse/federation/federation_base.py6
-rw-r--r--synapse/federation/send_queue.py84
-rw-r--r--synapse/federation/sender/__init__.py12
-rw-r--r--synapse/federation/sender/per_destination_queue.py9
-rw-r--r--synapse/federation/sender/transaction_manager.py6
-rw-r--r--synapse/groups/groups_server.py257
-rw-r--r--synapse/handlers/_base.py60
-rw-r--r--synapse/handlers/auth.py32
-rw-r--r--synapse/handlers/device.py144
-rw-r--r--synapse/handlers/e2e_keys.py22
-rw-r--r--synapse/handlers/federation.py86
-rw-r--r--synapse/handlers/groups_local.py82
-rw-r--r--synapse/handlers/identity.py103
-rw-r--r--synapse/handlers/message.py59
-rw-r--r--synapse/handlers/oidc_handler.py1049
-rw-r--r--synapse/handlers/presence.py12
-rw-r--r--synapse/handlers/register.py141
-rw-r--r--synapse/handlers/room.py229
-rw-r--r--synapse/handlers/room_list.py20
-rw-r--r--synapse/handlers/room_member.py434
-rw-r--r--synapse/handlers/room_member_worker.py42
-rw-r--r--synapse/handlers/saml_handler.py81
-rw-r--r--synapse/handlers/search.py44
-rw-r--r--synapse/handlers/set_password.py5
-rw-r--r--synapse/handlers/state_deltas.py9
-rw-r--r--synapse/handlers/stats.py47
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/handlers/ui_auth/checkers.py15
-rw-r--r--synapse/handlers/user_directory.py118
-rw-r--r--synapse/http/client.py13
-rw-r--r--synapse/http/matrixfederationclient.py21
-rw-r--r--synapse/http/server.py66
-rw-r--r--synapse/http/site.py16
-rw-r--r--synapse/logging/utils.py10
-rw-r--r--synapse/metrics/__init__.py79
-rw-r--r--synapse/metrics/_exposition.py12
-rw-r--r--synapse/metrics/background_process_metrics.py104
-rw-r--r--synapse/module_api/__init__.py8
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/push/emailpusher.py38
-rw-r--r--synapse/push/httppusher.py11
-rw-r--r--synapse/push/mailer.py84
-rw-r--r--synapse/push/push_rule_evaluator.py4
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/replication/http/__init__.py15
-rw-r--r--synapse/replication/http/_base.py24
-rw-r--r--synapse/replication/http/federation.py13
-rw-r--r--synapse/replication/http/membership.py52
-rw-r--r--synapse/replication/http/presence.py116
-rw-r--r--synapse/replication/http/send_event.py4
-rw-r--r--synapse/replication/http/streams.py5
-rw-r--r--synapse/replication/slave/storage/_base.py59
-rw-r--r--synapse/replication/slave/storage/account_data.py14
-rw-r--r--synapse/replication/slave/storage/client_ips.py3
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py6
-rw-r--r--synapse/replication/slave/storage/devices.py6
-rw-r--r--synapse/replication/slave/storage/events.py91
-rw-r--r--synapse/replication/slave/storage/groups.py6
-rw-r--r--synapse/replication/slave/storage/presence.py14
-rw-r--r--synapse/replication/slave/storage/push_rule.py14
-rw-r--r--synapse/replication/slave/storage/pushers.py6
-rw-r--r--synapse/replication/slave/storage/receipts.py6
-rw-r--r--synapse/replication/slave/storage/room.py4
-rw-r--r--synapse/replication/tcp/client.py133
-rw-r--r--synapse/replication/tcp/commands.py33
-rw-r--r--synapse/replication/tcp/handler.py111
-rw-r--r--synapse/replication/tcp/redis.py7
-rw-r--r--synapse/replication/tcp/resource.py33
-rw-r--r--synapse/replication/tcp/streams/_base.py160
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/replication/tcp/streams/federation.py36
-rw-r--r--synapse/res/templates/notice_expiry.html2
-rw-r--r--synapse/res/templates/notice_expiry.txt2
-rw-r--r--synapse/res/templates/sso_error.html18
-rw-r--r--synapse/rest/admin/__init__.py10
-rw-r--r--synapse/rest/admin/_base.py33
-rw-r--r--synapse/rest/admin/devices.py161
-rw-r--r--synapse/rest/admin/rooms.py45
-rw-r--r--synapse/rest/admin/users.py31
-rw-r--r--synapse/rest/client/v1/login.py126
-rw-r--r--synapse/rest/client/v1/logout.py6
-rw-r--r--synapse/rest/client/v1/room.py20
-rw-r--r--synapse/rest/client/v2_alpha/account.py21
-rw-r--r--synapse/rest/client/v2_alpha/auth.py22
-rw-r--r--synapse/rest/client/v2_alpha/keys.py8
-rw-r--r--synapse/rest/client/v2_alpha/register.py20
-rw-r--r--synapse/rest/client/v2_alpha/relations.py2
-rw-r--r--synapse/rest/client/versions.py16
-rw-r--r--synapse/rest/media/v1/_base.py27
-rw-r--r--synapse/rest/oidc/__init__.py27
-rw-r--r--synapse/rest/oidc/callback_resource.py31
-rw-r--r--synapse/rest/saml2/response_resource.py26
-rw-r--r--synapse/server.py30
-rw-r--r--synapse/server.pyi13
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py15
-rw-r--r--synapse/server_notices/server_notices_manager.py6
-rw-r--r--synapse/state/__init__.py4
-rw-r--r--synapse/static/client/login/js/login.js155
-rw-r--r--synapse/storage/_base.py11
-rw-r--r--synapse/storage/data_stores/__init__.py9
-rw-r--r--synapse/storage/data_stores/main/__init__.py40
-rw-r--r--synapse/storage/data_stores/main/account_data.py78
-rw-r--r--synapse/storage/data_stores/main/appservice.py4
-rw-r--r--synapse/storage/data_stores/main/cache.py146
-rw-r--r--synapse/storage/data_stores/main/censor_events.py208
-rw-r--r--synapse/storage/data_stores/main/client_ips.py3
-rw-r--r--synapse/storage/data_stores/main/devices.py100
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py98
-rw-r--r--synapse/storage/data_stores/main/event_federation.py83
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py74
-rw-r--r--synapse/storage/data_stores/main/events.py1232
-rw-r--r--synapse/storage/data_stores/main/events_worker.py192
-rw-r--r--synapse/storage/data_stores/main/group_server.py92
-rw-r--r--synapse/storage/data_stores/main/keys.py10
-rw-r--r--synapse/storage/data_stores/main/metrics.py128
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py149
-rw-r--r--synapse/storage/data_stores/main/profile.py2
-rw-r--r--synapse/storage/data_stores/main/purge_events.py399
-rw-r--r--synapse/storage/data_stores/main/push_rule.py14
-rw-r--r--synapse/storage/data_stores/main/receipts.py13
-rw-r--r--synapse/storage/data_stores/main/registration.py22
-rw-r--r--synapse/storage/data_stores/main/rejections.py11
-rw-r--r--synapse/storage/data_stores/main/relations.py60
-rw-r--r--synapse/storage/data_stores/main/room.py80
-rw-r--r--synapse/storage/data_stores/main/roommember.py148
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres30
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py80
-rw-r--r--synapse/storage/data_stores/main/search.py119
-rw-r--r--synapse/storage/data_stores/main/signatures.py30
-rw-r--r--synapse/storage/data_stores/main/state.py40
-rw-r--r--synapse/storage/data_stores/main/tags.py3
-rw-r--r--synapse/storage/data_stores/main/transactions.py9
-rw-r--r--synapse/storage/data_stores/state/bg_updates.py12
-rw-r--r--synapse/storage/data_stores/state/store.py6
-rw-r--r--synapse/storage/database.py88
-rw-r--r--synapse/storage/persist_events.py37
-rw-r--r--synapse/storage/prepare_database.py21
-rw-r--r--synapse/storage/util/id_generators.py180
-rw-r--r--synapse/util/async_helpers.py12
-rw-r--r--synapse/util/caches/__init__.py149
-rw-r--r--synapse/util/caches/descriptors.py36
-rw-r--r--synapse/util/caches/expiringcache.py29
-rw-r--r--synapse/util/caches/lrucache.py56
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/caches/stream_change_cache.py33
-rw-r--r--synapse/util/caches/ttlcache.py2
-rw-r--r--synapse/util/frozenutils.py2
-rw-r--r--synapse/util/patch_inline_callbacks.py9
-rw-r--r--synapse/util/ratelimitutils.py2
-rw-r--r--synapse/util/stringutils.py82
-rwxr-xr-xsynctl31
-rw-r--r--tests/api/test_ratelimiting.py96
-rw-r--r--tests/config/test_cache.py171
-rw-r--r--tests/events/test_utils.py2
-rw-r--r--tests/federation/test_complexity.py8
-rw-r--r--tests/federation/test_federation_sender.py2
-rw-r--r--tests/handlers/test_e2e_keys.py10
-rw-r--r--tests/handlers/test_federation.py67
-rw-r--r--tests/handlers/test_oidc.py570
-rw-r--r--tests/handlers/test_profile.py6
-rw-r--r--tests/handlers/test_register.py10
-rw-r--r--tests/handlers/test_typing.py5
-rw-r--r--tests/handlers/test_user_directory.py12
-rw-r--r--tests/replication/_base.py (renamed from tests/replication/tcp/streams/_base.py)20
-rw-r--r--tests/replication/slave/storage/_base.py59
-rw-r--r--tests/replication/slave/storage/test_events.py24
-rw-r--r--tests/replication/tcp/streams/test_account_data.py117
-rw-r--r--tests/replication/tcp/streams/test_events.py2
-rw-r--r--tests/replication/tcp/streams/test_federation.py81
-rw-r--r--tests/replication/tcp/streams/test_receipts.py2
-rw-r--r--tests/replication/tcp/streams/test_typing.py7
-rw-r--r--tests/replication/tcp/test_commands.py4
-rw-r--r--tests/replication/test_federation_ack.py71
-rw-r--r--tests/rest/admin/test_device.py541
-rw-r--r--tests/rest/admin/test_room.py41
-rw-r--r--tests/rest/admin/test_user.py241
-rw-r--r--tests/rest/client/v1/test_events.py8
-rw-r--r--tests/rest/client/v1/test_login.py271
-rw-r--r--tests/rest/client/v1/test_rooms.py9
-rw-r--r--tests/rest/client/v1/test_typing.py10
-rw-r--r--tests/rest/client/v2_alpha/test_account.py4
-rw-r--r--tests/rest/client/v2_alpha/test_register.py45
-rw-r--r--tests/rest/media/v1/test_media_storage.py135
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py60
-rw-r--r--tests/storage/test__base.py8
-rw-r--r--tests/storage/test_appservice.py10
-rw-r--r--tests/storage/test_base.py3
-rw-r--r--tests/storage/test_cleanup_extrems.py4
-rw-r--r--tests/storage/test_client_ips.py13
-rw-r--r--tests/storage/test_event_metrics.py2
-rw-r--r--tests/storage/test_event_push_actions.py3
-rw-r--r--tests/storage/test_id_generators.py184
-rw-r--r--tests/storage/test_monthly_active_users.py361
-rw-r--r--tests/storage/test_room.py11
-rw-r--r--tests/test_event_auth.py43
-rw-r--r--tests/test_federation.py123
-rw-r--r--tests/test_mau.py68
-rw-r--r--tests/test_metrics.py34
-rw-r--r--tests/test_server.py81
-rw-r--r--tests/util/test_expiring_cache.py2
-rw-r--r--tests/util/test_linearizer.py32
-rw-r--r--tests/util/test_lrucache.py6
-rw-r--r--tests/util/test_stream_change_cache.py5
-rw-r--r--tests/utils.py1
-rw-r--r--tox.ini16
275 files changed, 12859 insertions, 4909 deletions
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
new file mode 100644
index 00000000..1632170c
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE.md
@@ -0,0 +1,5 @@
+**If you are looking for support** please ask in **#synapse:matrix.org**
+(using a matrix.org account if necessary). We do not use GitHub issues for
+support.
+
+**If you want to report a security issue** please see https://matrix.org/security-disclosure-policy/
diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.md b/.github/ISSUE_TEMPLATE/BUG_REPORT.md
index 9dd05bcb..75c9b2c9 100644
--- a/.github/ISSUE_TEMPLATE/BUG_REPORT.md
+++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.md
@@ -4,11 +4,13 @@ about: Create a report to help us improve
---
-<!--
+**THIS IS NOT A SUPPORT CHANNEL!**
+**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**,
+please ask in **#synapse:matrix.org** (using a matrix.org account if necessary)
-**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**:
-You will likely get better support more quickly if you ask in ** #synapse:matrix.org ** ;)
+<!--
+If you want to report a security issue, please see https://matrix.org/security-disclosure-policy/
This is a bug report template. By following the instructions below and
filling out the sections with your information, you will help the us to get all
diff --git a/CHANGES.md b/CHANGES.md
index 225fced2..4a356442 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,182 @@
+Synapse 1.15.0 (2020-06-11)
+===========================
+
+No significant changes.
+
+
+Synapse 1.15.0rc1 (2020-06-09)
+==============================
+
+Features
+--------
+
+- Advertise support for Client-Server API r0.6.0 and remove related unstable feature flags. ([\#6585](https://github.com/matrix-org/synapse/issues/6585))
+- Add an option to disable autojoining rooms for guest accounts. ([\#6637](https://github.com/matrix-org/synapse/issues/6637))
+- For SAML authentication, add the ability to pass email addresses to be added to new users' accounts via SAML attributes. Contributed by Christopher Cooper. ([\#7385](https://github.com/matrix-org/synapse/issues/7385))
+- Add admin APIs to allow server admins to manage users' devices. Contributed by @dklimpel. ([\#7481](https://github.com/matrix-org/synapse/issues/7481))
+- Add support for generating thumbnails for WebP images. Previously, users would see an empty box instead of preview image. Contributed by @WGH-. ([\#7586](https://github.com/matrix-org/synapse/issues/7586))
+- Support the standardized `m.login.sso` user-interactive authentication flow. ([\#7630](https://github.com/matrix-org/synapse/issues/7630))
+
+
+Bugfixes
+--------
+
+- Allow new users to be registered via the admin API even if the monthly active user limit has been reached. Contributed by @dklimpel. ([\#7263](https://github.com/matrix-org/synapse/issues/7263))
+- Fix email notifications not being enabled for new users when created via the Admin API. ([\#7267](https://github.com/matrix-org/synapse/issues/7267))
+- Fix str placeholders in an instance of `PrepareDatabaseException`. Introduced in Synapse v1.8.0. ([\#7575](https://github.com/matrix-org/synapse/issues/7575))
+- Fix a bug in automatic user creation during first time login with `m.login.jwt`. Regression in v1.6.0. Contributed by @olof. ([\#7585](https://github.com/matrix-org/synapse/issues/7585))
+- Fix a bug causing the cross-signing keys to be ignored when resyncing a device list. ([\#7594](https://github.com/matrix-org/synapse/issues/7594))
+- Fix metrics failing when there is a large number of active background processes. ([\#7597](https://github.com/matrix-org/synapse/issues/7597))
+- Fix bug where returning rooms for a group would fail if it included a room that the server was not in. ([\#7599](https://github.com/matrix-org/synapse/issues/7599))
+- Fix duplicate key violation when persisting read markers. ([\#7607](https://github.com/matrix-org/synapse/issues/7607))
+- Prevent an entire iteration of the device list resync loop from failing if one server responds with a malformed result. ([\#7609](https://github.com/matrix-org/synapse/issues/7609))
+- Fix exceptions when fetching events from a remote host fails. ([\#7622](https://github.com/matrix-org/synapse/issues/7622))
+- Make `synctl restart` start synapse if it wasn't running. ([\#7624](https://github.com/matrix-org/synapse/issues/7624))
+- Pass device information through to the login endpoint when using the login fallback. ([\#7629](https://github.com/matrix-org/synapse/issues/7629))
+- Advertise the `m.login.token` login flow when OpenID Connect is enabled. ([\#7631](https://github.com/matrix-org/synapse/issues/7631))
+- Fix bug in account data replication stream. ([\#7656](https://github.com/matrix-org/synapse/issues/7656))
+
+
+Improved Documentation
+----------------------
+
+- Update the OpenBSD installation instructions. ([\#7587](https://github.com/matrix-org/synapse/issues/7587))
+- Advertise Python 3.8 support in `setup.py`. ([\#7602](https://github.com/matrix-org/synapse/issues/7602))
+- Add a link to `#synapse:matrix.org` in the troubleshooting section of the README. ([\#7603](https://github.com/matrix-org/synapse/issues/7603))
+- Clarifications to the admin api documentation. ([\#7647](https://github.com/matrix-org/synapse/issues/7647))
+
+
+Internal Changes
+----------------
+
+- Convert the identity handler to async/await. ([\#7561](https://github.com/matrix-org/synapse/issues/7561))
+- Improve query performance for fetching state from a PostgreSQL database. Contributed by @ilmari. ([\#7567](https://github.com/matrix-org/synapse/issues/7567))
+- Speed up processing of federation stream RDATA rows. ([\#7584](https://github.com/matrix-org/synapse/issues/7584))
+- Add comment to systemd example to show postgresql dependency. ([\#7591](https://github.com/matrix-org/synapse/issues/7591))
+- Refactor `Ratelimiter` to limit the amount of expensive config value accesses. ([\#7595](https://github.com/matrix-org/synapse/issues/7595))
+- Convert groups handlers to async/await. ([\#7600](https://github.com/matrix-org/synapse/issues/7600))
+- Clean up exception handling in `SAML2ResponseResource`. ([\#7614](https://github.com/matrix-org/synapse/issues/7614))
+- Check that all asynchronous tasks succeed and general cleanup of `MonthlyActiveUsersTestCase` and `TestMauLimit`. ([\#7619](https://github.com/matrix-org/synapse/issues/7619))
+- Convert `get_user_id_by_threepid` to async/await. ([\#7620](https://github.com/matrix-org/synapse/issues/7620))
+- Switch to upstream `dh-virtualenv` rather than our fork for Debian package builds. ([\#7621](https://github.com/matrix-org/synapse/issues/7621))
+- Update CI scripts to check the number in the newsfile fragment. ([\#7623](https://github.com/matrix-org/synapse/issues/7623))
+- Check if the localpart of a Matrix ID is reserved for guest users earlier in the registration flow, as well as when responding to requests to `/register/available`. ([\#7625](https://github.com/matrix-org/synapse/issues/7625))
+- Minor cleanups to OpenID Connect integration. ([\#7628](https://github.com/matrix-org/synapse/issues/7628))
+- Attempt to fix flaky test: `PhoneHomeStatsTestCase.test_performance_100`. ([\#7634](https://github.com/matrix-org/synapse/issues/7634))
+- Fix typos of `m.olm.curve25519-aes-sha2` and `m.megolm.v1.aes-sha2` in comments, test files. ([\#7637](https://github.com/matrix-org/synapse/issues/7637))
+- Convert user directory, state deltas, and stats handlers to async/await. ([\#7640](https://github.com/matrix-org/synapse/issues/7640))
+- Remove some unused constants. ([\#7644](https://github.com/matrix-org/synapse/issues/7644))
+- Fix type information on `assert_*_is_admin` methods. ([\#7645](https://github.com/matrix-org/synapse/issues/7645))
+- Convert registration handler to async/await. ([\#7649](https://github.com/matrix-org/synapse/issues/7649))
+
+
+Synapse 1.14.0 (2020-05-28)
+===========================
+
+No significant changes.
+
+
+Synapse 1.14.0rc2 (2020-05-27)
+==============================
+
+Bugfixes
+--------
+
+- Fix cache config to not apply cache factor to event cache. Regression in v1.14.0rc1. ([\#7578](https://github.com/matrix-org/synapse/issues/7578))
+- Fix bug where `ReplicationStreamer` was not always started when replication was enabled. Bug introduced in v1.14.0rc1. ([\#7579](https://github.com/matrix-org/synapse/issues/7579))
+- Fix specifying individual cache factors for caches with special characters in their name. Regression in v1.14.0rc1. ([\#7580](https://github.com/matrix-org/synapse/issues/7580))
+
+
+Improved Documentation
+----------------------
+
+- Fix the OIDC `client_auth_method` value in the sample config. ([\#7581](https://github.com/matrix-org/synapse/issues/7581))
+
+
+Synapse 1.14.0rc1 (2020-05-26)
+==============================
+
+Features
+--------
+
+- Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches. ([\#6391](https://github.com/matrix-org/synapse/issues/6391))
+- Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs). ([\#7256](https://github.com/matrix-org/synapse/issues/7256), [\#7457](https://github.com/matrix-org/synapse/issues/7457))
+- Add room details admin endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7317](https://github.com/matrix-org/synapse/issues/7317))
+- Allow for using more than one spam checker module at once. ([\#7435](https://github.com/matrix-org/synapse/issues/7435))
+- Add additional authentication checks for `m.room.power_levels` event per [MSC2209](https://github.com/matrix-org/matrix-doc/pull/2209). ([\#7502](https://github.com/matrix-org/synapse/issues/7502))
+- Implement room version 6 per [MSC2240](https://github.com/matrix-org/matrix-doc/pull/2240). ([\#7506](https://github.com/matrix-org/synapse/issues/7506))
+- Add highly experimental option to move event persistence off master. ([\#7281](https://github.com/matrix-org/synapse/issues/7281), [\#7374](https://github.com/matrix-org/synapse/issues/7374), [\#7436](https://github.com/matrix-org/synapse/issues/7436), [\#7440](https://github.com/matrix-org/synapse/issues/7440), [\#7475](https://github.com/matrix-org/synapse/issues/7475), [\#7490](https://github.com/matrix-org/synapse/issues/7490), [\#7491](https://github.com/matrix-org/synapse/issues/7491), [\#7492](https://github.com/matrix-org/synapse/issues/7492), [\#7493](https://github.com/matrix-org/synapse/issues/7493), [\#7495](https://github.com/matrix-org/synapse/issues/7495), [\#7515](https://github.com/matrix-org/synapse/issues/7515), [\#7516](https://github.com/matrix-org/synapse/issues/7516), [\#7517](https://github.com/matrix-org/synapse/issues/7517), [\#7542](https://github.com/matrix-org/synapse/issues/7542))
+
+
+Bugfixes
+--------
+
+- Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. ([\#7384](https://github.com/matrix-org/synapse/issues/7384))
+- Allow expired user accounts to log out their device sessions. ([\#7443](https://github.com/matrix-org/synapse/issues/7443))
+- Fix a bug that would cause Synapse not to resync out-of-sync device lists. ([\#7453](https://github.com/matrix-org/synapse/issues/7453))
+- Prevent rooms with 0 members or with invalid version strings from breaking group queries. ([\#7465](https://github.com/matrix-org/synapse/issues/7465))
+- Workaround for an upstream Twisted bug that caused Synapse to become unresponsive after startup. ([\#7473](https://github.com/matrix-org/synapse/issues/7473))
+- Fix Redis reconnection logic that can result in missed updates over replication if master reconnects to Redis without restarting. ([\#7482](https://github.com/matrix-org/synapse/issues/7482))
+- When sending `m.room.member` events, omit `displayname` and `avatar_url` if they aren't set instead of setting them to `null`. Contributed by Aaron Raimist. ([\#7497](https://github.com/matrix-org/synapse/issues/7497))
+- Fix incorrect `method` label on `synapse_http_matrixfederationclient_{requests,responses}` prometheus metrics. ([\#7503](https://github.com/matrix-org/synapse/issues/7503))
+- Ignore incoming presence events from other homeservers if presence is disabled locally. ([\#7508](https://github.com/matrix-org/synapse/issues/7508))
+- Fix a long-standing bug that broke the update remote profile background process. ([\#7511](https://github.com/matrix-org/synapse/issues/7511))
+- Hash passwords as early as possible during password reset. ([\#7538](https://github.com/matrix-org/synapse/issues/7538))
+- Fix bug where a local user leaving a room could fail under rare circumstances. ([\#7548](https://github.com/matrix-org/synapse/issues/7548))
+- Fix "Missing RelayState parameter" error when using user interactive authentication with SAML for some SAML providers. ([\#7552](https://github.com/matrix-org/synapse/issues/7552))
+- Fix exception `'GenericWorkerReplicationHandler' object has no attribute 'send_federation_ack'`, introduced in v1.13.0. ([\#7564](https://github.com/matrix-org/synapse/issues/7564))
+- `synctl` now warns if it was unable to stop Synapse and will not attempt to start Synapse if nothing was stopped. Contributed by Romain Bouyé. ([\#6598](https://github.com/matrix-org/synapse/issues/6598))
+
+
+Updates to the Docker image
+---------------------------
+
+- Update docker runtime image to Alpine v3.11. Contributed by @Starbix. ([\#7398](https://github.com/matrix-org/synapse/issues/7398))
+
+
+Improved Documentation
+----------------------
+
+- Update information about mapping providers for SAML and OpenID. ([\#7458](https://github.com/matrix-org/synapse/issues/7458))
+- Add additional reverse proxy example for Caddy v2. Contributed by Jeff Peeler. ([\#7463](https://github.com/matrix-org/synapse/issues/7463))
+- Fix copy-paste error in `ServerNoticesConfig` docstring. Contributed by @ptman. ([\#7477](https://github.com/matrix-org/synapse/issues/7477))
+- Improve the formatting of `reverse_proxy.md`. ([\#7514](https://github.com/matrix-org/synapse/issues/7514))
+- Change the systemd worker service to check that the worker config file exists instead of silently failing. Contributed by David Vo. ([\#7528](https://github.com/matrix-org/synapse/issues/7528))
+- Minor clarifications to the TURN docs. ([\#7533](https://github.com/matrix-org/synapse/issues/7533))
+
+
+Internal Changes
+----------------
+
+- Add typing annotations in `synapse.federation`. ([\#7382](https://github.com/matrix-org/synapse/issues/7382))
+- Convert the room handler to async/await. ([\#7396](https://github.com/matrix-org/synapse/issues/7396))
+- Improve performance of `get_e2e_cross_signing_key`. ([\#7428](https://github.com/matrix-org/synapse/issues/7428))
+- Improve performance of `mark_as_sent_devices_by_remote`. ([\#7429](https://github.com/matrix-org/synapse/issues/7429), [\#7562](https://github.com/matrix-org/synapse/issues/7562))
+- Add type hints to the SAML handler. ([\#7445](https://github.com/matrix-org/synapse/issues/7445))
+- Remove storage method `get_hosts_in_room` that is no longer called anywhere. ([\#7448](https://github.com/matrix-org/synapse/issues/7448))
+- Fix some typos in the `notice_expiry` templates. ([\#7449](https://github.com/matrix-org/synapse/issues/7449))
+- Convert the federation handler to async/await. ([\#7459](https://github.com/matrix-org/synapse/issues/7459))
+- Convert the search handler to async/await. ([\#7460](https://github.com/matrix-org/synapse/issues/7460))
+- Add type hints to `synapse.event_auth`. ([\#7505](https://github.com/matrix-org/synapse/issues/7505))
+- Convert the room member handler to async/await. ([\#7507](https://github.com/matrix-org/synapse/issues/7507))
+- Add type hints to room member handler. ([\#7513](https://github.com/matrix-org/synapse/issues/7513))
+- Fix typing annotations in `tests.replication`. ([\#7518](https://github.com/matrix-org/synapse/issues/7518))
+- Remove some redundant Python 2 support code. ([\#7519](https://github.com/matrix-org/synapse/issues/7519))
+- All endpoints now respond with a 200 OK for `OPTIONS` requests. ([\#7534](https://github.com/matrix-org/synapse/issues/7534), [\#7560](https://github.com/matrix-org/synapse/issues/7560))
+- Synapse now exports [detailed allocator statistics](https://doc.pypy.org/en/latest/gc_info.html#gc-get-stats) and basic GC timings as Prometheus metrics (`pypy_gc_time_seconds_total` and `pypy_memory_bytes`) when run under PyPy. Contributed by Ivan Shapovalov. ([\#7536](https://github.com/matrix-org/synapse/issues/7536))
+- Remove Ubuntu Cosmic and Disco from the list of distributions which we provide `.deb`s for, due to end-of-life. ([\#7539](https://github.com/matrix-org/synapse/issues/7539))
+- Make worker processes return a stubbed-out response to `GET /presence` requests. ([\#7545](https://github.com/matrix-org/synapse/issues/7545))
+- Optimise some references to `hs.config`. ([\#7546](https://github.com/matrix-org/synapse/issues/7546))
+- On upgrade room only send canonical alias once. ([\#7547](https://github.com/matrix-org/synapse/issues/7547))
+- Fix some indentation inconsistencies in the sample config. ([\#7550](https://github.com/matrix-org/synapse/issues/7550))
+- Include `synapse.http.site` in type checking. ([\#7553](https://github.com/matrix-org/synapse/issues/7553))
+- Fix some test code to not mangle stacktraces, to make it easier to debug errors. ([\#7554](https://github.com/matrix-org/synapse/issues/7554))
+- Refresh apt cache when building `dh_virtualenv` docker image. ([\#7555](https://github.com/matrix-org/synapse/issues/7555))
+- Stop logging some expected HTTP request errors as exceptions. ([\#7556](https://github.com/matrix-org/synapse/issues/7556), [\#7563](https://github.com/matrix-org/synapse/issues/7563))
+- Convert sending mail to async/await. ([\#7557](https://github.com/matrix-org/synapse/issues/7557))
+- Simplify `reap_monthly_active_users`. ([\#7558](https://github.com/matrix-org/synapse/issues/7558))
+
+
Synapse 1.13.0 (2020-05-19)
===========================
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 253a0ca6..062413e9 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,62 +1,48 @@
-# Contributing code to Matrix
+# Contributing code to Synapse
-Everyone is welcome to contribute code to Matrix
-(https://github.com/matrix-org), provided that they are willing to license
-their contributions under the same license as the project itself. We follow a
-simple 'inbound=outbound' model for contributions: the act of submitting an
-'inbound' contribution means that the contributor agrees to license the code
-under the same terms as the project's overall 'outbound' license - in our
-case, this is almost always Apache Software License v2 (see [LICENSE](LICENSE)).
+Everyone is welcome to contribute code to [matrix.org
+projects](https://github.com/matrix-org), provided that they are willing to
+license their contributions under the same license as the project itself. We
+follow a simple 'inbound=outbound' model for contributions: the act of
+submitting an 'inbound' contribution means that the contributor agrees to
+license the code under the same terms as the project's overall 'outbound'
+license - in our case, this is almost always Apache Software License v2 (see
+[LICENSE](LICENSE)).
## How to contribute
-The preferred and easiest way to contribute changes to Matrix is to fork the
-relevant project on github, and then [create a pull request](
-https://help.github.com/articles/using-pull-requests/) to ask us to pull
-your changes into our repo.
+The preferred and easiest way to contribute changes is to fork the relevant
+project on github, and then [create a pull request](
+https://help.github.com/articles/using-pull-requests/) to ask us to pull your
+changes into our repo.
-**The single biggest thing you need to know is: please base your changes on
-the develop branch - *not* master.**
+Some other points to follow:
+
+ * Please base your changes on the `develop` branch.
+
+ * Please follow the [code style requirements](#code-style).
-We use the master branch to track the most recent release, so that folks who
-blindly clone the repo and automatically check out master get something that
-works. Develop is the unstable branch where all the development actually
-happens: the workflow is that contributors should fork the develop branch to
-make a 'feature' branch for a particular contribution, and then make a pull
-request to merge this back into the matrix.org 'official' develop branch. We
-use github's pull request workflow to review the contribution, and either ask
-you to make any refinements needed or merge it and make them ourselves. The
-changes will then land on master when we next do a release.
+ * Please include a [changelog entry](#changelog) with each PR.
-We use [Buildkite](https://buildkite.com/matrix-dot-org/synapse) for continuous
-integration. If your change breaks the build, this will be shown in GitHub, so
-please keep an eye on the pull request for feedback.
+ * Please [sign off](#sign-off) your contribution.
-To run unit tests in a local development environment, you can use:
+ * Please keep an eye on the pull request for feedback from the [continuous
+ integration system](#continuous-integration-and-testing) and try to fix any
+ errors that come up.
-- ``tox -e py35`` (requires tox to be installed by ``pip install tox``)
- for SQLite-backed Synapse on Python 3.5.
-- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6.
-- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6
- (requires a running local PostgreSQL with access to create databases).
-- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5
- (requires Docker). Entirely self-contained, recommended if you don't want to
- set up PostgreSQL yourself.
-
-Docker images are available for running the integration tests (SyTest) locally,
-see the [documentation in the SyTest repo](
-https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more
-information.
+ * If you need to [update your PR](#updating-your-pull-request), just add new
+ commits to your branch rather than rebasing.
## Code style
-All Matrix projects have a well-defined code-style - and sometimes we've even
-got as far as documenting it... For instance, synapse's code style doc lives
-[here](docs/code_style.md).
+Synapse's code style is documented [here](docs/code_style.md). Please follow
+it, including the conventions for the [sample configuration
+file](docs/code_style.md#configuration-file-format).
-To facilitate meeting these criteria you can run `scripts-dev/lint.sh`
-locally. Since this runs the tools listed in the above document, you'll need
-python 3.6 and to install each tool:
+Many of the conventions are enforced by scripts which are run as part of the
+[continuous integration system](#continuous-integration-and-testing). To help
+check if you have followed the code style, you can run `scripts-dev/lint.sh`
+locally. You'll need python 3.6 or later, and to install a number of tools:
```
# Install the dependencies
@@ -67,9 +53,11 @@ pip install -U black flake8 flake8-comprehensions isort
```
**Note that the script does not just test/check, but also reformats code, so you
-may wish to ensure any new code is committed first**. By default this script
-checks all files and can take some time; if you alter only certain files, you
-might wish to specify paths as arguments to reduce the run-time:
+may wish to ensure any new code is committed first**.
+
+By default, this script checks all files and can take some time; if you alter
+only certain files, you might wish to specify paths as arguments to reduce the
+run-time:
```
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
@@ -82,7 +70,6 @@ Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise.
-
## Changelog
All changes, even minor ones, need a corresponding changelog / newsfragment
@@ -98,24 +85,55 @@ in the format of `PRnumber.type`. The type can be one of the following:
* `removal` (also used for deprecations)
* `misc` (for internal-only changes)
-The content of the file is your changelog entry, which should be a short
-description of your change in the same style as the rest of our [changelog](
-https://github.com/matrix-org/synapse/blob/master/CHANGES.md). The file can
-contain Markdown formatting, and should end with a full stop (.) or an
-exclamation mark (!) for consistency.
+This file will become part of our [changelog](
+https://github.com/matrix-org/synapse/blob/master/CHANGES.md) at the next
+release, so the content of the file should be a short description of your
+change in the same style as the rest of the changelog. The file can contain Markdown
+formatting, and should end with a full stop (.) or an exclamation mark (!) for
+consistency.
Adding credits to the changelog is encouraged, we value your
contributions and would like to have you shouted out in the release notes!
For example, a fix in PR #1234 would have its changelog entry in
-`changelog.d/1234.bugfix`, and contain content like "The security levels of
-Florbs are now validated when received over federation. Contributed by Jane
-Matrix.".
+`changelog.d/1234.bugfix`, and contain content like:
+
+> The security levels of Florbs are now validated when received
+> via the `/federation/florb` endpoint. Contributed by Jane Matrix.
+
+If there are multiple pull requests involved in a single bugfix/feature/etc,
+then the content for each `changelog.d` file should be the same. Towncrier will
+merge the matching files together into a single changelog entry when we come to
+release.
+
+### How do I know what to call the changelog file before I create the PR?
+
+Obviously, you don't know if you should call your newsfile
+`1234.bugfix` or `5678.bugfix` until you create the PR, which leads to a
+chicken-and-egg problem.
+
+There are two options for solving this:
+
+ 1. Open the PR without a changelog file, see what number you got, and *then*
+ add the changelog file to your branch (see [Updating your pull
+ request](#updating-your-pull-request)), or:
-## Debian changelog
+ 1. Look at the [list of all
+ issues/PRs](https://github.com/matrix-org/synapse/issues?q=), add one to the
+ highest number you see, and quickly open the PR before somebody else claims
+ your number.
+
+ [This
+ script](https://github.com/richvdh/scripts/blob/master/next_github_number.sh)
+ might be helpful if you find yourself doing this a lot.
+
+Sorry, we know it's a bit fiddly, but it's *really* helpful for us when we come
+to put together a release!
+
+### Debian changelog
Changes which affect the debian packaging files (in `debian`) are an
-exception.
+exception to the rule that all changes require a `changelog.d` file.
In this case, you will need to add an entry to the debian changelog for the
next release. For this, run the following command:
@@ -200,19 +218,45 @@ Git allows you to add this signoff automatically when using the `-s`
flag to `git commit`, which uses the name and email set in your
`user.name` and `user.email` git configs.
-## Merge Strategy
+## Continuous integration and testing
+
+[Buildkite](https://buildkite.com/matrix-dot-org/synapse) will automatically
+run a series of checks and tests against any PR which is opened against the
+project; if your change breaks the build, this will be shown in GitHub, with
+links to the build results. If your build fails, please try to fix the errors
+and update your branch.
+
+To run unit tests in a local development environment, you can use:
+
+- ``tox -e py35`` (requires tox to be installed by ``pip install tox``)
+ for SQLite-backed Synapse on Python 3.5.
+- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6.
+- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6
+ (requires a running local PostgreSQL with access to create databases).
+- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5
+ (requires Docker). Entirely self-contained, recommended if you don't want to
+ set up PostgreSQL yourself.
+
+Docker images are available for running the integration tests (SyTest) locally,
+see the [documentation in the SyTest repo](
+https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more
+information.
+
+## Updating your pull request
+
+If you decide to make changes to your pull request - perhaps to address issues
+raised in a review, or to fix problems highlighted by [continuous
+integration](#continuous-integration-and-testing) - just add new commits to your
+branch, and push to GitHub. The pull request will automatically be updated.
-We use the commit history of develop/master extensively to identify
-when regressions were introduced and what changes have been made.
+Please **avoid** rebasing your branch, especially once the PR has been
+reviewed: doing so makes it very difficult for a reviewer to see what has
+changed since a previous review.
-We aim to have a clean merge history, which means we normally squash-merge
-changes into develop. For small changes this means there is no need to rebase
-to clean up your PR before merging. Larger changes with an organised set of
-commits may be merged as-is, if the history is judged to be useful.
+## Notes for maintainers on merging PRs etc
-This use of squash-merging will mean PRs built on each other will be hard to
-merge. We suggest avoiding these where possible, and if required, ensuring
-each PR has a tidy set of commits to ease merging.
+There are some notes for those with commit access to the project on how we
+manage git [here](docs/dev/git.md).
## Conclusion
diff --git a/INSTALL.md b/INSTALL.md
index b8f8a673..ef80a26c 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -180,35 +180,41 @@ sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
#### OpenBSD
-Installing prerequisites on OpenBSD:
+A port of Synapse is available under `net/synapse`. The filesystem
+underlying the homeserver directory (defaults to `/var/synapse`) has to be
+mounted with `wxallowed` (cf. `mount(8)`), so creating a separate filesystem
+and mounting it to `/var/synapse` should be taken into consideration.
+
+To be able to build Synapse's dependency on python the `WRKOBJDIR`
+(cf. `bsd.port.mk(5)`) for building python, too, needs to be on a filesystem
+mounted with `wxallowed` (cf. `mount(8)`).
+
+Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a
+default OpenBSD installation is mounted with `wxallowed`):
```
-doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
- libxslt jpeg
+doas mkdir /usr/local/pobj_wxallowed
```
-There is currently no port for OpenBSD. Additionally, OpenBSD's security
-settings require a slightly more difficult installation process.
+Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are
+configured in `/etc/mk.conf`:
+
+```
+doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed
+```
-(XXX: I suspect this is out of date)
+Setting the `WRKOBJDIR` for building python:
-1. Create a new directory in `/usr/local` called `_synapse`. Also, create a
- new user called `_synapse` and set that directory as the new user's home.
- This is required because, by default, OpenBSD only allows binaries which need
- write and execute permissions on the same memory space to be run from
- `/usr/local`.
-2. `su` to the new `_synapse` user and change to their home directory.
-3. Create a new virtualenv: `virtualenv -p python3 ~/.synapse`
-4. Source the virtualenv configuration located at
- `/usr/local/_synapse/.synapse/bin/activate`. This is done in `ksh` by
- using the `.` command, rather than `bash`'s `source`.
-5. Optionally, use `pip` to install `lxml`, which Synapse needs to parse
- webpages for their titles.
-6. Use `pip` to install this repository: `pip install matrix-synapse`
-7. Optionally, change `_synapse`'s shell to `/bin/false` to reduce the
- chance of a compromised Synapse server being used to take over your box.
+```
+echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf
+```
-After this, you may proceed with the rest of the install directions.
+Building Synapse:
+
+```
+cd /usr/ports/net/synapse
+make install
+```
#### Windows
@@ -350,6 +356,18 @@ Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Mo
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
- Packages: `pkg install py37-matrix-synapse`
+### OpenBSD
+
+As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem
+underlying the homeserver directory (defaults to `/var/synapse`) has to be
+mounted with `wxallowed` (cf. `mount(8)`), so creating a separate filesystem
+and mounting it to `/var/synapse` should be taken into consideration.
+
+Installing Synapse:
+
+```
+doas pkg_add synapse
+```
### NixOS
diff --git a/README.rst b/README.rst
index 4db7d17e..31d375d1 100644
--- a/README.rst
+++ b/README.rst
@@ -1,3 +1,11 @@
+================
+Synapse |shield|
+================
+
+.. |shield| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix
+ :alt: (get support on #synapse:matrix.org)
+ :target: https://matrix.to/#/#synapse:matrix.org
+
.. contents::
Introduction
@@ -77,6 +85,17 @@ Thanks for using Matrix!
[1] End-to-end encryption is currently in beta: `blog post <https://matrix.org/blog/2016/11/21/matrixs-olm-end-to-end-encryption-security-assessment-released-and-implemented-cross-platform-on-riot-at-last>`_.
+Support
+=======
+
+For support installing or managing Synapse, please join |room|_ (from a matrix.org
+account if necessary) and ask questions there. We do not use GitHub issues for
+support requests, only for bug reports and feature requests.
+
+.. |room| replace:: ``#synapse:matrix.org``
+.. _room: https://matrix.to/#/#synapse:matrix.org
+
+
Synapse Installation
====================
@@ -248,7 +267,7 @@ First calculate the hash of the new password::
Confirm password:
$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
-Then update the `users` table in the database::
+Then update the ``users`` table in the database::
UPDATE users SET password_hash='$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
WHERE name='@test:test.com';
@@ -316,6 +335,9 @@ Building internal API documentation::
Troubleshooting
===============
+Need help? Join our community support room on Matrix:
+`#synapse:matrix.org <https://matrix.to/#/#synapse:matrix.org>`_
+
Running out of File Handles
---------------------------
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 41c47e96..3b5627e8 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -75,9 +75,15 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
-Upgrading to v1.13.0
+Upgrading to v1.14.0
====================
+This version includes a database update which is run as part of the upgrade,
+and which may take a couple of minutes in the case of a large server. Synapse
+will not respond to HTTP requests while this update is taking place.
+
+Upgrading to v1.13.0
+====================
Incorrect database migration in old synapse versions
----------------------------------------------------
@@ -136,12 +142,12 @@ back to v1.12.4 you need to:
2. Decrease the schema version in the database:
.. code:: sql
-
+
UPDATE schema_version SET version = 57;
3. Downgrade Synapse by following the instructions for your installation method
in the "Rolling back to older versions" section above.
-
+
Upgrading to v1.12.0
====================
diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json
index 656a4425..30a8681f 100644
--- a/contrib/grafana/synapse.json
+++ b/contrib/grafana/synapse.json
@@ -18,7 +18,7 @@
"gnetId": null,
"graphTooltip": 0,
"id": 1,
- "iteration": 1584612489167,
+ "iteration": 1591098104645,
"links": [
{
"asDropdown": true,
@@ -577,13 +577,15 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
- "y": 29
+ "y": 2
},
+ "hiddenSeries": false,
"id": 5,
"legend": {
"alignAsTable": false,
@@ -602,6 +604,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -704,12 +709,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
- "y": 29
+ "y": 2
},
+ "hiddenSeries": false,
"id": 37,
"legend": {
"avg": false,
@@ -724,6 +731,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -810,13 +820,15 @@
"editable": true,
"error": false,
"fill": 0,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
- "y": 36
+ "y": 9
},
+ "hiddenSeries": false,
"id": 34,
"legend": {
"avg": false,
@@ -831,6 +843,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -898,13 +913,16 @@
"datasource": "$datasource",
"description": "Shows the time in which the given percentage of reactor ticks completed, over the sampled timespan",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
- "y": 36
+ "y": 9
},
+ "hiddenSeries": false,
"id": 105,
+ "interval": "",
"legend": {
"avg": false,
"current": false,
@@ -918,6 +936,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -952,9 +973,10 @@
"refId": "C"
},
{
- "expr": "",
+ "expr": "rate(python_twisted_reactor_tick_time_sum{index=~\"$index\",instance=\"$instance\",job=~\"$job\"}[$bucket_size]) / rate(python_twisted_reactor_tick_time_count{index=~\"$index\",instance=\"$instance\",job=~\"$job\"}[$bucket_size])",
"format": "time_series",
"intervalFactor": 1,
+ "legendFormat": "{{job}}-{{index}} mean",
"refId": "D"
}
],
@@ -1005,14 +1027,16 @@
"dashLength": 10,
"dashes": false,
"datasource": "$datasource",
- "fill": 1,
+ "fill": 0,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
- "y": 43
+ "y": 16
},
- "id": 50,
+ "hiddenSeries": false,
+ "id": 53,
"legend": {
"avg": false,
"current": false,
@@ -1026,6 +1050,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -1037,20 +1064,18 @@
"steppedLine": false,
"targets": [
{
- "expr": "rate(python_twisted_reactor_tick_time_sum{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])/rate(python_twisted_reactor_tick_time_count[$bucket_size])",
+ "expr": "min_over_time(up{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
"format": "time_series",
- "interval": "",
"intervalFactor": 2,
"legendFormat": "{{job}}-{{index}}",
- "refId": "A",
- "step": 20
+ "refId": "A"
}
],
"thresholds": [],
"timeFrom": null,
"timeRegions": [],
"timeShift": null,
- "title": "Avg reactor tick time",
+ "title": "Up",
"tooltip": {
"shared": true,
"sort": 0,
@@ -1066,7 +1091,7 @@
},
"yaxes": [
{
- "format": "s",
+ "format": "short",
"label": null,
"logBase": 1,
"max": null,
@@ -1079,7 +1104,7 @@
"logBase": 1,
"max": null,
"min": null,
- "show": false
+ "show": true
}
],
"yaxis": {
@@ -1094,12 +1119,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
- "y": 43
+ "y": 16
},
+ "hiddenSeries": false,
"id": 49,
"legend": {
"avg": false,
@@ -1114,6 +1141,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -1188,14 +1218,17 @@
"dashLength": 10,
"dashes": false,
"datasource": "$datasource",
- "fill": 0,
+ "fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
- "y": 50
+ "y": 23
},
- "id": 53,
+ "hiddenSeries": false,
+ "id": 136,
+ "interval": "",
"legend": {
"avg": false,
"current": false,
@@ -1207,11 +1240,12 @@
},
"lines": true,
"linewidth": 1,
- "links": [],
"nullPointMode": "null",
- "paceLength": 10,
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
- "pointradius": 5,
+ "pointradius": 2,
"points": false,
"renderer": "flot",
"seriesOverrides": [],
@@ -1220,18 +1254,21 @@
"steppedLine": false,
"targets": [
{
- "expr": "min_over_time(up{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
- "format": "time_series",
- "intervalFactor": 2,
- "legendFormat": "{{job}}-{{index}}",
+ "expr": "rate(synapse_http_client_requests{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size])",
+ "legendFormat": "{{job}}-{{index}} {{method}}",
"refId": "A"
+ },
+ {
+ "expr": "rate(synapse_http_matrixfederationclient_requests{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size])",
+ "legendFormat": "{{job}}-{{index}} {{method}} (federation)",
+ "refId": "B"
}
],
"thresholds": [],
"timeFrom": null,
"timeRegions": [],
"timeShift": null,
- "title": "Up",
+ "title": "Outgoing HTTP request rate",
"tooltip": {
"shared": true,
"sort": 0,
@@ -1247,7 +1284,7 @@
},
"yaxes": [
{
- "format": "short",
+ "format": "reqps",
"label": null,
"logBase": 1,
"max": null,
@@ -1275,12 +1312,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
- "y": 50
+ "y": 23
},
+ "hiddenSeries": false,
"id": 120,
"legend": {
"avg": false,
@@ -1295,6 +1334,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 2,
"points": false,
@@ -1307,6 +1349,7 @@
{
"expr": "rate(synapse_http_server_response_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_response_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
"format": "time_series",
+ "hide": false,
"instant": false,
"intervalFactor": 1,
"legendFormat": "{{job}}-{{index}} {{method}} {{servlet}} {{tag}}",
@@ -1315,6 +1358,7 @@
{
"expr": "rate(synapse_background_process_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_background_process_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
"format": "time_series",
+ "hide": false,
"instant": false,
"interval": "",
"intervalFactor": 1,
@@ -1770,6 +1814,7 @@
"editable": true,
"error": false,
"fill": 2,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 8,
@@ -1777,6 +1822,7 @@
"x": 0,
"y": 31
},
+ "hiddenSeries": false,
"id": 4,
"legend": {
"alignAsTable": true,
@@ -1795,7 +1841,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -1878,6 +1926,7 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 8,
@@ -1885,6 +1934,7 @@
"x": 12,
"y": 31
},
+ "hiddenSeries": false,
"id": 32,
"legend": {
"avg": false,
@@ -1899,7 +1949,9 @@
"linewidth": 2,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -1968,6 +2020,7 @@
"editable": true,
"error": false,
"fill": 2,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 8,
@@ -1975,7 +2028,8 @@
"x": 0,
"y": 39
},
- "id": 23,
+ "hiddenSeries": false,
+ "id": 139,
"legend": {
"alignAsTable": true,
"avg": false,
@@ -1993,7 +2047,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -2004,7 +2060,7 @@
"steppedLine": false,
"targets": [
{
- "expr": "rate(synapse_http_server_response_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_response_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
+ "expr": "rate(synapse_http_server_in_flight_requests_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_in_flight_requests_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
"format": "time_series",
"interval": "",
"intervalFactor": 1,
@@ -2079,6 +2135,7 @@
"editable": true,
"error": false,
"fill": 2,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 8,
@@ -2086,6 +2143,7 @@
"x": 12,
"y": 39
},
+ "hiddenSeries": false,
"id": 52,
"legend": {
"alignAsTable": true,
@@ -2104,7 +2162,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -2115,7 +2175,7 @@
"steppedLine": false,
"targets": [
{
- "expr": "(rate(synapse_http_server_response_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_response_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])) / rate(synapse_http_server_response_count{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
+ "expr": "(rate(synapse_http_server_in_flight_requests_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_in_flight_requests_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])) / rate(synapse_http_server_requests_received{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
"format": "time_series",
"interval": "",
"intervalFactor": 2,
@@ -2187,6 +2247,7 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 8,
@@ -2194,6 +2255,7 @@
"x": 0,
"y": 47
},
+ "hiddenSeries": false,
"id": 7,
"legend": {
"alignAsTable": true,
@@ -2211,7 +2273,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -2222,7 +2286,7 @@
"steppedLine": false,
"targets": [
{
- "expr": "rate(synapse_http_server_response_db_txn_duration_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
+ "expr": "rate(synapse_http_server_in_flight_requests_db_txn_duration_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
"format": "time_series",
"interval": "",
"intervalFactor": 2,
@@ -2280,6 +2344,7 @@
"editable": true,
"error": false,
"fill": 2,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 8,
@@ -2287,6 +2352,7 @@
"x": 12,
"y": 47
},
+ "hiddenSeries": false,
"id": 47,
"legend": {
"alignAsTable": true,
@@ -2305,7 +2371,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -2372,12 +2440,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 55
},
+ "hiddenSeries": false,
"id": 103,
"legend": {
"avg": false,
@@ -2392,7 +2462,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -2475,12 +2547,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 32
},
+ "hiddenSeries": false,
"id": 99,
"legend": {
"avg": false,
@@ -2495,7 +2569,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -2563,12 +2639,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 32
},
+ "hiddenSeries": false,
"id": 101,
"legend": {
"avg": false,
@@ -2583,7 +2661,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -2649,6 +2729,93 @@
"align": false,
"alignLevel": null
}
+ },
+ {
+ "aliasColors": {},
+ "bars": false,
+ "dashLength": 10,
+ "dashes": false,
+ "datasource": "$datasource",
+ "fill": 1,
+ "fillGradient": 0,
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 0,
+ "y": 41
+ },
+ "hiddenSeries": false,
+ "id": 138,
+ "legend": {
+ "avg": false,
+ "current": false,
+ "max": false,
+ "min": false,
+ "show": true,
+ "total": false,
+ "values": false
+ },
+ "lines": true,
+ "linewidth": 1,
+ "nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
+ "percentage": false,
+ "pointradius": 2,
+ "points": false,
+ "renderer": "flot",
+ "seriesOverrides": [],
+ "spaceLength": 10,
+ "stack": false,
+ "steppedLine": false,
+ "targets": [
+ {
+ "expr": "synapse_background_process_in_flight_count{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}",
+ "legendFormat": "{{job}}-{{index}} {{name}}",
+ "refId": "A"
+ }
+ ],
+ "thresholds": [],
+ "timeFrom": null,
+ "timeRegions": [],
+ "timeShift": null,
+ "title": "Background jobs in flight",
+ "tooltip": {
+ "shared": false,
+ "sort": 0,
+ "value_type": "individual"
+ },
+ "type": "graph",
+ "xaxis": {
+ "buckets": null,
+ "mode": "time",
+ "name": null,
+ "show": true,
+ "values": []
+ },
+ "yaxes": [
+ {
+ "format": "short",
+ "label": null,
+ "logBase": 1,
+ "max": null,
+ "min": null,
+ "show": true
+ },
+ {
+ "format": "short",
+ "label": null,
+ "logBase": 1,
+ "max": null,
+ "min": null,
+ "show": true
+ }
+ ],
+ "yaxis": {
+ "align": false,
+ "alignLevel": null
+ }
}
],
"title": "Background jobs",
@@ -2679,6 +2846,7 @@
"x": 0,
"y": 33
},
+ "hiddenSeries": false,
"id": 79,
"legend": {
"avg": false,
@@ -2774,6 +2942,7 @@
"x": 12,
"y": 33
},
+ "hiddenSeries": false,
"id": 83,
"legend": {
"avg": false,
@@ -2871,6 +3040,7 @@
"x": 0,
"y": 42
},
+ "hiddenSeries": false,
"id": 109,
"legend": {
"avg": false,
@@ -2969,6 +3139,7 @@
"x": 12,
"y": 42
},
+ "hiddenSeries": false,
"id": 111,
"legend": {
"avg": false,
@@ -3045,6 +3216,144 @@
"align": false,
"alignLevel": null
}
+ },
+ {
+ "aliasColors": {},
+ "bars": false,
+ "dashLength": 10,
+ "dashes": false,
+ "datasource": "$datasource",
+ "description": "",
+ "fill": 1,
+ "fillGradient": 0,
+ "gridPos": {
+ "h": 9,
+ "w": 12,
+ "x": 0,
+ "y": 51
+ },
+ "hiddenSeries": false,
+ "id": 140,
+ "legend": {
+ "avg": false,
+ "current": false,
+ "max": false,
+ "min": false,
+ "show": true,
+ "total": false,
+ "values": false
+ },
+ "lines": true,
+ "linewidth": 1,
+ "links": [],
+ "nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
+ "paceLength": 10,
+ "percentage": false,
+ "pointradius": 5,
+ "points": false,
+ "renderer": "flot",
+ "seriesOverrides": [],
+ "spaceLength": 10,
+ "stack": false,
+ "steppedLine": false,
+ "targets": [
+ {
+ "expr": "synapse_federation_send_queue_presence_changed_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "format": "time_series",
+ "interval": "",
+ "intervalFactor": 1,
+ "legendFormat": "presence changed",
+ "refId": "A"
+ },
+ {
+ "expr": "synapse_federation_send_queue_presence_map_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "format": "time_series",
+ "hide": false,
+ "interval": "",
+ "intervalFactor": 1,
+ "legendFormat": "presence map",
+ "refId": "B"
+ },
+ {
+ "expr": "synapse_federation_send_queue_presence_destinations_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "format": "time_series",
+ "hide": false,
+ "interval": "",
+ "intervalFactor": 1,
+ "legendFormat": "presence destinations",
+ "refId": "E"
+ },
+ {
+ "expr": "synapse_federation_send_queue_keyed_edu_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "format": "time_series",
+ "hide": false,
+ "interval": "",
+ "intervalFactor": 1,
+ "legendFormat": "keyed edus",
+ "refId": "C"
+ },
+ {
+ "expr": "synapse_federation_send_queue_edus_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "format": "time_series",
+ "hide": false,
+ "interval": "",
+ "intervalFactor": 1,
+ "legendFormat": "other edus",
+ "refId": "D"
+ },
+ {
+ "expr": "synapse_federation_send_queue_pos_time_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "format": "time_series",
+ "hide": false,
+ "interval": "",
+ "intervalFactor": 1,
+ "legendFormat": "stream positions",
+ "refId": "F"
+ }
+ ],
+ "thresholds": [],
+ "timeFrom": null,
+ "timeRegions": [],
+ "timeShift": null,
+ "title": "Outgoing EDU queues on master",
+ "tooltip": {
+ "shared": true,
+ "sort": 0,
+ "value_type": "individual"
+ },
+ "type": "graph",
+ "xaxis": {
+ "buckets": null,
+ "mode": "time",
+ "name": null,
+ "show": true,
+ "values": []
+ },
+ "yaxes": [
+ {
+ "format": "none",
+ "label": null,
+ "logBase": 1,
+ "max": null,
+ "min": "0",
+ "show": true
+ },
+ {
+ "format": "short",
+ "label": null,
+ "logBase": 1,
+ "max": null,
+ "min": null,
+ "show": true
+ }
+ ],
+ "yaxis": {
+ "align": false,
+ "alignLevel": null
+ }
}
],
"title": "Federation",
@@ -3274,12 +3583,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
- "y": 35
+ "y": 52
},
+ "hiddenSeries": false,
"id": 48,
"legend": {
"avg": false,
@@ -3294,7 +3605,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -3364,12 +3677,14 @@
"datasource": "$datasource",
"description": "Shows the time in which the given percentage of database queries were scheduled, over the sampled timespan",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
- "y": 35
+ "y": 52
},
+ "hiddenSeries": false,
"id": 104,
"legend": {
"alignAsTable": true,
@@ -3385,7 +3700,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -3479,13 +3796,15 @@
"editable": true,
"error": false,
"fill": 0,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
- "y": 42
+ "y": 59
},
+ "hiddenSeries": false,
"id": 10,
"legend": {
"avg": false,
@@ -3502,7 +3821,9 @@
"linewidth": 2,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -3571,13 +3892,15 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
- "y": 42
+ "y": 59
},
+ "hiddenSeries": false,
"id": 11,
"legend": {
"avg": false,
@@ -3594,7 +3917,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -3686,8 +4011,9 @@
"h": 13,
"w": 12,
"x": 0,
- "y": 36
+ "y": 67
},
+ "hiddenSeries": false,
"id": 12,
"legend": {
"alignAsTable": true,
@@ -3780,8 +4106,9 @@
"h": 13,
"w": 12,
"x": 12,
- "y": 36
+ "y": 67
},
+ "hiddenSeries": false,
"id": 26,
"legend": {
"alignAsTable": true,
@@ -3874,8 +4201,9 @@
"h": 13,
"w": 12,
"x": 0,
- "y": 49
+ "y": 80
},
+ "hiddenSeries": false,
"id": 13,
"legend": {
"alignAsTable": true,
@@ -3905,7 +4233,7 @@
"steppedLine": false,
"targets": [
{
- "expr": "rate(synapse_util_metrics_block_db_txn_duration_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\",block_name!=\"wrapped_request_handler\"}[$bucket_size])",
+ "expr": "rate(synapse_util_metrics_block_db_txn_duration_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
"format": "time_series",
"interval": "",
"intervalFactor": 2,
@@ -3959,6 +4287,7 @@
"dashLength": 10,
"dashes": false,
"datasource": "$datasource",
+ "description": "The time each database transaction takes to execute, on average, broken down by metrics block.",
"editable": true,
"error": false,
"fill": 1,
@@ -3968,8 +4297,9 @@
"h": 13,
"w": 12,
"x": 12,
- "y": 49
+ "y": 80
},
+ "hiddenSeries": false,
"id": 27,
"legend": {
"alignAsTable": true,
@@ -4012,7 +4342,7 @@
"timeFrom": null,
"timeRegions": [],
"timeShift": null,
- "title": "Average Database Time per Block",
+ "title": "Average Database Transaction time, by Block",
"tooltip": {
"shared": false,
"sort": 0,
@@ -4056,13 +4386,15 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 13,
"w": 12,
"x": 0,
- "y": 62
+ "y": 93
},
+ "hiddenSeries": false,
"id": 28,
"legend": {
"avg": false,
@@ -4077,7 +4409,9 @@
"linewidth": 2,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -4146,13 +4480,15 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 13,
"w": 12,
"x": 12,
- "y": 62
+ "y": 93
},
+ "hiddenSeries": false,
"id": 25,
"legend": {
"avg": false,
@@ -4167,7 +4503,9 @@
"linewidth": 2,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -4253,6 +4591,7 @@
"editable": true,
"error": false,
"fill": 0,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 10,
@@ -4260,6 +4599,7 @@
"x": 0,
"y": 37
},
+ "hiddenSeries": false,
"id": 1,
"legend": {
"alignAsTable": true,
@@ -4277,6 +4617,9 @@
"linewidth": 2,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -4346,6 +4689,7 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 10,
@@ -4353,6 +4697,7 @@
"x": 12,
"y": 37
},
+ "hiddenSeries": false,
"id": 8,
"legend": {
"alignAsTable": true,
@@ -4369,6 +4714,9 @@
"linewidth": 2,
"links": [],
"nullPointMode": "connected",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -4437,6 +4785,7 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 10,
@@ -4444,6 +4793,7 @@
"x": 0,
"y": 47
},
+ "hiddenSeries": false,
"id": 38,
"legend": {
"alignAsTable": true,
@@ -4460,6 +4810,9 @@
"linewidth": 2,
"links": [],
"nullPointMode": "connected",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -4525,12 +4878,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 10,
"w": 12,
"x": 12,
"y": 47
},
+ "hiddenSeries": false,
"id": 39,
"legend": {
"alignAsTable": true,
@@ -4546,6 +4901,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -4612,12 +4970,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 57
},
+ "hiddenSeries": false,
"id": 65,
"legend": {
"alignAsTable": true,
@@ -4633,6 +4993,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -4715,12 +5078,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 66
},
+ "hiddenSeries": false,
"id": 91,
"legend": {
"avg": false,
@@ -4735,6 +5100,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -4805,6 +5173,7 @@
"editable": true,
"error": false,
"fill": 1,
+ "fillGradient": 0,
"grid": {},
"gridPos": {
"h": 9,
@@ -4812,6 +5181,7 @@
"x": 12,
"y": 66
},
+ "hiddenSeries": false,
"id": 21,
"legend": {
"alignAsTable": true,
@@ -4827,6 +5197,9 @@
"linewidth": 2,
"links": [],
"nullPointMode": "null as zero",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -4893,12 +5266,14 @@
"datasource": "$datasource",
"description": "'gen 0' shows the number of objects allocated since the last gen0 GC.\n'gen 1' / 'gen 2' show the number of gen0/gen1 GCs since the last gen1/gen2 GC.",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 75
},
+ "hiddenSeries": false,
"id": 89,
"legend": {
"avg": false,
@@ -4915,6 +5290,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -4986,12 +5364,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 75
},
+ "hiddenSeries": false,
"id": 93,
"legend": {
"avg": false,
@@ -5006,6 +5386,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "connected",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -5071,12 +5454,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 84
},
+ "hiddenSeries": false,
"id": 95,
"legend": {
"avg": false,
@@ -5091,6 +5476,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 5,
"points": false,
@@ -5179,6 +5567,7 @@
"show": true
},
"links": [],
+ "options": {},
"reverseYBuckets": false,
"targets": [
{
@@ -5236,12 +5625,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
"y": 39
},
+ "hiddenSeries": false,
"id": 2,
"legend": {
"avg": false,
@@ -5256,7 +5647,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -5356,12 +5749,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
"y": 39
},
+ "hiddenSeries": false,
"id": 41,
"legend": {
"avg": false,
@@ -5376,7 +5771,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -5445,12 +5842,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
"y": 46
},
+ "hiddenSeries": false,
"id": 42,
"legend": {
"avg": false,
@@ -5465,7 +5864,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -5533,12 +5934,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
"y": 46
},
+ "hiddenSeries": false,
"id": 43,
"legend": {
"avg": false,
@@ -5553,7 +5956,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -5621,12 +6026,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 0,
"y": 53
},
+ "hiddenSeries": false,
"id": 113,
"legend": {
"avg": false,
@@ -5641,7 +6048,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -5658,6 +6067,13 @@
"intervalFactor": 1,
"legendFormat": "{{job}}-{{index}} {{stream_name}}",
"refId": "A"
+ },
+ {
+ "expr": "synapse_replication_tcp_resource_total_connections{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}",
+ "format": "time_series",
+ "intervalFactor": 1,
+ "legendFormat": "{{job}}-{{index}}",
+ "refId": "B"
}
],
"thresholds": [],
@@ -5684,7 +6100,7 @@
"label": null,
"logBase": 1,
"max": null,
- "min": null,
+ "min": "0",
"show": true
},
{
@@ -5708,12 +6124,14 @@
"dashes": false,
"datasource": "$datasource",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 7,
"w": 12,
"x": 12,
"y": 53
},
+ "hiddenSeries": false,
"id": 115,
"legend": {
"avg": false,
@@ -5728,7 +6146,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"paceLength": 10,
"percentage": false,
"pointradius": 5,
@@ -5816,8 +6236,9 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 40
+ "y": 58
},
+ "hiddenSeries": false,
"id": 67,
"legend": {
"avg": false,
@@ -5907,8 +6328,9 @@
"h": 9,
"w": 12,
"x": 12,
- "y": 40
+ "y": 58
},
+ "hiddenSeries": false,
"id": 71,
"legend": {
"avg": false,
@@ -5998,8 +6420,9 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 49
+ "y": 67
},
+ "hiddenSeries": false,
"id": 121,
"interval": "",
"legend": {
@@ -6116,7 +6539,7 @@
"h": 8,
"w": 12,
"x": 0,
- "y": 14
+ "y": 41
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -6171,12 +6594,14 @@
"datasource": "$datasource",
"description": "Number of rooms with the given number of forward extremities or fewer.\n\nThis is only updated once an hour.",
"fill": 0,
+ "fillGradient": 0,
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
- "y": 14
+ "y": 41
},
+ "hiddenSeries": false,
"id": 124,
"interval": "",
"legend": {
@@ -6192,7 +6617,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 2,
"points": false,
@@ -6273,7 +6700,7 @@
"h": 8,
"w": 12,
"x": 0,
- "y": 22
+ "y": 49
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -6328,12 +6755,14 @@
"datasource": "$datasource",
"description": "For a given percentage P, the number X where P% of events were persisted to rooms with X forward extremities or fewer.",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
- "y": 22
+ "y": 49
},
+ "hiddenSeries": false,
"id": 128,
"legend": {
"avg": false,
@@ -6348,7 +6777,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 2,
"points": false,
@@ -6448,7 +6879,7 @@
"h": 8,
"w": 12,
"x": 0,
- "y": 30
+ "y": 57
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -6503,12 +6934,14 @@
"datasource": "$datasource",
"description": "For given percentage P, the number X where P% of events were persisted to rooms with X stale forward extremities or fewer.\n\nStale forward extremities are those that were in the previous set of extremities as well as the new.",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
- "y": 30
+ "y": 57
},
+ "hiddenSeries": false,
"id": 130,
"legend": {
"avg": false,
@@ -6523,7 +6956,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 2,
"points": false,
@@ -6623,7 +7058,7 @@
"h": 8,
"w": 12,
"x": 0,
- "y": 38
+ "y": 65
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -6678,12 +7113,14 @@
"datasource": "$datasource",
"description": "For a given percentage P, the number X where P% of state resolution operations took place over X state groups or fewer.",
"fill": 1,
+ "fillGradient": 0,
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
- "y": 38
+ "y": 65
},
+ "hiddenSeries": false,
"id": 132,
"interval": "",
"legend": {
@@ -6699,7 +7136,9 @@
"linewidth": 1,
"links": [],
"nullPointMode": "null",
- "options": {},
+ "options": {
+ "dataLinks": []
+ },
"percentage": false,
"pointradius": 2,
"points": false,
@@ -6794,7 +7233,6 @@
"list": [
{
"current": {
- "selected": true,
"text": "Prometheus",
"value": "Prometheus"
},
@@ -6929,10 +7367,9 @@
"allFormat": "regex wildcard",
"allValue": ".*",
"current": {
+ "selected": false,
"text": "All",
- "value": [
- "$__all"
- ]
+ "value": "$__all"
},
"datasource": "$datasource",
"definition": "",
@@ -6991,5 +7428,5 @@
"timezone": "",
"title": "Synapse",
"uid": "000000012",
- "version": 19
+ "version": 29
} \ No newline at end of file
diff --git a/contrib/systemd/matrix-synapse.service b/contrib/systemd/matrix-synapse.service
index 813717b0..a7540784 100644
--- a/contrib/systemd/matrix-synapse.service
+++ b/contrib/systemd/matrix-synapse.service
@@ -15,6 +15,9 @@
[Unit]
Description=Synapse Matrix homeserver
+# If you are using postgresql to persist data, uncomment this line to make sure
+# synapse starts after the postgresql service.
+# After=postgresql.service
[Service]
Type=notify
diff --git a/debian/changelog b/debian/changelog
index e7842d41..182a50ee 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,15 @@
+matrix-synapse-py3 (1.15.0) stable; urgency=medium
+
+ * New synapse release 1.15.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 11 Jun 2020 13:27:06 +0100
+
+matrix-synapse-py3 (1.14.0) stable; urgency=medium
+
+ * New synapse release 1.14.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 28 May 2020 10:37:27 +0000
+
matrix-synapse-py3 (1.13.0) stable; urgency=medium
[ Patrick Cloke ]
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 93d61739..9a3cf7b3 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -55,7 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
### Stage 1: runtime
###
-FROM docker.io/python:${PYTHON_VERSION}-alpine3.10
+FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
# xmlsec is required for saml support
RUN apk add --no-cache --virtual .runtime_deps \
diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index 57972468..619585d5 100644
--- a/docker/Dockerfile-dhvirtualenv
+++ b/docker/Dockerfile-dhvirtualenv
@@ -28,11 +28,13 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
# fetch and unpack the package
RUN mkdir /dh-virtualenv
-RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/matrix-org/dh-virtualenv/archive/matrixorg-20200519.tar.gz
+RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
-# install its build deps
-RUN cd /dh-virtualenv \
+# install its build deps. We do another apt-cache-update here, because we might
+# be using a stale cache from docker build.
+RUN apt-get update -qq -o Acquire::Languages=none \
+ && cd /dh-virtualenv \
&& env DEBIAN_FRONTEND=noninteractive mk-build-deps -ri -t "apt-get -y --no-install-recommends"
# build it
diff --git a/docs/admin_api/README.rst b/docs/admin_api/README.rst
index 191806c5..9587bee0 100644
--- a/docs/admin_api/README.rst
+++ b/docs/admin_api/README.rst
@@ -4,17 +4,21 @@ Admin APIs
This directory includes documentation for the various synapse specific admin
APIs available.
-Only users that are server admins can use these APIs. A user can be marked as a
-server admin by updating the database directly, e.g.:
+Authenticating as a server admin
+--------------------------------
-``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
+Many of the API calls in the admin api will require an `access_token` for a
+server admin. (Note that a server admin is distinct from a room admin.)
-Restarting may be required for the changes to register.
+A user can be marked as a server admin by updating the database directly, e.g.:
-Using an admin access_token
-###########################
+.. code-block:: sql
+
+ UPDATE users SET admin = 1 WHERE name = '@foo:bar.com';
+
+A new server admin user can also be created using the
+``register_new_matrix_user`` script.
-Many of the API calls listed in the documentation here will require to include an admin `access_token`.
Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings.
Once you have your `access_token`, to include it in a request, the best option is to add the token to a request header:
diff --git a/docs/admin_api/delete_group.md b/docs/admin_api/delete_group.md
index 1710488e..c061678e 100644
--- a/docs/admin_api/delete_group.md
+++ b/docs/admin_api/delete_group.md
@@ -4,11 +4,11 @@ This API lets a server admin delete a local group. Doing so will kick all
users out of the group so that their clients will correctly handle the group
being deleted.
-
The API is:
```
POST /_synapse/admin/v1/delete_group/<group_id>
```
-including an `access_token` of a server admin.
+To use it, you will need to authenticate by providing an `access_token` for a
+server admin: see [README.rst](README.rst).
diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
index 46ba7a1a..26948770 100644
--- a/docs/admin_api/media_admin_api.md
+++ b/docs/admin_api/media_admin_api.md
@@ -6,9 +6,10 @@ The API is:
```
GET /_synapse/admin/v1/room/<room_id>/media
```
-including an `access_token` of a server admin.
+To use it, you will need to authenticate by providing an `access_token` for a
+server admin: see [README.rst](README.rst).
-It returns a JSON body like the following:
+The API returns a JSON body like the following:
```
{
"local": [
@@ -99,4 +100,3 @@ Response:
"num_quarantined": 10 # The number of media items successfully quarantined
}
```
-
diff --git a/docs/admin_api/purge_history_api.rst b/docs/admin_api/purge_history_api.rst
index e2a620c5..92cd05f2 100644
--- a/docs/admin_api/purge_history_api.rst
+++ b/docs/admin_api/purge_history_api.rst
@@ -15,7 +15,8 @@ The API is:
``POST /_synapse/admin/v1/purge_history/<room_id>[/<event_id>]``
-including an ``access_token`` of a server admin.
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
By default, events sent by local users are not deleted, as they may represent
the only copies of this content in existence. (Events sent by remote users are
@@ -54,8 +55,10 @@ It is possible to poll for updates on recent purges with a second API;
``GET /_synapse/admin/v1/purge_history_status/<purge_id>``
-(again, with a suitable ``access_token``). This API returns a JSON body like
-the following:
+Again, you will need to authenticate by providing an ``access_token`` for a
+server admin.
+
+This API returns a JSON body like the following:
.. code:: json
diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst
index dacd5bc8..00cb6b05 100644
--- a/docs/admin_api/purge_remote_media.rst
+++ b/docs/admin_api/purge_remote_media.rst
@@ -6,12 +6,15 @@ media.
The API is::
- POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>&access_token=<access_token>
+ POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>
{}
-Which will remove all cached media that was last accessed before
+\... which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.
diff --git a/docs/admin_api/room_membership.md b/docs/admin_api/room_membership.md
index 16736d3d..b6746ff5 100644
--- a/docs/admin_api/room_membership.md
+++ b/docs/admin_api/room_membership.md
@@ -23,7 +23,8 @@ POST /_synapse/admin/v1/join/<room_id_or_alias>
}
```
-Including an `access_token` of a server admin.
+To use it, you will need to authenticate by providing an `access_token` for a
+server admin: see [README.rst](README.rst).
Response:
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 26fe8b86..624e7745 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -264,3 +264,57 @@ Response:
Once the `next_token` parameter is no longer present, we know we've reached the
end of the list.
+
+# DRAFT: Room Details API
+
+The Room Details admin API allows server admins to get all details of a room.
+
+This API is still a draft and details might change!
+
+The following fields are possible in the JSON response body:
+
+* `room_id` - The ID of the room.
+* `name` - The name of the room.
+* `canonical_alias` - The canonical (main) alias address of the room.
+* `joined_members` - How many users are currently in the room.
+* `joined_local_members` - How many local users are currently in the room.
+* `version` - The version of the room as a string.
+* `creator` - The `user_id` of the room creator.
+* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
+* `federatable` - Whether users on other servers can join this room.
+* `public` - Whether the room is visible in room directory.
+* `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"].
+* `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"].
+* `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"].
+* `state_events` - Total number of state_events of a room. Complexity of the room.
+
+## Usage
+
+A standard request:
+
+```
+GET /_synapse/admin/v1/rooms/<room_id>
+
+{}
+```
+
+Response:
+
+```
+{
+ "room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
+ "name": "Music Theory",
+ "canonical_alias": "#musictheory:matrix.org",
+ "joined_members": 127
+ "joined_local_members": 2,
+ "version": "1",
+ "creator": "@foo:matrix.org",
+ "encryption": null,
+ "federatable": true,
+ "public": true,
+ "join_rules": "invite",
+ "guest_access": null,
+ "history_visibility": "shared",
+ "state_events": 93534
+}
+```
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 859d7f99..7b030a62 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -1,9 +1,47 @@
+.. contents::
+
+Query User Account
+==================
+
+This API returns information about a specific user account.
+
+The api is::
+
+ GET /_synapse/admin/v2/users/<user_id>
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+It returns a JSON body like the following:
+
+.. code:: json
+
+ {
+ "displayname": "User",
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "<user_mail_1>"
+ },
+ {
+ "medium": "email",
+ "address": "<user_mail_2>"
+ }
+ ],
+ "avatar_url": "<avatar_url>",
+ "admin": false,
+ "deactivated": false
+ }
+
+URL parameters:
+
+- ``user_id``: fully-qualified user id: for example, ``@user:server.com``.
+
Create or modify Account
========================
This API allows an administrator to create or modify a user account with a
-specific ``user_id``. Be aware that ``user_id`` is fully qualified: for example,
-``@user:server.com``.
+specific ``user_id``.
This api is::
@@ -31,23 +69,29 @@ with a body of:
"deactivated": false
}
-including an ``access_token`` of a server admin.
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+URL parameters:
+
+- ``user_id``: fully-qualified user id: for example, ``@user:server.com``.
-The parameter ``displayname`` is optional and defaults to the value of
-``user_id``.
+Body parameters:
-The parameter ``threepids`` is optional and allows setting the third-party IDs
-(email, msisdn) belonging to a user.
+- ``password``, optional. If provided, the user's password is updated and all
+ devices are logged out.
-The parameter ``avatar_url`` is optional. Must be a [MXC
-URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris).
+- ``displayname``, optional, defaults to the value of ``user_id``.
-The parameter ``admin`` is optional and defaults to ``false``.
+- ``threepids``, optional, allows setting the third-party IDs (email, msisdn)
+ belonging to a user.
-The parameter ``deactivated`` is optional and defaults to ``false``.
+- ``avatar_url``, optional, must be a
+ `MXC URI <https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris>`_.
-The parameter ``password`` is optional. If provided, the user's password is
-updated and all devices are logged out.
+- ``admin``, optional, defaults to ``false``.
+
+- ``deactivated``, optional, defaults to ``false``.
If the user already exists then optional parameters default to the current value.
@@ -60,7 +104,8 @@ The api is::
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
-including an ``access_token`` of a server admin.
+To use it, you will need to authenticate by providing an `access_token` for a
+server admin: see `README.rst <README.rst>`_.
The parameter ``from`` is optional but used for pagination, denoting the
offset in the returned results. This should be treated as an opaque value and
@@ -115,17 +160,17 @@ with ``from`` set to the value of ``next_token``. This will return a new page.
If the endpoint does not return a ``next_token`` then there are no more users
to paginate through.
-Query Account
-=============
+Query current sessions for a user
+=================================
-This API returns information about a specific user account.
+This API returns information about the active sessions for a specific user.
The api is::
- GET /_synapse/admin/v1/whois/<user_id> (deprecated)
- GET /_synapse/admin/v2/users/<user_id>
+ GET /_synapse/admin/v1/whois/<user_id>
-including an ``access_token`` of a server admin.
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
It returns a JSON body like the following:
@@ -178,9 +223,10 @@ with a body of:
"erase": true
}
-including an ``access_token`` of a server admin.
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
-The erase parameter is optional and defaults to 'false'.
+The erase parameter is optional and defaults to ``false``.
An empty body may be passed for backwards compatibility.
@@ -202,7 +248,8 @@ with a body of:
"logout_devices": true,
}
-including an ``access_token`` of a server admin.
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
The parameter ``new_password`` is required.
The parameter ``logout_devices`` is optional and defaults to ``true``.
@@ -215,7 +262,8 @@ The api is::
GET /_synapse/admin/v1/users/<user_id>/admin
-including an ``access_token`` of a server admin.
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
A response body like the following is returned:
@@ -243,4 +291,191 @@ with a body of:
"admin": true
}
-including an ``access_token`` of a server admin.
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+
+User devices
+============
+
+List all devices
+----------------
+Gets information about all devices for a specific ``user_id``.
+
+The API is::
+
+ GET /_synapse/admin/v2/users/<user_id>/devices
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+A response body like the following is returned:
+
+.. code:: json
+
+ {
+ "devices": [
+ {
+ "device_id": "QBUAZIFURK",
+ "display_name": "android",
+ "last_seen_ip": "1.2.3.4",
+ "last_seen_ts": 1474491775024,
+ "user_id": "<user_id>"
+ },
+ {
+ "device_id": "AUIECTSRND",
+ "display_name": "ios",
+ "last_seen_ip": "1.2.3.5",
+ "last_seen_ts": 1474491775025,
+ "user_id": "<user_id>"
+ }
+ ]
+ }
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - fully qualified: for example, ``@user:server.com``.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``devices`` - An array of objects, each containing information about a device.
+ Device objects contain the following fields:
+
+ - ``device_id`` - Identifier of device.
+ - ``display_name`` - Display name set by the user for this device.
+ Absent if no name has been set.
+ - ``last_seen_ip`` - The IP address where this device was last seen.
+ (May be a few minutes out of date, for efficiency reasons).
+ - ``last_seen_ts`` - The timestamp (in milliseconds since the unix epoch) when this
+ devices was last seen. (May be a few minutes out of date, for efficiency reasons).
+ - ``user_id`` - Owner of device.
+
+Delete multiple devices
+------------------
+Deletes the given devices for a specific ``user_id``, and invalidates
+any access token associated with them.
+
+The API is::
+
+ POST /_synapse/admin/v2/users/<user_id>/delete_devices
+
+ {
+ "devices": [
+ "QBUAZIFURK",
+ "AUIECTSRND"
+ ],
+ }
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - fully qualified: for example, ``@user:server.com``.
+
+The following fields are required in the JSON request body:
+
+- ``devices`` - The list of device IDs to delete.
+
+Show a device
+---------------
+Gets information on a single device, by ``device_id`` for a specific ``user_id``.
+
+The API is::
+
+ GET /_synapse/admin/v2/users/<user_id>/devices/<device_id>
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+A response body like the following is returned:
+
+.. code:: json
+
+ {
+ "device_id": "<device_id>",
+ "display_name": "android",
+ "last_seen_ip": "1.2.3.4",
+ "last_seen_ts": 1474491775024,
+ "user_id": "<user_id>"
+ }
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - fully qualified: for example, ``@user:server.com``.
+- ``device_id`` - The device to retrieve.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``device_id`` - Identifier of device.
+- ``display_name`` - Display name set by the user for this device.
+ Absent if no name has been set.
+- ``last_seen_ip`` - The IP address where this device was last seen.
+ (May be a few minutes out of date, for efficiency reasons).
+- ``last_seen_ts`` - The timestamp (in milliseconds since the unix epoch) when this
+ devices was last seen. (May be a few minutes out of date, for efficiency reasons).
+- ``user_id`` - Owner of device.
+
+Update a device
+---------------
+Updates the metadata on the given ``device_id`` for a specific ``user_id``.
+
+The API is::
+
+ PUT /_synapse/admin/v2/users/<user_id>/devices/<device_id>
+
+ {
+ "display_name": "My other phone"
+ }
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - fully qualified: for example, ``@user:server.com``.
+- ``device_id`` - The device to update.
+
+The following fields are required in the JSON request body:
+
+- ``display_name`` - The new display name for this device. If not given,
+ the display name is unchanged.
+
+Delete a device
+---------------
+Deletes the given ``device_id`` for a specific ``user_id``,
+and invalidates any access token associated with it.
+
+The API is::
+
+ DELETE /_synapse/admin/v2/users/<user_id>/devices/<device_id>
+
+ {}
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - fully qualified: for example, ``@user:server.com``.
+- ``device_id`` - The device to delete.
diff --git a/docs/dev/git.md b/docs/dev/git.md
new file mode 100644
index 00000000..b747ff20
--- /dev/null
+++ b/docs/dev/git.md
@@ -0,0 +1,148 @@
+Some notes on how we use git
+============================
+
+On keeping the commit history clean
+-----------------------------------
+
+In an ideal world, our git commit history would be a linear progression of
+commits each of which contains a single change building on what came
+before. Here, by way of an arbitrary example, is the top of `git log --graph
+b2dba0607`:
+
+<img src="git/clean.png" alt="clean git graph" width="500px">
+
+Note how the commit comment explains clearly what is changing and why. Also
+note the *absence* of merge commits, as well as the absence of commits called
+things like (to pick a few culprits):
+[“pep8”](https://github.com/matrix-org/synapse/commit/84691da6c), [“fix broken
+test”](https://github.com/matrix-org/synapse/commit/474810d9d),
+[“oops”](https://github.com/matrix-org/synapse/commit/c9d72e457),
+[“typo”](https://github.com/matrix-org/synapse/commit/836358823), or [“Who's
+the president?”](https://github.com/matrix-org/synapse/commit/707374d5d).
+
+There are a number of reasons why keeping a clean commit history is a good
+thing:
+
+ * From time to time, after a change lands, it turns out to be necessary to
+ revert it, or to backport it to a release branch. Those operations are
+ *much* easier when the change is contained in a single commit.
+
+ * Similarly, it's much easier to answer questions like “is the fix for
+ `/publicRooms` on the release branch?” if that change consists of a single
+ commit.
+
+ * Likewise: “what has changed on this branch in the last week?” is much
+ clearer without merges and “pep8” commits everywhere.
+
+ * Sometimes we need to figure out where a bug got introduced, or some
+ behaviour changed. One way of doing that is with `git bisect`: pick an
+ arbitrary commit between the known good point and the known bad point, and
+ see how the code behaves. However, that strategy fails if the commit you
+ chose is the middle of someone's epic branch in which they broke the world
+ before putting it back together again.
+
+One counterargument is that it is sometimes useful to see how a PR evolved as
+it went through review cycles. This is true, but that information is always
+available via the GitHub UI (or via the little-known [refs/pull
+namespace](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/checking-out-pull-requests-locally)).
+
+
+Of course, in reality, things are more complicated than that. We have release
+branches as well as `develop` and `master`, and we deliberately merge changes
+between them. Bugs often slip through and have to be fixed later. That's all
+fine: this not a cast-iron rule which must be obeyed, but an ideal to aim
+towards.
+
+Merges, squashes, rebases: wtf?
+-------------------------------
+
+Ok, so that's what we'd like to achieve. How do we achieve it?
+
+The TL;DR is: when you come to merge a pull request, you *probably* want to
+“squash and merge”:
+
+![squash and merge](git/squash.png).
+
+(This applies whether you are merging your own PR, or that of another
+contributor.)
+
+“Squash and merge”<sup id="a1">[1](#f1)</sup> takes all of the changes in the
+PR, and bundles them into a single commit. GitHub gives you the opportunity to
+edit the commit message before you confirm, and normally you should do so,
+because the default will be useless (again: `* woops typo` is not a useful
+thing to keep in the historical record).
+
+The main problem with this approach comes when you have a series of pull
+requests which build on top of one another: as soon as you squash-merge the
+first PR, you'll end up with a stack of conflicts to resolve in all of the
+others. In general, it's best to avoid this situation in the first place by
+trying not to have multiple related PRs in flight at the same time. Still,
+sometimes that's not possible and doing a regular merge is the lesser evil.
+
+Another occasion in which a regular merge makes more sense is a PR where you've
+deliberately created a series of commits each of which makes sense in its own
+right. For example: [a PR which gradually propagates a refactoring operation
+through the codebase](https://github.com/matrix-org/synapse/pull/6837), or [a
+PR which is the culmination of several other
+PRs](https://github.com/matrix-org/synapse/pull/5987). In this case the ability
+to figure out when a particular change/bug was introduced could be very useful.
+
+Ultimately: **this is not a hard-and-fast-rule**. If in doubt, ask yourself “do
+each of the commits I am about to merge make sense in their own right”, but
+remember that we're just doing our best to balance “keeping the commit history
+clean” with other factors.
+
+Git branching model
+-------------------
+
+A [lot](https://nvie.com/posts/a-successful-git-branching-model/)
+[of](http://scottchacon.com/2011/08/31/github-flow.html)
+[words](https://www.endoflineblog.com/gitflow-considered-harmful) have been
+written in the past about git branching models (no really, [a
+lot](https://martinfowler.com/articles/branching-patterns.html)). I tend to
+think the whole thing is overblown. Fundamentally, it's not that
+complicated. Here's how we do it.
+
+Let's start with a picture:
+
+![branching model](git/branches.jpg)
+
+It looks complicated, but it's really not. There's one basic rule: *anyone* is
+free to merge from *any* more-stable branch to *any* less-stable branch at
+*any* time<sup id="a2">[2](#f2)</sup>. (The principle behind this is that if a
+change is good enough for the more-stable branch, then it's also good enough go
+put in a less-stable branch.)
+
+Meanwhile, merging (or squashing, as per the above) from a less-stable to a
+more-stable branch is a deliberate action in which you want to publish a change
+or a set of changes to (some subset of) the world: for example, this happens
+when a PR is landed, or as part of our release process.
+
+So, what counts as a more- or less-stable branch? A little reflection will show
+that our active branches are ordered thus, from more-stable to less-stable:
+
+ * `master` (tracks our last release).
+ * `release-vX.Y.Z` (the branch where we prepare the next release)<sup
+ id="a3">[3](#f3)</sup>.
+ * PR branches which are targeting the release.
+ * `develop` (our "mainline" branch containing our bleeding-edge).
+ * regular PR branches.
+
+The corollary is: if you have a bugfix that needs to land in both
+`release-vX.Y.Z` *and* `develop`, then you should base your PR on
+`release-vX.Y.Z`, get it merged there, and then merge from `release-vX.Y.Z` to
+`develop`. (If a fix lands in `develop` and we later need it in a
+release-branch, we can of course cherry-pick it, but landing it in the release
+branch first helps reduce the chance of annoying conflicts.)
+
+---
+
+<b id="f1">[1]</b>: “Squash and merge” is GitHub's term for this
+operation. Given that there is no merge involved, I'm not convinced it's the
+most intuitive name. [^](#a1)
+
+<b id="f2">[2]</b>: Well, anyone with commit access.[^](#a2)
+
+<b id="f3">[3]</b>: Very, very occasionally (I think this has happened once in
+the history of Synapse), we've had two releases in flight at once. Obviously,
+`release-v1.2.3` is more-stable than `release-v1.3.0`. [^](#a3)
diff --git a/docs/dev/git/branches.jpg b/docs/dev/git/branches.jpg
new file mode 100644
index 00000000..715ecc8c
--- /dev/null
+++ b/docs/dev/git/branches.jpg
Binary files differ
diff --git a/docs/dev/git/clean.png b/docs/dev/git/clean.png
new file mode 100644
index 00000000..3accd7cc
--- /dev/null
+++ b/docs/dev/git/clean.png
Binary files differ
diff --git a/docs/dev/git/squash.png b/docs/dev/git/squash.png
new file mode 100644
index 00000000..234caca3
--- /dev/null
+++ b/docs/dev/git/squash.png
Binary files differ
diff --git a/docs/openid.md b/docs/openid.md
new file mode 100644
index 00000000..688379dd
--- /dev/null
+++ b/docs/openid.md
@@ -0,0 +1,206 @@
+# Configuring Synapse to authenticate against an OpenID Connect provider
+
+Synapse can be configured to use an OpenID Connect Provider (OP) for
+authentication, instead of its own local password database.
+
+Any OP should work with Synapse, as long as it supports the authorization code
+flow. There are a few options for that:
+
+ - start a local OP. Synapse has been tested with [Hydra][hydra] and
+ [Dex][dex-idp]. Note that for an OP to work, it should be served under a
+ secure (HTTPS) origin. A certificate signed with a self-signed, locally
+ trusted CA should work. In that case, start Synapse with a `SSL_CERT_FILE`
+ environment variable set to the path of the CA.
+
+ - set up a SaaS OP, like [Google][google-idp], [Auth0][auth0] or
+ [Okta][okta]. Synapse has been tested with Auth0 and Google.
+
+It may also be possible to use other OAuth2 providers which provide the
+[authorization code grant type](https://tools.ietf.org/html/rfc6749#section-4.1),
+such as [Github][github-idp].
+
+[google-idp]: https://developers.google.com/identity/protocols/oauth2/openid-connect
+[auth0]: https://auth0.com/
+[okta]: https://www.okta.com/
+[dex-idp]: https://github.com/dexidp/dex
+[hydra]: https://www.ory.sh/docs/hydra/
+[github-idp]: https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps
+
+## Preparing Synapse
+
+The OpenID integration in Synapse uses the
+[`authlib`](https://pypi.org/project/Authlib/) library, which must be installed
+as follows:
+
+ * The relevant libraries are included in the Docker images and Debian packages
+ provided by `matrix.org` so no further action is needed.
+
+ * If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
+ install synapse[oidc]` to install the necessary dependencies.
+
+ * For other installation mechanisms, see the documentation provided by the
+ maintainer.
+
+To enable the OpenID integration, you should then add an `oidc_config` section
+to your configuration file (or uncomment the `enabled: true` line in the
+existing section). See [sample_config.yaml](./sample_config.yaml) for some
+sample settings, as well as the text below for example configurations for
+specific providers.
+
+## Sample configs
+
+Here are a few configs for providers that should work with Synapse.
+
+### [Dex][dex-idp]
+
+[Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider.
+Although it is designed to help building a full-blown provider with an
+external database, it can be configured with static passwords in a config file.
+
+Follow the [Getting Started
+guide](https://github.com/dexidp/dex/blob/master/Documentation/getting-started.md)
+to install Dex.
+
+Edit `examples/config-dev.yaml` config file from the Dex repo to add a client:
+
+```yaml
+staticClients:
+- id: synapse
+ secret: secret
+ redirectURIs:
+ - '[synapse public baseurl]/_synapse/oidc/callback'
+ name: 'Synapse'
+```
+
+Run with `dex serve examples/config-dex.yaml`.
+
+Synapse config:
+
+```yaml
+oidc_config:
+ enabled: true
+ skip_verification: true # This is needed as Dex is served on an insecure endpoint
+ issuer: "http://127.0.0.1:5556/dex"
+ client_id: "synapse"
+ client_secret: "secret"
+ scopes: ["openid", "profile"]
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.name }}"
+ display_name_template: "{{ user.name|capitalize }}"
+```
+
+### [Auth0][auth0]
+
+1. Create a regular web application for Synapse
+2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/oidc/callback`
+3. Add a rule to add the `preferred_username` claim.
+ <details>
+ <summary>Code sample</summary>
+
+ ```js
+ function addPersistenceAttribute(user, context, callback) {
+ user.user_metadata = user.user_metadata || {};
+ user.user_metadata.preferred_username = user.user_metadata.preferred_username || user.user_id;
+ context.idToken.preferred_username = user.user_metadata.preferred_username;
+
+ auth0.users.updateUserMetadata(user.user_id, user.user_metadata)
+ .then(function(){
+ callback(null, user, context);
+ })
+ .catch(function(err){
+ callback(err);
+ });
+ }
+ ```
+ </details>
+
+Synapse config:
+
+```yaml
+oidc_config:
+ enabled: true
+ issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ scopes: ["openid", "profile"]
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.preferred_username }}"
+ display_name_template: "{{ user.name }}"
+```
+
+### GitHub
+
+GitHub is a bit special as it is not an OpenID Connect compliant provider, but
+just a regular OAuth2 provider.
+
+The [`/user` API endpoint](https://developer.github.com/v3/users/#get-the-authenticated-user)
+can be used to retrieve information on the authenticated user. As the Synaspse
+login mechanism needs an attribute to uniquely identify users, and that endpoint
+does not return a `sub` property, an alternative `subject_claim` has to be set.
+
+1. Create a new OAuth application: https://github.com/settings/applications/new.
+2. Set the callback URL to `[synapse public baseurl]/_synapse/oidc/callback`.
+
+Synapse config:
+
+```yaml
+oidc_config:
+ enabled: true
+ discover: false
+ issuer: "https://github.com/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ authorization_endpoint: "https://github.com/login/oauth/authorize"
+ token_endpoint: "https://github.com/login/oauth/access_token"
+ userinfo_endpoint: "https://api.github.com/user"
+ scopes: ["read:user"]
+ user_mapping_provider:
+ config:
+ subject_claim: "id"
+ localpart_template: "{{ user.login }}"
+ display_name_template: "{{ user.name }}"
+```
+
+### [Google][google-idp]
+
+1. Set up a project in the Google API Console (see
+ https://developers.google.com/identity/protocols/oauth2/openid-connect#appsetup).
+2. add an "OAuth Client ID" for a Web Application under "Credentials".
+3. Copy the Client ID and Client Secret, and add the following to your synapse config:
+ ```yaml
+ oidc_config:
+ enabled: true
+ issuer: "https://accounts.google.com/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ scopes: ["openid", "profile"]
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.given_name|lower }}"
+ display_name_template: "{{ user.name }}"
+ ```
+4. Back in the Google console, add this Authorized redirect URI: `[synapse
+ public baseurl]/_synapse/oidc/callback`.
+
+### Twitch
+
+1. Setup a developer account on [Twitch](https://dev.twitch.tv/)
+2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/)
+3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/oidc/callback`
+
+Synapse config:
+
+```yaml
+oidc_config:
+ enabled: true
+ issuer: "https://id.twitch.tv/oauth2/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: "client_secret_post"
+ user_mapping_provider:
+ config:
+ localpart_template: '{{ user.preferred_username }}'
+ display_name_template: '{{ user.name }}'
+```
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index c7222f73..cbb82695 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -9,7 +9,7 @@ of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root
privileges.
-> **NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
+**NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
the requested URI in any way (for example, by decoding `%xx` escapes).
Beware that Apache *will* canonicalise URIs unless you specifify
`nocanon`.
@@ -18,7 +18,7 @@ When setting up a reverse proxy, remember that Matrix clients and other
Matrix servers do not necessarily need to connect to your server via the
same server name or port. Indeed, clients will use port 443 by default,
whereas servers default to port 8448. Where these are different, we
-refer to the 'client port' and the \'federation port\'. See [the Matrix
+refer to the 'client port' and the 'federation port'. See [the Matrix
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names)
for more details of the algorithm used for federation connections, and
[delegate.md](<delegate.md>) for instructions on setting up delegation.
@@ -28,93 +28,113 @@ Let's assume that we expect clients to connect to our server at
`https://example.com:8448`. The following sections detail the configuration of
the reverse proxy and the homeserver.
-## Webserver configuration examples
+## Reverse-proxy configuration examples
-> **NOTE**: You only need one of these.
+**NOTE**: You only need one of these.
### nginx
- server {
- listen 443 ssl;
- listen [::]:443 ssl;
- server_name matrix.example.com;
-
- location /_matrix {
- proxy_pass http://localhost:8008;
- proxy_set_header X-Forwarded-For $remote_addr;
- # Nginx by default only allows file uploads up to 1M in size
- # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
- client_max_body_size 10M;
- }
- }
-
- server {
- listen 8448 ssl default_server;
- listen [::]:8448 ssl default_server;
- server_name example.com;
-
- location / {
- proxy_pass http://localhost:8008;
- proxy_set_header X-Forwarded-For $remote_addr;
- }
- }
-
-> **NOTE**: Do not add a `/` after the port in `proxy_pass`, otherwise nginx will
+```
+server {
+ listen 443 ssl;
+ listen [::]:443 ssl;
+ server_name matrix.example.com;
+
+ location /_matrix {
+ proxy_pass http://localhost:8008;
+ proxy_set_header X-Forwarded-For $remote_addr;
+ # Nginx by default only allows file uploads up to 1M in size
+ # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
+ client_max_body_size 10M;
+ }
+}
+
+server {
+ listen 8448 ssl default_server;
+ listen [::]:8448 ssl default_server;
+ server_name example.com;
+
+ location / {
+ proxy_pass http://localhost:8008;
+ proxy_set_header X-Forwarded-For $remote_addr;
+ }
+}
+```
+
+**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will
canonicalise/normalise the URI.
-### Caddy
+### Caddy 1
- matrix.example.com {
- proxy /_matrix http://localhost:8008 {
- transparent
- }
- }
+```
+matrix.example.com {
+ proxy /_matrix http://localhost:8008 {
+ transparent
+ }
+}
- example.com:8448 {
- proxy / http://localhost:8008 {
- transparent
- }
- }
+example.com:8448 {
+ proxy / http://localhost:8008 {
+ transparent
+ }
+}
+```
+
+### Caddy 2
+
+```
+matrix.example.com {
+ reverse_proxy /_matrix/* http://localhost:8008
+}
+
+example.com:8448 {
+ reverse_proxy http://localhost:8008
+}
+```
### Apache
- <VirtualHost *:443>
- SSLEngine on
- ServerName matrix.example.com;
+```
+<VirtualHost *:443>
+ SSLEngine on
+ ServerName matrix.example.com;
- AllowEncodedSlashes NoDecode
- ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
- ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
- </VirtualHost>
+ AllowEncodedSlashes NoDecode
+ ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
+ ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
+</VirtualHost>
- <VirtualHost *:8448>
- SSLEngine on
- ServerName example.com;
+<VirtualHost *:8448>
+ SSLEngine on
+ ServerName example.com;
- AllowEncodedSlashes NoDecode
- ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
- ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
- </VirtualHost>
+ AllowEncodedSlashes NoDecode
+ ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
+ ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
+</VirtualHost>
+```
-> **NOTE**: ensure the `nocanon` options are included.
+**NOTE**: ensure the `nocanon` options are included.
### HAProxy
- frontend https
- bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
+```
+frontend https
+ bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
- # Matrix client traffic
- acl matrix-host hdr(host) -i matrix.example.com
- acl matrix-path path_beg /_matrix
+ # Matrix client traffic
+ acl matrix-host hdr(host) -i matrix.example.com
+ acl matrix-path path_beg /_matrix
- use_backend matrix if matrix-host matrix-path
+ use_backend matrix if matrix-host matrix-path
- frontend matrix-federation
- bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
- default_backend matrix
+frontend matrix-federation
+ bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
+ default_backend matrix
- backend matrix
- server matrix 127.0.0.1:8008
+backend matrix
+ server matrix 127.0.0.1:8008
+```
## Homeserver Configuration
diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md
deleted file mode 100644
index 92f23804..00000000
--- a/docs/saml_mapping_providers.md
+++ /dev/null
@@ -1,77 +0,0 @@
-# SAML Mapping Providers
-
-A SAML mapping provider is a Python class (loaded via a Python module) that
-works out how to map attributes of a SAML response object to Matrix-specific
-user attributes. Details such as user ID localpart, displayname, and even avatar
-URLs are all things that can be mapped from talking to a SSO service.
-
-As an example, a SSO service may return the email address
-"john.smith@example.com" for a user, whereas Synapse will need to figure out how
-to turn that into a displayname when creating a Matrix user for this individual.
-It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
-variations. As each Synapse configuration may want something different, this is
-where SAML mapping providers come into play.
-
-## Enabling Providers
-
-External mapping providers are provided to Synapse in the form of an external
-Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
-then tell Synapse where to look for the handler class by editing the
-`saml2_config.user_mapping_provider.module` config option.
-
-`saml2_config.user_mapping_provider.config` allows you to provide custom
-configuration options to the module. Check with the module's documentation for
-what options it provides (if any). The options listed by default are for the
-user mapping provider built in to Synapse. If using a custom module, you should
-comment these options out and use those specified by the module instead.
-
-## Building a Custom Mapping Provider
-
-A custom mapping provider must specify the following methods:
-
-* `__init__(self, parsed_config)`
- - Arguments:
- - `parsed_config` - A configuration object that is the return value of the
- `parse_config` method. You should set any configuration options needed by
- the module here.
-* `saml_response_to_user_attributes(self, saml_response, failures)`
- - Arguments:
- - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
- information from.
- - `failures` - An `int` that represents the amount of times the returned
- mxid localpart mapping has failed. This should be used
- to create a deduplicated mxid localpart which should be
- returned instead. For example, if this method returns
- `john.doe` as the value of `mxid_localpart` in the returned
- dict, and that is already taken on the homeserver, this
- method will be called again with the same parameters but
- with failures=1. The method should then return a different
- `mxid_localpart` value, such as `john.doe1`.
- - This method must return a dictionary, which will then be used by Synapse
- to build a new user. The following keys are allowed:
- * `mxid_localpart` - Required. The mxid localpart of the new user.
- * `displayname` - The displayname of the new user. If not provided, will default to
- the value of `mxid_localpart`.
-* `parse_config(config)`
- - This method should have the `@staticmethod` decoration.
- - Arguments:
- - `config` - A `dict` representing the parsed content of the
- `saml2_config.user_mapping_provider.config` homeserver config option.
- Runs on homeserver startup. Providers should extract any option values
- they need here.
- - Whatever is returned will be passed back to the user mapping provider module's
- `__init__` method during construction.
-* `get_saml_attributes(config)`
- - This method should have the `@staticmethod` decoration.
- - Arguments:
- - `config` - A object resulting from a call to `parse_config`.
- - Returns a tuple of two sets. The first set equates to the saml auth
- response attributes that are required for the module to function, whereas
- the second set consists of those attributes which can be used if available,
- but are not necessary.
-
-## Synapse's Default Provider
-
-Synapse has a built-in SAML mapping provider if a custom provider isn't
-specified in the config. It is located at
-[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 98ead7dc..94e1ec69 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -322,22 +322,27 @@ listeners:
# Used by phonehome stats to group together related servers.
#server_context: context
-# Resource-constrained homeserver Settings
+# Resource-constrained homeserver settings
#
-# If limit_remote_rooms.enabled is True, the room complexity will be
-# checked before a user joins a new remote room. If it is above
-# limit_remote_rooms.complexity, it will disallow joining or
-# instantly leave.
+# When this is enabled, the room "complexity" will be checked before a user
+# joins a new remote room. If it is above the complexity limit, the server will
+# disallow joining, or will instantly leave.
#
-# limit_remote_rooms.complexity_error can be set to customise the text
-# displayed to the user when a room above the complexity threshold has
-# its join cancelled.
+# Room complexity is an arbitrary measure based on factors such as the number of
+# users in the room.
#
-# Uncomment the below lines to enable:
-#limit_remote_rooms:
-# enabled: true
-# complexity: 1.0
-# complexity_error: "This room is too complex."
+limit_remote_rooms:
+ # Uncomment to enable room complexity checking.
+ #
+ #enabled: true
+
+ # the limit above which rooms cannot be joined. The default is 1.0.
+ #
+ #complexity: 0.5
+
+ # override the error which is returned when the room is too complex.
+ #
+ #complexity_error: "This room is too complex."
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.
@@ -603,6 +608,51 @@ acme:
+## Caching ##
+
+# Caching can be configured through the following options.
+#
+# A cache 'factor' is a multiplier that can be applied to each of
+# Synapse's caches in order to increase or decrease the maximum
+# number of entries that can be stored.
+
+# The number of events to cache in memory. Not affected by
+# caches.global_factor.
+#
+#event_cache_size: 10K
+
+caches:
+ # Controls the global cache factor, which is the default cache factor
+ # for all caches if a specific factor for that cache is not otherwise
+ # set.
+ #
+ # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
+ # variable. Setting by environment variable takes priority over
+ # setting through the config file.
+ #
+ # Defaults to 0.5, which will half the size of all caches.
+ #
+ #global_factor: 1.0
+
+ # A dictionary of cache name to cache factor for that individual
+ # cache. Overrides the global cache factor for a given cache.
+ #
+ # These can also be set through environment variables comprised
+ # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
+ # letters and underscores. Setting by environment variable
+ # takes priority over setting through the config file.
+ # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
+ #
+ # Some caches have '*' and other characters that are not
+ # alphanumeric or underscores. These caches can be named with or
+ # without the special characters stripped. For example, to specify
+ # the cache factor for `*stateGroupCache*` via an environment
+ # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
+ #
+ per_cache_factors:
+ #get_users_who_share_room_with_user: 2.0
+
+
## Database ##
# The 'database' setting defines the database that synapse uses to store all of
@@ -646,10 +696,6 @@ database:
args:
database: DATADIR/homeserver.db
-# Number of events to cache in memory.
-#
-#event_cache_size: 10K
-
## Logging ##
@@ -907,25 +953,28 @@ url_preview_accept_language:
## Captcha ##
-# See docs/CAPTCHA_SETUP for full details of configuring this.
+# See docs/CAPTCHA_SETUP.md for full details of configuring this.
-# This homeserver's ReCAPTCHA public key.
+# This homeserver's ReCAPTCHA public key. Must be specified if
+# enable_registration_captcha is enabled.
#
#recaptcha_public_key: "YOUR_PUBLIC_KEY"
-# This homeserver's ReCAPTCHA private key.
+# This homeserver's ReCAPTCHA private key. Must be specified if
+# enable_registration_captcha is enabled.
#
#recaptcha_private_key: "YOUR_PRIVATE_KEY"
-# Enables ReCaptcha checks when registering, preventing signup
+# Uncomment to enable ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
-# public/private key.
+# public/private key. Defaults to 'false'.
#
-#enable_registration_captcha: false
+#enable_registration_captcha: true
# The API endpoint to use for verifying m.login.recaptcha responses.
+# Defaults to "https://www.recaptcha.net/recaptcha/api/siteverify".
#
-#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
+#recaptcha_siteverify_api: "https://my.recaptcha.site"
## TURN ##
@@ -1069,7 +1118,7 @@ account_validity:
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
-# registration_shared_secret: <PRIVATE STRING>
+#registration_shared_secret: <PRIVATE STRING>
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
@@ -1174,6 +1223,13 @@ account_threepid_delegates:
#
#autocreate_auto_join_rooms: true
+# When auto_join_rooms is specified, setting this flag to false prevents
+# guest accounts from being automatically joined to the rooms.
+#
+# Defaults to true.
+#
+#auto_join_rooms_for_guests: false
+
## Metrics ###
@@ -1202,7 +1258,8 @@ metrics_flags:
#known_servers: true
# Whether or not to report anonymized homeserver usage statistics.
-# report_stats: true|false
+#
+#report_stats: true|false
# The endpoint to report the anonymized homeserver usage statistics to.
# Defaults to https://matrix.org/report-usage-stats/push
@@ -1238,13 +1295,13 @@ metrics_flags:
# the registration_shared_secret is used, if one is given; otherwise,
# a secret key is derived from the signing key.
#
-# macaroon_secret_key: <PRIVATE STRING>
+#macaroon_secret_key: <PRIVATE STRING>
# a secret which is used to calculate HMACs for form values, to stop
# falsification of values. Must be specified for the User Consent
# forms to work.
#
-# form_secret: <PRIVATE STRING>
+#form_secret: <PRIVATE STRING>
## Signing Keys ##
@@ -1329,6 +1386,8 @@ trusted_key_servers:
#key_server_signing_keys_path: "key_server_signing_keys.key"
+## Single sign-on integration ##
+
# Enable SAML2 for registration and login. Uses pysaml2.
#
# At least one of `sp_config` or `config_path` must be set in this section to
@@ -1462,7 +1521,13 @@ saml2_config:
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
- # This template doesn't currently need any variable to render.
+ # When rendering, this template is given the following variables:
+ # * code: an HTML error code corresponding to the error that is being
+ # returned (typically 400 or 500)
+ #
+ # * msg: a textual message describing the error.
+ #
+ # The variables will automatically be HTML-escaped.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
@@ -1470,6 +1535,121 @@ saml2_config:
#template_dir: "res/templates"
+# OpenID Connect integration. The following settings can be used to make Synapse
+# use an OpenID Connect Provider for authentication, instead of its internal
+# password database.
+#
+# See https://github.com/matrix-org/synapse/blob/master/openid.md.
+#
+oidc_config:
+ # Uncomment the following to enable authorization against an OpenID Connect
+ # server. Defaults to false.
+ #
+ #enabled: true
+
+ # Uncomment the following to disable use of the OIDC discovery mechanism to
+ # discover endpoints. Defaults to true.
+ #
+ #discover: false
+
+ # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
+ # discover the provider's endpoints.
+ #
+ # Required if 'enabled' is true.
+ #
+ #issuer: "https://accounts.example.com/"
+
+ # oauth2 client id to use.
+ #
+ # Required if 'enabled' is true.
+ #
+ #client_id: "provided-by-your-issuer"
+
+ # oauth2 client secret to use.
+ #
+ # Required if 'enabled' is true.
+ #
+ #client_secret: "provided-by-your-issuer"
+
+ # auth method to use when exchanging the token.
+ # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
+ # 'none'.
+ #
+ #client_auth_method: client_secret_post
+
+ # list of scopes to request. This should normally include the "openid" scope.
+ # Defaults to ["openid"].
+ #
+ #scopes: ["openid", "profile"]
+
+ # the oauth2 authorization endpoint. Required if provider discovery is disabled.
+ #
+ #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+
+ # the oauth2 token endpoint. Required if provider discovery is disabled.
+ #
+ #token_endpoint: "https://accounts.example.com/oauth2/token"
+
+ # the OIDC userinfo endpoint. Required if discovery is disabled and the
+ # "openid" scope is not requested.
+ #
+ #userinfo_endpoint: "https://accounts.example.com/userinfo"
+
+ # URI where to fetch the JWKS. Required if discovery is disabled and the
+ # "openid" scope is used.
+ #
+ #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+
+ # Uncomment to skip metadata verification. Defaults to false.
+ #
+ # Use this if you are connecting to a provider that is not OpenID Connect
+ # compliant.
+ # Avoid this in production.
+ #
+ #skip_verification: true
+
+ # An external module can be provided here as a custom solution to mapping
+ # attributes returned from a OIDC provider onto a matrix user.
+ #
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
+ #
+ # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+ # for information on implementing a custom mapping provider.
+ #
+ #module: mapping_provider.OidcMappingProvider
+
+ # Custom configuration values for the module. This section will be passed as
+ # a Python dictionary to the user mapping provider module's `parse_config`
+ # method.
+ #
+ # The examples below are intended for the default provider: they should be
+ # changed if using a custom provider.
+ #
+ config:
+ # name of the claim containing a unique identifier for the user.
+ # Defaults to `sub`, which OpenID Connect compliant providers should provide.
+ #
+ #subject_claim: "sub"
+
+ # Jinja2 template for the localpart of the MXID.
+ #
+ # When rendering, this template is given the following variables:
+ # * user: The claims returned by the UserInfo Endpoint and/or in the ID
+ # Token
+ #
+ # This must be configured if using the default mapping provider.
+ #
+ localpart_template: "{{ user.preferred_username }}"
+
+ # Jinja2 template for the display name to set on first login.
+ #
+ # If unset, no displayname will be set.
+ #
+ #display_name_template: "{{ user.given_name }} {{ user.last_name }}"
+
+
# Enable CAS for registration and login.
#
@@ -1482,7 +1662,8 @@ saml2_config:
# # name: value
-# Additional settings to use with single-sign on systems such as SAML2 and CAS.
+# Additional settings to use with single-sign on systems such as OpenID Connect,
+# SAML2 and CAS.
#
sso:
# A list of client URLs which are whitelisted so that the user does not
@@ -1554,6 +1735,13 @@ sso:
#
# This template has no additional variables.
#
+ # * HTML page to display to users if something goes wrong during the
+ # OpenID Connect authentication process: 'sso_error.html'.
+ #
+ # When rendering, this template is given two variables:
+ # * error: the technical name of the error
+ # * error_description: a human-readable message for the error
+ #
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
@@ -1634,8 +1822,8 @@ email:
# Username/password for authentication to the SMTP server. By default, no
# authentication is attempted.
#
- # smtp_user: "exampleusername"
- # smtp_pass: "examplepassword"
+ #smtp_user: "exampleusername"
+ #smtp_pass: "examplepassword"
# Uncomment the following to require TLS transport security for SMTP.
# By default, Synapse will connect over plain text, and will then switch to
@@ -1772,10 +1960,17 @@ password_providers:
# include_content: true
-#spam_checker:
-# module: "my_custom_project.SuperSpamChecker"
-# config:
-# example_option: 'things'
+# Spam checkers are third-party modules that can block specific actions
+# of local users, such as creating rooms and registering undesirable
+# usernames, as well as remote users by redacting incoming events.
+#
+spam_checker:
+ #- module: "my_custom_project.SuperSpamChecker"
+ # config:
+ # example_option: 'things'
+ #- module: "some_other_project.BadEventStopper"
+ # config:
+ # example_stop_events_from: ['@bad:example.com']
# Uncomment to allow non-server-admin users to create groups on this server
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index 5b5f5000..eb10e115 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -64,10 +64,12 @@ class ExampleSpamChecker:
Modify the `spam_checker` section of your `homeserver.yaml` in the following
manner:
-`module` should point to the fully qualified Python class that implements your
-custom logic, e.g. `my_module.ExampleSpamChecker`.
+Create a list entry with the keys `module` and `config`.
-`config` is a dictionary that gets passed to the spam checker class.
+* `module` should point to the fully qualified Python class that implements your
+ custom logic, e.g. `my_module.ExampleSpamChecker`.
+
+* `config` is a dictionary that gets passed to the spam checker class.
### Example
@@ -75,12 +77,15 @@ This section might look like:
```yaml
spam_checker:
- module: my_module.ExampleSpamChecker
- config:
- # Enable or disable a specific option in ExampleSpamChecker.
- my_custom_option: true
+ - module: my_module.ExampleSpamChecker
+ config:
+ # Enable or disable a specific option in ExampleSpamChecker.
+ my_custom_option: true
```
+More spam checkers can be added in tandem by appending more items to the list. An
+action is blocked when at least one of the configured spam checkers flags it.
+
## Examples
The [Mjolnir](https://github.com/matrix-org/mjolnir) project is a full fledged
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
new file mode 100644
index 00000000..abea4323
--- /dev/null
+++ b/docs/sso_mapping_providers.md
@@ -0,0 +1,148 @@
+# SSO Mapping Providers
+
+A mapping provider is a Python class (loaded via a Python module) that
+works out how to map attributes of a SSO response to Matrix-specific
+user attributes. Details such as user ID localpart, displayname, and even avatar
+URLs are all things that can be mapped from talking to a SSO service.
+
+As an example, a SSO service may return the email address
+"john.smith@example.com" for a user, whereas Synapse will need to figure out how
+to turn that into a displayname when creating a Matrix user for this individual.
+It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
+variations. As each Synapse configuration may want something different, this is
+where SAML mapping providers come into play.
+
+SSO mapping providers are currently supported for OpenID and SAML SSO
+configurations. Please see the details below for how to implement your own.
+
+External mapping providers are provided to Synapse in the form of an external
+Python module. You can retrieve this module from [PyPi](https://pypi.org) or elsewhere,
+but it must be importable via Synapse (e.g. it must be in the same virtualenv
+as Synapse). The Synapse config is then modified to point to the mapping provider
+(and optionally provide additional configuration for it).
+
+## OpenID Mapping Providers
+
+The OpenID mapping provider can be customized by editing the
+`oidc_config.user_mapping_provider.module` config option.
+
+`oidc_config.user_mapping_provider.config` allows you to provide custom
+configuration options to the module. Check with the module's documentation for
+what options it provides (if any). The options listed by default are for the
+user mapping provider built in to Synapse. If using a custom module, you should
+comment these options out and use those specified by the module instead.
+
+### Building a Custom OpenID Mapping Provider
+
+A custom mapping provider must specify the following methods:
+
+* `__init__(self, parsed_config)`
+ - Arguments:
+ - `parsed_config` - A configuration object that is the return value of the
+ `parse_config` method. You should set any configuration options needed by
+ the module here.
+* `parse_config(config)`
+ - This method should have the `@staticmethod` decoration.
+ - Arguments:
+ - `config` - A `dict` representing the parsed content of the
+ `oidc_config.user_mapping_provider.config` homeserver config option.
+ Runs on homeserver startup. Providers should extract and validate
+ any option values they need here.
+ - Whatever is returned will be passed back to the user mapping provider module's
+ `__init__` method during construction.
+* `get_remote_user_id(self, userinfo)`
+ - Arguments:
+ - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
+ information from.
+ - This method must return a string, which is the unique identifier for the
+ user. Commonly the ``sub`` claim of the response.
+* `map_user_attributes(self, userinfo, token)`
+ - This method should be async.
+ - Arguments:
+ - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
+ information from.
+ - `token` - A dictionary which includes information necessary to make
+ further requests to the OpenID provider.
+ - Returns a dictionary with two keys:
+ - localpart: A required string, used to generate the Matrix ID.
+ - displayname: An optional string, the display name for the user.
+
+### Default OpenID Mapping Provider
+
+Synapse has a built-in OpenID mapping provider if a custom provider isn't
+specified in the config. It is located at
+[`synapse.handlers.oidc_handler.JinjaOidcMappingProvider`](../synapse/handlers/oidc_handler.py).
+
+## SAML Mapping Providers
+
+The SAML mapping provider can be customized by editing the
+`saml2_config.user_mapping_provider.module` config option.
+
+`saml2_config.user_mapping_provider.config` allows you to provide custom
+configuration options to the module. Check with the module's documentation for
+what options it provides (if any). The options listed by default are for the
+user mapping provider built in to Synapse. If using a custom module, you should
+comment these options out and use those specified by the module instead.
+
+### Building a Custom SAML Mapping Provider
+
+A custom mapping provider must specify the following methods:
+
+* `__init__(self, parsed_config)`
+ - Arguments:
+ - `parsed_config` - A configuration object that is the return value of the
+ `parse_config` method. You should set any configuration options needed by
+ the module here.
+* `parse_config(config)`
+ - This method should have the `@staticmethod` decoration.
+ - Arguments:
+ - `config` - A `dict` representing the parsed content of the
+ `saml_config.user_mapping_provider.config` homeserver config option.
+ Runs on homeserver startup. Providers should extract and validate
+ any option values they need here.
+ - Whatever is returned will be passed back to the user mapping provider module's
+ `__init__` method during construction.
+* `get_saml_attributes(config)`
+ - This method should have the `@staticmethod` decoration.
+ - Arguments:
+ - `config` - A object resulting from a call to `parse_config`.
+ - Returns a tuple of two sets. The first set equates to the SAML auth
+ response attributes that are required for the module to function, whereas
+ the second set consists of those attributes which can be used if available,
+ but are not necessary.
+* `get_remote_user_id(self, saml_response, client_redirect_url)`
+ - Arguments:
+ - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
+ information from.
+ - `client_redirect_url` - A string, the URL that the client will be
+ redirected to.
+ - This method must return a string, which is the unique identifier for the
+ user. Commonly the ``uid`` claim of the response.
+* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
+ - Arguments:
+ - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
+ information from.
+ - `failures` - An `int` that represents the amount of times the returned
+ mxid localpart mapping has failed. This should be used
+ to create a deduplicated mxid localpart which should be
+ returned instead. For example, if this method returns
+ `john.doe` as the value of `mxid_localpart` in the returned
+ dict, and that is already taken on the homeserver, this
+ method will be called again with the same parameters but
+ with failures=1. The method should then return a different
+ `mxid_localpart` value, such as `john.doe1`.
+ - `client_redirect_url` - A string, the URL that the client will be
+ redirected to.
+ - This method must return a dictionary, which will then be used by Synapse
+ to build a new user. The following keys are allowed:
+ * `mxid_localpart` - Required. The mxid localpart of the new user.
+ * `displayname` - The displayname of the new user. If not provided, will default to
+ the value of `mxid_localpart`.
+ * `emails` - A list of emails for the new user. If not provided, will
+ default to an empty list.
+
+### Default SAML Mapping Provider
+
+Synapse has a built-in SAML mapping provider if a custom provider isn't
+specified in the config. It is located at
+[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
diff --git a/docs/systemd-with-workers/system/matrix-synapse-worker@.service b/docs/systemd-with-workers/system/matrix-synapse-worker@.service
index 70589a7a..39bc5e88 100644
--- a/docs/systemd-with-workers/system/matrix-synapse-worker@.service
+++ b/docs/systemd-with-workers/system/matrix-synapse-worker@.service
@@ -1,6 +1,6 @@
[Unit]
Description=Synapse %i
-
+AssertPathExists=/etc/matrix-synapse/workers/%i.yaml
# This service should be restarted when the synapse target is restarted.
PartOf=matrix-synapse.target
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index ab2fffbf..db318baa 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -219,10 +219,6 @@ Asks the server for the current position of all streams.
Inform the server a pusher should be removed
-#### INVALIDATE_CACHE (C)
-
- Inform the server a cache should be invalidated
-
### REMOTE_SERVER_UP (S, C)
Inform other processes that a remote server may have come back online.
diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index b8a2ba3e..d4a726be 100644
--- a/docs/turn-howto.md
+++ b/docs/turn-howto.md
@@ -18,7 +18,7 @@ For TURN relaying with `coturn` to work, it must be hosted on a server/endpoint
Hosting TURN behind a NAT (even with appropriate port forwarding) is known to cause issues
and to often not work.
-## `coturn` Setup
+## `coturn` setup
### Initial installation
@@ -26,7 +26,13 @@ The TURN daemon `coturn` is available from a variety of sources such as native p
#### Debian installation
- # apt install coturn
+Just install the debian package:
+
+```sh
+apt install coturn
+```
+
+This will install and start a systemd service called `coturn`.
#### Source installation
@@ -63,38 +69,52 @@ The TURN daemon `coturn` is available from a variety of sources such as native p
1. Consider your security settings. TURN lets users request a relay which will
connect to arbitrary IP addresses and ports. The following configuration is
suggested as a minimum starting point:
-
+
# VoIP traffic is all UDP. There is no reason to let users connect to arbitrary TCP endpoints via the relay.
no-tcp-relay
-
+
# don't let the relay ever try to connect to private IP address ranges within your network (if any)
# given the turn server is likely behind your firewall, remember to include any privileged public IPs too.
denied-peer-ip=10.0.0.0-10.255.255.255
denied-peer-ip=192.168.0.0-192.168.255.255
denied-peer-ip=172.16.0.0-172.31.255.255
-
+
# special case the turn server itself so that client->TURN->TURN->client flows work
allowed-peer-ip=10.0.0.1
-
+
# consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS.
user-quota=12 # 4 streams per video call, so 12 streams = 3 simultaneous relayed calls per user.
total-quota=1200
- Ideally coturn should refuse to relay traffic which isn't SRTP; see
- <https://github.com/matrix-org/synapse/issues/2009>
+1. Also consider supporting TLS/DTLS. To do this, add the following settings
+ to `turnserver.conf`:
+
+ # TLS certificates, including intermediate certs.
+ # For Let's Encrypt certificates, use `fullchain.pem` here.
+ cert=/path/to/fullchain.pem
+
+ # TLS private key file
+ pkey=/path/to/privkey.pem
1. Ensure your firewall allows traffic into the TURN server on the ports
- you've configured it to listen on (remember to allow both TCP and UDP TURN
- traffic)
+ you've configured it to listen on (By default: 3478 and 5349 for the TURN(s)
+ traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535
+ for the UDP relay.)
+
+1. (Re)start the turn server:
-1. If you've configured coturn to support TLS/DTLS, generate or import your
- private key and certificate.
+ * If you used the Debian package (or have set up a systemd unit yourself):
+ ```sh
+ systemctl restart coturn
+ ```
-1. Start the turn server:
+ * If you installed from source:
- bin/turnserver -o
+ ```sh
+ bin/turnserver -o
+ ```
-## synapse Setup
+## Synapse setup
Your home server configuration file needs the following extra keys:
@@ -126,7 +146,14 @@ As an example, here is the relevant section of the config file for matrix.org:
After updating the homeserver configuration, you must restart synapse:
+ * If you use synctl:
+ ```sh
cd /where/you/run/synapse
./synctl restart
+ ```
+ * If you use systemd:
+ ```
+ systemctl restart synapse.service
+ ```
..and your Home Server now supports VoIP relaying!
diff --git a/mypy.ini b/mypy.ini
index 69be2f67..3533797d 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -75,3 +75,6 @@ ignore_missing_imports = True
[mypy-jwt.*]
ignore_missing_imports = True
+
+[mypy-authlib.*]
+ignore_missing_imports = True
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index ae2145d7..e6f4bd1d 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -24,8 +24,6 @@ DISTS = (
"debian:sid",
"ubuntu:xenial",
"ubuntu:bionic",
- "ubuntu:cosmic",
- "ubuntu:disco",
"ubuntu:eoan",
"ubuntu:focal",
)
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index 0ec5075e..98a618f6 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -7,7 +7,9 @@ set -e
# make sure that origin/develop is up to date
git remote set-branches --add origin develop
-git fetch origin develop
+git fetch -q origin develop
+
+pr="$BUILDKITE_PULL_REQUEST"
# if there are changes in the debian directory, check that the debian changelog
# has been updated
@@ -20,20 +22,30 @@ fi
# if there are changes *outside* the debian directory, check that the
# newsfragments have been updated.
-if git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then
- tox -e check-newsfragment
+if ! git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then
+ exit 0
fi
+tox -qe check-newsfragment
+
echo
echo "--------------------------"
echo
-# check that any new newsfiles on this branch end with a full stop.
+matched=0
for f in `git diff --name-only FETCH_HEAD... -- changelog.d`; do
+ # check that any modified newsfiles on this branch end with a full stop.
lastchar=`tr -d '\n' < $f | tail -c 1`
if [ $lastchar != '.' -a $lastchar != '!' ]; then
echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2
exit 1
fi
+
+ # see if this newsfile corresponds to the right PR
+ [[ -n "$pr" && "$f" == changelog.d/"$pr".* ]] && matched=1
done
+if [[ -n "$pr" && "$matched" -eq 0 ]]; then
+ echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2
+ exit 1
+fi
diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py
index 06b4c1e2..961dc59f 100644
--- a/scripts-dev/convert_server_keys.py
+++ b/scripts-dev/convert_server_keys.py
@@ -3,8 +3,6 @@ import json
import sys
import time
-import six
-
import psycopg2
import yaml
from canonicaljson import encode_canonical_json
@@ -12,10 +10,7 @@ from signedjson.key import read_signing_keys
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
-if six.PY2:
- db_type = six.moves.builtins.buffer
-else:
- db_type = memoryview
+db_binary_type = memoryview
def select_v1_keys(connection):
@@ -72,7 +67,7 @@ def rows_v2(server, json):
valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]:
- yield (server, key_id, "-", valid_until, valid_until, db_type(key_json))
+ yield (server, key_id, "-", valid_until, valid_until, db_binary_type(key_json))
def main():
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index e8b698f3..9a0fbc61 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [
"presence_stream",
"push_rules_stream",
"ex_outlier_stream",
- "cache_invalidation_stream",
+ "cache_invalidation_stream_by_instance",
"public_room_list_stream",
"state_group_edges",
"stream_ordering_to_exterm",
@@ -188,7 +188,7 @@ class MockHomeserver:
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server_name
- self.version_string = "Synapse/"+get_version_string(synapse)
+ self.version_string = "Synapse/" + get_version_string(synapse)
def get_clock(self):
return self.clock
@@ -196,6 +196,9 @@ class MockHomeserver:
def get_reactor(self):
return reactor
+ def get_instance_name(self):
+ return "master"
+
class Porter(object):
def __init__(self, **kwargs):
diff --git a/setup.py b/setup.py
index 5ce06c89..54ddec8f 100755
--- a/setup.py
+++ b/setup.py
@@ -114,6 +114,7 @@ setup(
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
],
scripts=["synctl"] + glob.glob("scripts/*"),
cmdclass={"test": TestCommand},
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 0abf4911..1d9d85a7 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.13.0"
+__version__ = "1.15.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1ad5ff94..06ade256 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -22,6 +22,7 @@ import pymacaroons
from netaddr import IPAddress
from twisted.internet import defer
+from twisted.web.server import Request
import synapse.logging.opentracing as opentracing
import synapse.types
@@ -37,7 +38,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import StateMap, UserID
-from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
+from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -73,7 +74,7 @@ class Auth(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
+ self.token_cache = LruCache(10000)
register_cache("cache", "token_cache", self.token_cache)
self._auth_blocking = AuthBlocking(self.hs)
@@ -162,19 +163,25 @@ class Auth(object):
@defer.inlineCallbacks
def get_user_by_req(
- self, request, allow_guest=False, rights="access", allow_expired=False
+ self,
+ request: Request,
+ allow_guest: bool = False,
+ rights: str = "access",
+ allow_expired: bool = False,
):
""" Get a registered user's ID.
Args:
- request - An HTTP request with an access_token query parameter.
- allow_expired - Whether to allow the request through even if the account is
- expired. If true, Synapse will still require an access token to be
- provided but won't check if the account it belongs to has expired. This
- works thanks to /login delivering access tokens regardless of accounts'
- expiration.
+ request: An HTTP request with an access_token query parameter.
+ allow_guest: If False, will raise an AuthError if the user making the
+ request is a guest.
+ rights: The operation being performed; the access token must allow this
+ allow_expired: If True, allow the request through even if the account
+ is expired, or session token lifetime has ended. Note that
+ /login will deliver access tokens regardless of expiration.
+
Returns:
- defer.Deferred: resolves to a ``synapse.types.Requester`` object
+ defer.Deferred: resolves to a `synapse.types.Requester` object
Raises:
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
@@ -205,7 +212,9 @@ class Auth(object):
return synapse.types.create_requester(user_id, app_service=app_service)
- user_info = yield self.get_user_by_access_token(access_token, rights)
+ user_info = yield self.get_user_by_access_token(
+ access_token, rights, allow_expired=allow_expired
+ )
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
@@ -280,13 +289,17 @@ class Auth(object):
return user_id, app_service
@defer.inlineCallbacks
- def get_user_by_access_token(self, token, rights="access"):
+ def get_user_by_access_token(
+ self, token: str, rights: str = "access", allow_expired: bool = False,
+ ):
""" Validate access token and get user_id from it
Args:
- token (str): The access token to get the user by.
- rights (str): The operation being performed; the access token must
- allow this.
+ token: The access token to get the user by
+ rights: The operation being performed; the access token must
+ allow this
+ allow_expired: If False, raises an InvalidClientTokenError
+ if the token is expired
Returns:
Deferred[dict]: dict that includes:
`user` (UserID)
@@ -294,8 +307,10 @@ class Auth(object):
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
+ InvalidClientTokenError if a user by that token exists, but the token is
+ expired
InvalidClientCredentialsError if no user by that token exists or the token
- is invalid.
+ is invalid
"""
if rights == "access":
@@ -304,7 +319,8 @@ class Auth(object):
if r:
valid_until_ms = r["valid_until_ms"]
if (
- valid_until_ms is not None
+ not allow_expired
+ and valid_until_ms is not None
and valid_until_ms < self.clock.time_msec()
):
# there was a valid access token, but it has expired.
@@ -494,16 +510,16 @@ class Auth(object):
request.authenticated_entity = service.sender
return defer.succeed(service)
- def is_server_admin(self, user):
+ async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin.
Args:
- user (UserID): user to check
+ user: user to check
Returns:
- bool: True if the user is an admin
+ True if the user is an admin
"""
- return self.store.is_server_admin(user)
+ return await self.store.is_server_admin(user)
def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
@@ -575,7 +591,7 @@ class Auth(object):
return user_level >= send_level
@staticmethod
- def has_access_token(request):
+ def has_access_token(request: Request):
"""Checks if the request has an access_token.
Returns:
@@ -586,7 +602,7 @@ class Auth(object):
return bool(query_params) or bool(auth_headers)
@staticmethod
- def get_access_token_from_request(request):
+ def get_access_token_from_request(request: Request):
"""Extracts the access_token from the request.
Args:
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index bcaf2c36..5ec4a77c 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -61,13 +61,9 @@ class LoginType(object):
MSISDN = "m.login.msisdn"
RECAPTCHA = "m.login.recaptcha"
TERMS = "m.login.terms"
- SSO = "org.matrix.login.sso"
+ SSO = "m.login.sso"
DUMMY = "m.login.dummy"
- # Only for C/S API v1
- APPLICATION_SERVICE = "m.login.application_service"
- SHARED_SECRET = "org.matrix.login.shared_secret"
-
class EventTypes(object):
Member = "m.room.member"
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 7a049b3a..ec6b3a69 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,75 +17,157 @@ from collections import OrderedDict
from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
+from synapse.util import Clock
class Ratelimiter(object):
"""
- Ratelimit message sending by user.
+ Ratelimit actions marked by arbitrary keys.
+
+ Args:
+ clock: A homeserver clock, for retrieving the current time
+ rate_hz: The long term number of actions that can be performed in a second.
+ burst_count: How many actions that can be performed before being limited.
"""
- def __init__(self):
- self.message_counts = (
- OrderedDict()
- ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]]
+ def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
+ self.clock = clock
+ self.rate_hz = rate_hz
+ self.burst_count = burst_count
+
+ # A ordered dictionary keeping track of actions, when they were last
+ # performed and how often. Each entry is a mapping from a key of arbitrary type
+ # to a tuple representing:
+ # * How many times an action has occurred since a point in time
+ # * The point in time
+ # * The rate_hz of this particular entry. This can vary per request
+ self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
- def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
+ def can_do_action(
+ self,
+ key: Any,
+ rate_hz: Optional[float] = None,
+ burst_count: Optional[int] = None,
+ update: bool = True,
+ _time_now_s: Optional[int] = None,
+ ) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
+
Args:
key: The key we should use when rate limiting. Can be a user ID
(when sending events), an IP address, etc.
- time_now_s: The time now.
- rate_hz: The long term number of messages a user can send in a
- second.
- burst_count: How many messages the user can send before being
- limited.
- update (bool): Whether to update the message rates or not. This is
- useful to check if a message would be allowed to be sent before
- its ready to be actually sent.
+ rate_hz: The long term number of actions that can be performed in a second.
+ Overrides the value set during instantiation if set.
+ burst_count: How many actions that can be performed before being limited.
+ Overrides the value set during instantiation if set.
+ update: Whether to count this check as performing the action
+ _time_now_s: The current time. Optional, defaults to the current time according
+ to self.clock. Only used by tests.
+
Returns:
- A pair of a bool indicating if they can send a message now and a
- time in seconds of when they can next send a message.
+ A tuple containing:
+ * A bool indicating if they can perform the action now
+ * The reactor timestamp for when the action can be performed next.
+ -1 if rate_hz is less than or equal to zero
"""
- self.prune_message_counts(time_now_s)
- message_count, time_start, _ignored = self.message_counts.get(
- key, (0.0, time_now_s, None)
- )
+ # Override default values if set
+ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
+ rate_hz = rate_hz if rate_hz is not None else self.rate_hz
+ burst_count = burst_count if burst_count is not None else self.burst_count
+
+ # Remove any expired entries
+ self._prune_message_counts(time_now_s)
+
+ # Check if there is an existing count entry for this key
+ action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+
+ # Check whether performing another action is allowed
time_delta = time_now_s - time_start
- sent_count = message_count - time_delta * rate_hz
- if sent_count < 0:
+ performed_count = action_count - time_delta * rate_hz
+ if performed_count < 0:
+ # Allow, reset back to count 1
allowed = True
time_start = time_now_s
- message_count = 1.0
- elif sent_count > burst_count - 1.0:
+ action_count = 1.0
+ elif performed_count > burst_count - 1.0:
+ # Deny, we have exceeded our burst count
allowed = False
else:
+ # We haven't reached our limit yet
allowed = True
- message_count += 1
+ action_count += 1.0
if update:
- self.message_counts[key] = (message_count, time_start, rate_hz)
+ self.actions[key] = (action_count, time_start, rate_hz)
if rate_hz > 0:
- time_allowed = time_start + (message_count - burst_count + 1) / rate_hz
+ # Find out when the count of existing actions expires
+ time_allowed = time_start + (action_count - burst_count + 1) / rate_hz
+
+ # Don't give back a time in the past
if time_allowed < time_now_s:
time_allowed = time_now_s
+
else:
+ # XXX: Why is this -1? This seems to only be used in
+ # self.ratelimit. I guess so that clients get a time in the past and don't
+ # feel afraid to try again immediately
time_allowed = -1
return allowed, time_allowed
- def prune_message_counts(self, time_now_s):
- for key in list(self.message_counts.keys()):
- message_count, time_start, rate_hz = self.message_counts[key]
+ def _prune_message_counts(self, time_now_s: int):
+ """Remove message count entries that have not exceeded their defined
+ rate_hz limit
+
+ Args:
+ time_now_s: The current time
+ """
+ # We create a copy of the key list here as the dictionary is modified during
+ # the loop
+ for key in list(self.actions.keys()):
+ action_count, time_start, rate_hz = self.actions[key]
+
+ # Rate limit = "seconds since we started limiting this action" * rate_hz
+ # If this limit has not been exceeded, wipe our record of this action
time_delta = time_now_s - time_start
- if message_count - time_delta * rate_hz > 0:
- break
+ if action_count - time_delta * rate_hz > 0:
+ continue
else:
- del self.message_counts[key]
+ del self.actions[key]
+
+ def ratelimit(
+ self,
+ key: Any,
+ rate_hz: Optional[float] = None,
+ burst_count: Optional[int] = None,
+ update: bool = True,
+ _time_now_s: Optional[int] = None,
+ ):
+ """Checks if an action can be performed. If not, raises a LimitExceededError
+
+ Args:
+ key: An arbitrary key used to classify an action
+ rate_hz: The long term number of actions that can be performed in a second.
+ Overrides the value set during instantiation if set.
+ burst_count: How many actions that can be performed before being limited.
+ Overrides the value set during instantiation if set.
+ update: Whether to count this check as performing the action
+ _time_now_s: The current time. Optional, defaults to the current time according
+ to self.clock. Only used by tests.
+
+ Raises:
+ LimitExceededError: If an action could not be performed, along with the time in
+ milliseconds until the action can be performed again
+ """
+ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
- def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True):
allowed, time_allowed = self.can_do_action(
- key, time_now_s, rate_hz, burst_count, update
+ key,
+ rate_hz=rate_hz,
+ burst_count=burst_count,
+ update=update,
+ _time_now_s=time_now_s,
)
if not allowed:
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 87117974..d7baf2bc 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -58,7 +58,15 @@ class RoomVersion(object):
enforce_key_validity = attr.ib() # bool
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
- special_case_aliases_auth = attr.ib(type=bool, default=False)
+ special_case_aliases_auth = attr.ib(type=bool)
+ # Strictly enforce canonicaljson, do not allow:
+ # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
+ # * Floats
+ # * NaN, Infinity, -Infinity
+ strict_canonicaljson = attr.ib(type=bool)
+ # bool: MSC2209: Check 'notifications' key while verifying
+ # m.room.power_levels auth rules.
+ limit_notifications_power_levels = attr.ib(type=bool)
class RoomVersions(object):
@@ -69,6 +77,8 @@ class RoomVersions(object):
StateResolutionVersions.V1,
enforce_key_validity=False,
special_case_aliases_auth=True,
+ strict_canonicaljson=False,
+ limit_notifications_power_levels=False,
)
V2 = RoomVersion(
"2",
@@ -77,6 +87,8 @@ class RoomVersions(object):
StateResolutionVersions.V2,
enforce_key_validity=False,
special_case_aliases_auth=True,
+ strict_canonicaljson=False,
+ limit_notifications_power_levels=False,
)
V3 = RoomVersion(
"3",
@@ -85,6 +97,8 @@ class RoomVersions(object):
StateResolutionVersions.V2,
enforce_key_validity=False,
special_case_aliases_auth=True,
+ strict_canonicaljson=False,
+ limit_notifications_power_levels=False,
)
V4 = RoomVersion(
"4",
@@ -93,6 +107,8 @@ class RoomVersions(object):
StateResolutionVersions.V2,
enforce_key_validity=False,
special_case_aliases_auth=True,
+ strict_canonicaljson=False,
+ limit_notifications_power_levels=False,
)
V5 = RoomVersion(
"5",
@@ -101,14 +117,18 @@ class RoomVersions(object):
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=True,
+ strict_canonicaljson=False,
+ limit_notifications_power_levels=False,
)
- MSC2432_DEV = RoomVersion(
- "org.matrix.msc2432",
- RoomDisposition.UNSTABLE,
+ V6 = RoomVersion(
+ "6",
+ RoomDisposition.STABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
)
@@ -120,6 +140,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V3,
RoomVersions.V4,
RoomVersions.V5,
- RoomVersions.MSC2432_DEV,
+ RoomVersions.V6,
)
} # type: Dict[str, RoomVersion]
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 667ad204..f3ec2a34 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -17,17 +17,15 @@
import contextlib
import logging
import sys
-from typing import Dict, Iterable
+from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
import synapse
import synapse.events
-from synapse.api.constants import EventTypes
-from synapse.api.errors import HttpResponseException, SynapseError
+from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.api.urls import (
CLIENT_API_PREFIX,
FEDERATION_PREFIX,
@@ -41,13 +39,22 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.federation import send_queue
from synapse.federation.transport.server import TransportLayerServer
-from synapse.handlers.presence import BasePresenceHandler, get_interested_parties
-from synapse.http.server import JsonResource
+from synapse.handlers.presence import (
+ BasePresenceHandler,
+ PresenceState,
+ get_interested_parties,
+)
+from synapse.http.server import JsonResource, OptionsResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
+from synapse.replication.http.presence import (
+ ReplicationBumpPresenceActiveTime,
+ ReplicationPresenceSetState,
+)
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -81,11 +88,6 @@ from synapse.replication.tcp.streams import (
ToDeviceStream,
TypingStream,
)
-from synapse.replication.tcp.streams.events import (
- EventsStream,
- EventsStreamEventRow,
- EventsStreamRow,
-)
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -122,11 +124,13 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
+from synapse.storage.data_stores.main.censor_events import CensorEventsStore
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
+from synapse.storage.data_stores.main.search import SearchWorkerStore
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
@@ -140,31 +144,18 @@ logger = logging.getLogger("synapse.app.generic_worker")
class PresenceStatusStubServlet(RestServlet):
"""If presence is disabled this servlet can be used to stub out setting
- presence status, while proxying the getters to the master instance.
+ presence status.
"""
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
super(PresenceStatusStubServlet, self).__init__()
- self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
- self.main_uri = hs.config.worker_main_http_uri
async def on_GET(self, request, user_id):
- # Pass through the auth headers, if any, in case the access token
- # is there.
- auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
- headers = {"Authorization": auth_headers}
-
- try:
- result = await self.http_client.get_json(
- self.main_uri + request.uri.decode("ascii"), headers=headers
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
-
- return 200, result
+ await self.auth.get_user_by_req(request)
+ return 200, {"presence": "offline"}
async def on_PUT(self, request, user_id):
await self.auth.get_user_by_req(request)
@@ -218,9 +209,14 @@ class KeyUploadServlet(RestServlet):
# is there.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
headers = {"Authorization": auth_headers}
- result = await self.http_client.post_json_get_json(
- self.main_uri + request.uri.decode("ascii"), body, headers=headers
- )
+ try:
+ result = await self.http_client.post_json_get_json(
+ self.main_uri + request.uri.decode("ascii"), body, headers=headers
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error() from e
+ except RequestSendFailed as e:
+ raise SynapseError(502, "Failed to talk to master") from e
return 200, result
else:
@@ -259,6 +255,9 @@ class GenericWorkerPresence(BasePresenceHandler):
# but we haven't notified the master of that yet
self.users_going_offline = {}
+ self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
+ self._set_state_client = ReplicationPresenceSetState.make_client(hs)
+
self._send_stop_syncing_loop = self.clock.looping_call(
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
@@ -316,10 +315,6 @@ class GenericWorkerPresence(BasePresenceHandler):
self.users_going_offline.pop(user_id, None)
self.send_user_sync(user_id, False, last_sync_ms)
- def set_state(self, user, state, ignore_status_msg=False):
- # TODO Hows this supposed to work?
- return defer.succeed(None)
-
async def user_syncing(
self, user_id: str, affect_presence: bool
) -> ContextManager[None]:
@@ -398,6 +393,42 @@ class GenericWorkerPresence(BasePresenceHandler):
if count > 0
]
+ async def set_state(self, target_user, state, ignore_status_msg=False):
+ """Set the presence state of the user.
+ """
+ presence = state["presence"]
+
+ valid_presence = (
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ )
+ if presence not in valid_presence:
+ raise SynapseError(400, "Invalid presence state")
+
+ user_id = target_user.to_string()
+
+ # If presence is disabled, no-op
+ if not self.hs.config.use_presence:
+ return
+
+ # Proxy request to master
+ await self._set_state_client(
+ user_id=user_id, state=state, ignore_status_msg=ignore_status_msg
+ )
+
+ async def bump_presence_active_time(self, user):
+ """We've seen the user do something that indicates they're interacting
+ with the app.
+ """
+ # If presence is disabled, no-op
+ if not self.hs.config.use_presence:
+ return
+
+ # Proxy request to master
+ user_id = user.to_string()
+ await self._bump_active_client(user_id=user_id)
+
class GenericWorkerTyping(object):
def __init__(self, hs):
@@ -442,6 +473,7 @@ class GenericWorkerSlavedStore(
SlavedGroupServerStore,
SlavedAccountDataStore,
SlavedPusherStore,
+ CensorEventsStore,
SlavedEventStore,
SlavedKeyStore,
RoomStore,
@@ -455,6 +487,7 @@ class GenericWorkerSlavedStore(
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
+ SearchWorkerStore,
BaseSlavedStore,
):
def __init__(self, database, db_conn, hs):
@@ -572,7 +605,10 @@ class GenericWorkerServer(HomeServer):
if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
- root_resource = create_resource_tree(resources, NoResource())
+ if name == "replication":
+ resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
+
+ root_resource = create_resource_tree(resources, OptionsResource())
_base.listen_tcp(
bind_addresses,
@@ -631,7 +667,7 @@ class GenericWorkerServer(HomeServer):
class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
- super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
+ super(GenericWorkerReplicationHandler, self).__init__(hs)
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
@@ -641,10 +677,9 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
self.notify_pushers = hs.config.start_pushers
self.pusher_pool = hs.get_pusherpool()
+ self.send_handler = None # type: Optional[FederationSenderHandler]
if hs.config.send_federation:
- self.send_handler = FederationSenderHandler(hs, self)
- else:
- self.send_handler = None
+ self.send_handler = FederationSenderHandler(hs)
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
@@ -657,30 +692,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
stream_name, token, rows
)
- if stream_name == EventsStream.NAME:
- # We shouldn't get multiple rows per token for events stream, so
- # we don't need to optimise this for multiple rows.
- for row in rows:
- if row.type != EventsStreamEventRow.TypeId:
- continue
- assert isinstance(row, EventsStreamRow)
-
- event = await self.store.get_event(
- row.data.event_id, allow_rejected=True
- )
- if event.rejected_reason:
- continue
-
- extra_users = ()
- if event.type == EventTypes.Member:
- extra_users = (event.state_key,)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(
- event, token, max_token, extra_users
- )
-
- await self.pusher_pool.on_new_notifications(token, token)
- elif stream_name == PushRulesStream.NAME:
+ if stream_name == PushRulesStream.NAME:
self.notifier.on_new_event(
"push_rules_key", token, users=[row.user_id for row in rows]
)
@@ -705,7 +717,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
if entities:
self.notifier.on_new_event("to_device_key", token, users=entities)
elif stream_name == DeviceListsStream.NAME:
- all_room_ids = set()
+ all_room_ids = set() # type: Set[str]
for row in rows:
if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
@@ -756,24 +768,33 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
class FederationSenderHandler(object):
- """Processes the replication stream and forwards the appropriate entries
- to the federation sender.
+ """Processes the fedration replication stream
+
+ This class is only instantiate on the worker responsible for sending outbound
+ federation transactions. It receives rows from the replication stream and forwards
+ the appropriate entries to the FederationSender class.
"""
- def __init__(self, hs: GenericWorkerServer, replication_client):
+ def __init__(self, hs: GenericWorkerServer):
self.store = hs.get_datastore()
self._is_mine_id = hs.is_mine_id
self.federation_sender = hs.get_federation_sender()
- self.replication_client = replication_client
-
+ self._hs = hs
+
+ # if the worker is restarted, we want to pick up where we left off in
+ # the replication stream, so load the position from the database.
+ #
+ # XXX is this actually worthwhile? Whenever the master is restarted, we'll
+ # drop some rows anyway (which is mostly fine because we're only dropping
+ # typing and presence notifications). If the replication stream is
+ # unreliable, why do we do all this hoop-jumping to store the position in the
+ # database? See also https://github.com/matrix-org/synapse/issues/7535.
+ #
self.federation_position = self.store.federation_out_pos_startup
- self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+ self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
self._last_ack = self.federation_position
- self._room_serials = {}
- self._room_typing = {}
-
def on_start(self):
# There may be some events that are persisted but haven't been sent,
# so send them now.
@@ -836,22 +857,51 @@ class FederationSenderHandler(object):
await self.federation_sender.send_read_receipt(receipt_info)
async def update_token(self, token):
- try:
- self.federation_position = token
+ """Update the record of where we have processed to in the federation stream.
+
+ Called after we have processed a an update received over replication. Sends
+ a FEDERATION_ACK back to the master, and stores the token that we have processed
+ in `federation_stream_position` so that we can restart where we left off.
+ """
+ self.federation_position = token
+
+ # We save and send the ACK to master asynchronously, so we don't block
+ # processing on persistence. We don't need to do this operation for
+ # every single RDATA we receive, we just need to do it periodically.
+
+ if self._fed_position_linearizer.is_queued(None):
+ # There is already a task queued up to save and send the token, so
+ # no need to queue up another task.
+ return
+
+ run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
+ async def _save_and_send_ack(self):
+ """Save the current federation position in the database and send an ACK
+ to master with where we're up to.
+ """
+ try:
# We linearize here to ensure we don't have races updating the token
+ #
+ # XXX this appears to be redundant, since the ReplicationCommandHandler
+ # has a linearizer which ensures that we only process one line of
+ # replication data at a time. Should we remove it, or is it doing useful
+ # service for robustness? Or could we replace it with an assertion that
+ # we're not being re-entered?
+
with (await self._fed_position_linearizer.queue(None)):
- if self._last_ack < self.federation_position:
- await self.store.update_federation_out_pos(
- "federation", self.federation_position
- )
+ # We persist and ack the same position, so we take a copy of it
+ # here as otherwise it can get modified from underneath us.
+ current_position = self.federation_position
- # We ACK this token over replication so that the master can drop
- # its in memory queues
- self.replication_client.send_federation_ack(
- self.federation_position
- )
- self._last_ack = self.federation_position
+ await self.store.update_federation_out_pos(
+ "federation", current_position
+ )
+
+ # We ACK this token over replication so that the master can drop
+ # its in memory queues
+ self._hs.get_tcp_replication().send_federation_ack(current_position)
+ self._last_ack = current_position
except Exception:
logger.exception("Error updating federation stream position")
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index cbd1ea47..8454d748 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -31,7 +31,7 @@ from prometheus_client import Gauge
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
-from twisted.web.resource import EncodingResourceWrapper, IResource, NoResource
+from twisted.web.resource import EncodingResourceWrapper, IResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@@ -52,7 +52,11 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
-from synapse.http.server import RootRedirect
+from synapse.http.server import (
+ OptionsResource,
+ RootOptionsRedirectResource,
+ RootRedirect,
+)
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
@@ -69,7 +73,6 @@ from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.prepare_database import UpgradeDatabaseException
-from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
@@ -122,11 +125,11 @@ class SynapseHomeServer(HomeServer):
# try to find something useful to redirect '/' to
if WEB_CLIENT_PREFIX in resources:
- root_resource = RootRedirect(WEB_CLIENT_PREFIX)
+ root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
elif STATIC_PREFIX in resources:
- root_resource = RootRedirect(STATIC_PREFIX)
+ root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
else:
- root_resource = NoResource()
+ root_resource = OptionsResource()
root_resource = create_resource_tree(resources, root_resource)
@@ -192,6 +195,11 @@ class SynapseHomeServer(HomeServer):
}
)
+ if self.get_config().oidc_enabled:
+ from synapse.rest.oidc import OIDCResource
+
+ resources["/_synapse/oidc"] = OIDCResource(self)
+
if self.get_config().saml2_enabled:
from synapse.rest.saml2 import SAML2Resource
@@ -422,6 +430,13 @@ def setup(config_options):
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
+ # Load the OIDC provider metadatas, if OIDC is enabled.
+ if hs.config.oidc_enabled:
+ oidc = hs.get_oidc_handler()
+ # Loading the provider metadata also ensures the provider config is valid.
+ yield defer.ensureDeferred(oidc.load_metadata())
+ yield defer.ensureDeferred(oidc.load_jwks())
+
_base.start(hs, config.listeners)
hs.get_datastore().db.updates.start_doing_background_updates()
@@ -473,6 +488,29 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
if uptime < 0:
uptime = 0
+ #
+ # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
+ #
+ old = stats_process[0]
+ new = (now, resource.getrusage(resource.RUSAGE_SELF))
+ stats_process[0] = new
+
+ # Get RSS in bytes
+ stats["memory_rss"] = new[1].ru_maxrss
+
+ # Get CPU time in % of a single core, not % of all cores
+ used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
+ old[1].ru_utime + old[1].ru_stime
+ )
+ if used_cpu_time == 0 or new[0] == old[0]:
+ stats["cpu_average"] = 0
+ else:
+ stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
+
+ #
+ # General statistics
+ #
+
stats["homeserver"] = hs.config.server_name
stats["server_context"] = hs.config.server_context
stats["timestamp"] = now
@@ -504,27 +542,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
- stats["cache_factor"] = CACHE_SIZE_FACTOR
- stats["event_cache_size"] = hs.config.event_cache_size
-
- #
- # Performance statistics
- #
- old = stats_process[0]
- new = (now, resource.getrusage(resource.RUSAGE_SELF))
- stats_process[0] = new
-
- # Get RSS in bytes
- stats["memory_rss"] = new[1].ru_maxrss
-
- # Get CPU time in % of a single core, not % of all cores
- used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
- old[1].ru_utime + old[1].ru_stime
- )
- if used_cpu_time == 0 or new[0] == old[0]:
- stats["cpu_average"] = 0
- else:
- stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
+ stats["cache_factor"] = hs.config.caches.global_factor
+ stats["event_cache_size"] = hs.config.caches.event_cache_size
#
# Database version
@@ -602,18 +621,17 @@ def run(hs):
clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
reap_monthly_active_users()
- @defer.inlineCallbacks
- def generate_monthly_active_users():
+ async def generate_monthly_active_users():
current_mau_count = 0
current_mau_count_by_service = {}
reserved_users = ()
store = hs.get_datastore()
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
- current_mau_count = yield store.get_monthly_active_count()
+ current_mau_count = await store.get_monthly_active_count()
current_mau_count_by_service = (
- yield store.get_monthly_active_count_by_service()
+ await store.get_monthly_active_count_by_service()
)
- reserved_users = yield store.get_registered_reserved_users()
+ reserved_users = await store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count))
for app_service, count in current_mau_count_by_service.items():
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index aea3985a..1b13e844 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -270,7 +270,7 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
- def get_exlusive_user_regexes(self):
+ def get_exclusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 3053fc9d..9e576060 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -13,6 +13,7 @@ from synapse.config import (
key,
logger,
metrics,
+ oidc_config,
password,
password_auth_providers,
push,
@@ -59,6 +60,7 @@ class RootConfig:
saml2: saml2_config.SAML2Config
cas: cas.CasConfig
sso: sso.SSOConfig
+ oidc: oidc_config.OIDCConfig
jwt: jwt_config.JWTConfig
password: password.PasswordConfig
email: emailconfig.EmailConfig
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
new file mode 100644
index 00000000..06725387
--- /dev/null
+++ b/synapse/config/cache.py
@@ -0,0 +1,198 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+from typing import Callable, Dict
+
+from ._base import Config, ConfigError
+
+# The prefix for all cache factor-related environment variables
+_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
+
+# Map from canonicalised cache name to cache.
+_CACHES = {}
+
+_DEFAULT_FACTOR_SIZE = 0.5
+_DEFAULT_EVENT_CACHE_SIZE = "10K"
+
+
+class CacheProperties(object):
+ def __init__(self):
+ # The default factor size for all caches
+ self.default_factor_size = float(
+ os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
+ )
+ self.resize_all_caches_func = None
+
+
+properties = CacheProperties()
+
+
+def _canonicalise_cache_name(cache_name: str) -> str:
+ """Gets the canonical form of the cache name.
+
+ Since we specify cache names in config and environment variables we need to
+ ignore case and special characters. For example, some caches have asterisks
+ in their name to denote that they're not attached to a particular database
+ function, and these asterisks need to be stripped out
+ """
+
+ cache_name = re.sub(r"[^A-Za-z_1-9]", "", cache_name)
+
+ return cache_name.lower()
+
+
+def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
+ """Register a cache that's size can dynamically change
+
+ Args:
+ cache_name: A reference to the cache
+ cache_resize_callback: A callback function that will be ran whenever
+ the cache needs to be resized
+ """
+ # Some caches have '*' in them which we strip out.
+ cache_name = _canonicalise_cache_name(cache_name)
+
+ _CACHES[cache_name] = cache_resize_callback
+
+ # Ensure all loaded caches are sized appropriately
+ #
+ # This method should only run once the config has been read,
+ # as it uses values read from it
+ if properties.resize_all_caches_func:
+ properties.resize_all_caches_func()
+
+
+class CacheConfig(Config):
+ section = "caches"
+ _environ = os.environ
+
+ @staticmethod
+ def reset():
+ """Resets the caches to their defaults. Used for tests."""
+ properties.default_factor_size = float(
+ os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
+ )
+ properties.resize_all_caches_func = None
+ _CACHES.clear()
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ ## Caching ##
+
+ # Caching can be configured through the following options.
+ #
+ # A cache 'factor' is a multiplier that can be applied to each of
+ # Synapse's caches in order to increase or decrease the maximum
+ # number of entries that can be stored.
+
+ # The number of events to cache in memory. Not affected by
+ # caches.global_factor.
+ #
+ #event_cache_size: 10K
+
+ caches:
+ # Controls the global cache factor, which is the default cache factor
+ # for all caches if a specific factor for that cache is not otherwise
+ # set.
+ #
+ # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
+ # variable. Setting by environment variable takes priority over
+ # setting through the config file.
+ #
+ # Defaults to 0.5, which will half the size of all caches.
+ #
+ #global_factor: 1.0
+
+ # A dictionary of cache name to cache factor for that individual
+ # cache. Overrides the global cache factor for a given cache.
+ #
+ # These can also be set through environment variables comprised
+ # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
+ # letters and underscores. Setting by environment variable
+ # takes priority over setting through the config file.
+ # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
+ #
+ # Some caches have '*' and other characters that are not
+ # alphanumeric or underscores. These caches can be named with or
+ # without the special characters stripped. For example, to specify
+ # the cache factor for `*stateGroupCache*` via an environment
+ # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
+ #
+ per_cache_factors:
+ #get_users_who_share_room_with_user: 2.0
+ """
+
+ def read_config(self, config, **kwargs):
+ self.event_cache_size = self.parse_size(
+ config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
+ )
+ self.cache_factors = {} # type: Dict[str, float]
+
+ cache_config = config.get("caches") or {}
+ self.global_factor = cache_config.get(
+ "global_factor", properties.default_factor_size
+ )
+ if not isinstance(self.global_factor, (int, float)):
+ raise ConfigError("caches.global_factor must be a number.")
+
+ # Set the global one so that it's reflected in new caches
+ properties.default_factor_size = self.global_factor
+
+ # Load cache factors from the config
+ individual_factors = cache_config.get("per_cache_factors") or {}
+ if not isinstance(individual_factors, dict):
+ raise ConfigError("caches.per_cache_factors must be a dictionary")
+
+ # Canonicalise the cache names *before* updating with the environment
+ # variables.
+ individual_factors = {
+ _canonicalise_cache_name(key): val
+ for key, val in individual_factors.items()
+ }
+
+ # Override factors from environment if necessary
+ individual_factors.update(
+ {
+ _canonicalise_cache_name(key[len(_CACHE_PREFIX) + 1 :]): float(val)
+ for key, val in self._environ.items()
+ if key.startswith(_CACHE_PREFIX + "_")
+ }
+ )
+
+ for cache, factor in individual_factors.items():
+ if not isinstance(factor, (int, float)):
+ raise ConfigError(
+ "caches.per_cache_factors.%s must be a number" % (cache,)
+ )
+ self.cache_factors[cache] = factor
+
+ # Resize all caches (if necessary) with the new factors we've loaded
+ self.resize_all_caches()
+
+ # Store this function so that it can be called from other classes without
+ # needing an instance of Config
+ properties.resize_all_caches_func = self.resize_all_caches
+
+ def resize_all_caches(self):
+ """Ensure all cache sizes are up to date
+
+ For each cache, run the mapped callback function with either
+ a specific cache factor or the default, global one.
+ """
+ for cache_name, callback in _CACHES.items():
+ new_factor = self.cache_factors.get(cache_name, self.global_factor)
+ callback(new_factor)
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 56c87fa2..82f04d79 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -32,23 +32,26 @@ class CaptchaConfig(Config):
def generate_config_section(self, **kwargs):
return """\
## Captcha ##
- # See docs/CAPTCHA_SETUP for full details of configuring this.
+ # See docs/CAPTCHA_SETUP.md for full details of configuring this.
- # This homeserver's ReCAPTCHA public key.
+ # This homeserver's ReCAPTCHA public key. Must be specified if
+ # enable_registration_captcha is enabled.
#
#recaptcha_public_key: "YOUR_PUBLIC_KEY"
- # This homeserver's ReCAPTCHA private key.
+ # This homeserver's ReCAPTCHA private key. Must be specified if
+ # enable_registration_captcha is enabled.
#
#recaptcha_private_key: "YOUR_PRIVATE_KEY"
- # Enables ReCaptcha checks when registering, preventing signup
+ # Uncomment to enable ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
- # public/private key.
+ # public/private key. Defaults to 'false'.
#
- #enable_registration_captcha: false
+ #enable_registration_captcha: true
# The API endpoint to use for verifying m.login.recaptcha responses.
+ # Defaults to "https://www.recaptcha.net/recaptcha/api/siteverify".
#
- #recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
+ #recaptcha_siteverify_api: "https://my.recaptcha.site"
"""
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 5b662d1b..1064c269 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -68,10 +68,6 @@ database:
name: sqlite3
args:
database: %(database_path)s
-
-# Number of events to cache in memory.
-#
-#event_cache_size: 10K
"""
@@ -116,8 +112,6 @@ class DatabaseConfig(Config):
self.databases = []
def read_config(self, config, **kwargs):
- self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
-
# We *experimentally* support specifying multiple databases via the
# `databases` key. This is a map from a label to database config in the
# same format as the `database` config option, plus an extra
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 76b8957e..ca612144 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -311,8 +311,8 @@ class EmailConfig(Config):
# Username/password for authentication to the SMTP server. By default, no
# authentication is attempted.
#
- # smtp_user: "exampleusername"
- # smtp_pass: "examplepassword"
+ #smtp_user: "exampleusername"
+ #smtp_pass: "examplepassword"
# Uncomment the following to require TLS transport security for SMTP.
# By default, Synapse will connect over plain text, and will then switch to
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index be6c6afa..2c7b3a69 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -17,6 +17,7 @@
from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
+from .cache import CacheConfig
from .captcha import CaptchaConfig
from .cas import CasConfig
from .consent_config import ConsentConfig
@@ -27,6 +28,7 @@ from .jwt_config import JWTConfig
from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
+from .oidc_config import OIDCConfig
from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
@@ -54,6 +56,7 @@ class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
TlsConfig,
+ CacheConfig,
DatabaseConfig,
LoggingConfig,
RatelimitConfig,
@@ -66,6 +69,7 @@ class HomeServerConfig(RootConfig):
AppServiceConfig,
KeyConfig,
SAML2Config,
+ OIDCConfig,
CasConfig,
SSOConfig,
JWTConfig,
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 066e7838..b529ea5d 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -175,8 +175,8 @@ class KeyConfig(Config):
)
form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
else:
- macaroon_secret_key = "# macaroon_secret_key: <PRIVATE STRING>"
- form_secret = "# form_secret: <PRIVATE STRING>"
+ macaroon_secret_key = "#macaroon_secret_key: <PRIVATE STRING>"
+ form_secret = "#form_secret: <PRIVATE STRING>"
return (
"""\
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index a25c70e9..49f6c32b 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -257,5 +257,6 @@ def setup_logging(
logging.warning("***** STARTING SERVER *****")
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
+ logging.info("Instance name: %s", hs.get_instance_name())
return logger
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 6f517a71..6aad0d37 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -93,10 +93,11 @@ class MetricsConfig(Config):
#known_servers: true
# Whether or not to report anonymized homeserver usage statistics.
+ #
"""
if report_stats is None:
- res += "# report_stats: true|false\n"
+ res += "#report_stats: true|false\n"
else:
res += "report_stats: %s\n" % ("true" if report_stats else "false")
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
new file mode 100644
index 00000000..e24dd637
--- /dev/null
+++ b/synapse/config/oidc_config.py
@@ -0,0 +1,203 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.module_loader import load_module
+
+from ._base import Config, ConfigError
+
+DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
+
+
+class OIDCConfig(Config):
+ section = "oidc"
+
+ def read_config(self, config, **kwargs):
+ self.oidc_enabled = False
+
+ oidc_config = config.get("oidc_config")
+
+ if not oidc_config or not oidc_config.get("enabled", False):
+ return
+
+ try:
+ check_requirements("oidc")
+ except DependencyException as e:
+ raise ConfigError(e.message)
+
+ public_baseurl = self.public_baseurl
+ if public_baseurl is None:
+ raise ConfigError("oidc_config requires a public_baseurl to be set")
+ self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
+
+ self.oidc_enabled = True
+ self.oidc_discover = oidc_config.get("discover", True)
+ self.oidc_issuer = oidc_config["issuer"]
+ self.oidc_client_id = oidc_config["client_id"]
+ self.oidc_client_secret = oidc_config["client_secret"]
+ self.oidc_client_auth_method = oidc_config.get(
+ "client_auth_method", "client_secret_basic"
+ )
+ self.oidc_scopes = oidc_config.get("scopes", ["openid"])
+ self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
+ self.oidc_token_endpoint = oidc_config.get("token_endpoint")
+ self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
+ self.oidc_jwks_uri = oidc_config.get("jwks_uri")
+ self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+
+ ump_config = oidc_config.get("user_mapping_provider", {})
+ ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+ ump_config.setdefault("config", {})
+
+ (
+ self.oidc_user_mapping_provider_class,
+ self.oidc_user_mapping_provider_config,
+ ) = load_module(ump_config)
+
+ # Ensure loaded user mapping module has defined all necessary methods
+ required_methods = [
+ "get_remote_user_id",
+ "map_user_attributes",
+ ]
+ missing_methods = [
+ method
+ for method in required_methods
+ if not hasattr(self.oidc_user_mapping_provider_class, method)
+ ]
+ if missing_methods:
+ raise ConfigError(
+ "Class specified by oidc_config."
+ "user_mapping_provider.module is missing required "
+ "methods: %s" % (", ".join(missing_methods),)
+ )
+
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ return """\
+ # OpenID Connect integration. The following settings can be used to make Synapse
+ # use an OpenID Connect Provider for authentication, instead of its internal
+ # password database.
+ #
+ # See https://github.com/matrix-org/synapse/blob/master/openid.md.
+ #
+ oidc_config:
+ # Uncomment the following to enable authorization against an OpenID Connect
+ # server. Defaults to false.
+ #
+ #enabled: true
+
+ # Uncomment the following to disable use of the OIDC discovery mechanism to
+ # discover endpoints. Defaults to true.
+ #
+ #discover: false
+
+ # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
+ # discover the provider's endpoints.
+ #
+ # Required if 'enabled' is true.
+ #
+ #issuer: "https://accounts.example.com/"
+
+ # oauth2 client id to use.
+ #
+ # Required if 'enabled' is true.
+ #
+ #client_id: "provided-by-your-issuer"
+
+ # oauth2 client secret to use.
+ #
+ # Required if 'enabled' is true.
+ #
+ #client_secret: "provided-by-your-issuer"
+
+ # auth method to use when exchanging the token.
+ # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
+ # 'none'.
+ #
+ #client_auth_method: client_secret_post
+
+ # list of scopes to request. This should normally include the "openid" scope.
+ # Defaults to ["openid"].
+ #
+ #scopes: ["openid", "profile"]
+
+ # the oauth2 authorization endpoint. Required if provider discovery is disabled.
+ #
+ #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+
+ # the oauth2 token endpoint. Required if provider discovery is disabled.
+ #
+ #token_endpoint: "https://accounts.example.com/oauth2/token"
+
+ # the OIDC userinfo endpoint. Required if discovery is disabled and the
+ # "openid" scope is not requested.
+ #
+ #userinfo_endpoint: "https://accounts.example.com/userinfo"
+
+ # URI where to fetch the JWKS. Required if discovery is disabled and the
+ # "openid" scope is used.
+ #
+ #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+
+ # Uncomment to skip metadata verification. Defaults to false.
+ #
+ # Use this if you are connecting to a provider that is not OpenID Connect
+ # compliant.
+ # Avoid this in production.
+ #
+ #skip_verification: true
+
+ # An external module can be provided here as a custom solution to mapping
+ # attributes returned from a OIDC provider onto a matrix user.
+ #
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ # Default is {mapping_provider!r}.
+ #
+ # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+ # for information on implementing a custom mapping provider.
+ #
+ #module: mapping_provider.OidcMappingProvider
+
+ # Custom configuration values for the module. This section will be passed as
+ # a Python dictionary to the user mapping provider module's `parse_config`
+ # method.
+ #
+ # The examples below are intended for the default provider: they should be
+ # changed if using a custom provider.
+ #
+ config:
+ # name of the claim containing a unique identifier for the user.
+ # Defaults to `sub`, which OpenID Connect compliant providers should provide.
+ #
+ #subject_claim: "sub"
+
+ # Jinja2 template for the localpart of the MXID.
+ #
+ # When rendering, this template is given the following variables:
+ # * user: The claims returned by the UserInfo Endpoint and/or in the ID
+ # Token
+ #
+ # This must be configured if using the default mapping provider.
+ #
+ localpart_template: "{{{{ user.preferred_username }}}}"
+
+ # Jinja2 template for the display name to set on first login.
+ #
+ # If unset, no displayname will be set.
+ #
+ #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
+ """.format(
+ mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
+ )
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4a3bfc43..2dd94bae 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -12,11 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict
+
from ._base import Config
class RateLimitConfig(object):
- def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}):
+ def __init__(
+ self,
+ config: Dict[str, float],
+ defaults={"per_second": 0.17, "burst_count": 3.0},
+ ):
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = config.get("burst_count", defaults["burst_count"])
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index e7ea3a01..fecced2d 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -128,6 +128,7 @@ class RegistrationConfig(Config):
if not RoomAlias.is_valid(room_alias):
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
+ self.auto_join_rooms_for_guests = config.get("auto_join_rooms_for_guests", True)
self.enable_set_displayname = config.get("enable_set_displayname", True)
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
@@ -148,9 +149,7 @@ class RegistrationConfig(Config):
random_string_with_symbols(50),
)
else:
- registration_shared_secret = (
- "# registration_shared_secret: <PRIVATE STRING>"
- )
+ registration_shared_secret = "#registration_shared_secret: <PRIVATE STRING>"
return (
"""\
@@ -370,6 +369,13 @@ class RegistrationConfig(Config):
# users cannot be auto-joined since they do not exist.
#
#autocreate_auto_join_rooms: true
+
+ # When auto_join_rooms is specified, setting this flag to false prevents
+ # guest accounts from being automatically joined to the rooms.
+ #
+ # Defaults to true.
+ #
+ #auto_join_rooms_for_guests: false
"""
% locals()
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 9d2ce202..b751d02d 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -70,6 +70,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
+ requirements.setdefault("image/webp", []).append(jpeg_thumbnail)
requirements.setdefault("image/gif", []).append(png_thumbnail)
requirements.setdefault("image/png", []).append(png_thumbnail)
return {
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 726a27d7..d0a19751 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -15,8 +15,8 @@
# limitations under the License.
import logging
-import os
+import jinja2
import pkg_resources
from synapse.python_dependencies import DependencyException, check_requirements
@@ -167,9 +167,11 @@ class SAML2Config(Config):
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
- self.saml2_error_html_content = self.read_file(
- os.path.join(template_dir, "saml_error.html"), "saml2_config.saml_error",
- )
+ loader = jinja2.FileSystemLoader(template_dir)
+ # enable auto-escape here, to having to remember to escape manually in the
+ # template
+ env = jinja2.Environment(loader=loader, autoescape=True)
+ self.saml2_error_html_template = env.get_template("saml_error.html")
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
@@ -216,6 +218,8 @@ class SAML2Config(Config):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
+ ## Single sign-on integration ##
+
# Enable SAML2 for registration and login. Uses pysaml2.
#
# At least one of `sp_config` or `config_path` must be set in this section to
@@ -349,7 +353,13 @@ class SAML2Config(Config):
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
- # This template doesn't currently need any variable to render.
+ # When rendering, this template is given the following variables:
+ # * code: an HTML error code corresponding to the error that is being
+ # returned (typically 400 or 500)
+ #
+ # * msg: a textual message describing the error.
+ #
+ # The variables will automatically be HTML-escaped.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ed28da3d..f57eefc9 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -434,7 +434,7 @@ class ServerConfig(Config):
)
self.limit_remote_rooms = LimitRemoteRoomsConfig(
- **config.get("limit_remote_rooms", {})
+ **(config.get("limit_remote_rooms") or {})
)
bind_port = config.get("bind_port")
@@ -895,22 +895,27 @@ class ServerConfig(Config):
# Used by phonehome stats to group together related servers.
#server_context: context
- # Resource-constrained homeserver Settings
+ # Resource-constrained homeserver settings
#
- # If limit_remote_rooms.enabled is True, the room complexity will be
- # checked before a user joins a new remote room. If it is above
- # limit_remote_rooms.complexity, it will disallow joining or
- # instantly leave.
+ # When this is enabled, the room "complexity" will be checked before a user
+ # joins a new remote room. If it is above the complexity limit, the server will
+ # disallow joining, or will instantly leave.
#
- # limit_remote_rooms.complexity_error can be set to customise the text
- # displayed to the user when a room above the complexity threshold has
- # its join cancelled.
+ # Room complexity is an arbitrary measure based on factors such as the number of
+ # users in the room.
#
- # Uncomment the below lines to enable:
- #limit_remote_rooms:
- # enabled: true
- # complexity: 1.0
- # complexity_error: "This room is too complex."
+ limit_remote_rooms:
+ # Uncomment to enable room complexity checking.
+ #
+ #enabled: true
+
+ # the limit above which rooms cannot be joined. The default is 1.0.
+ #
+ #complexity: 0.5
+
+ # override the error which is returned when the room is too complex.
+ #
+ #complexity_error: "This room is too complex."
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
index 6ea2ea88..6c427b6f 100644
--- a/synapse/config/server_notices_config.py
+++ b/synapse/config/server_notices_config.py
@@ -51,7 +51,7 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled.
server_notices_mxid_avatar_url (str|None):
- The display name to use for the server notices user.
+ The MXC URL for the avatar of the server notices user.
None if server notices are not enabled.
server_notices_room_name (str|None):
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 36e0ddab..3d067d29 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, List, Tuple
+
+from synapse.config import ConfigError
from synapse.util.module_loader import load_module
from ._base import Config
@@ -22,16 +25,35 @@ class SpamCheckerConfig(Config):
section = "spamchecker"
def read_config(self, config, **kwargs):
- self.spam_checker = None
+ self.spam_checkers = [] # type: List[Tuple[Any, Dict]]
+
+ spam_checkers = config.get("spam_checker") or []
+ if isinstance(spam_checkers, dict):
+ # The spam_checker config option used to only support one
+ # spam checker, and thus was simply a dictionary with module
+ # and config keys. Support this old behaviour by checking
+ # to see if the option resolves to a dictionary
+ self.spam_checkers.append(load_module(spam_checkers))
+ elif isinstance(spam_checkers, list):
+ for spam_checker in spam_checkers:
+ if not isinstance(spam_checker, dict):
+ raise ConfigError("spam_checker syntax is incorrect")
- provider = config.get("spam_checker", None)
- if provider is not None:
- self.spam_checker = load_module(provider)
+ self.spam_checkers.append(load_module(spam_checker))
+ else:
+ raise ConfigError("spam_checker syntax is incorrect")
def generate_config_section(self, **kwargs):
return """\
- #spam_checker:
- # module: "my_custom_project.SuperSpamChecker"
- # config:
- # example_option: 'things'
+ # Spam checkers are third-party modules that can block specific actions
+ # of local users, such as creating rooms and registering undesirable
+ # usernames, as well as remote users by redacting incoming events.
+ #
+ spam_checker:
+ #- module: "my_custom_project.SuperSpamChecker"
+ # config:
+ # example_option: 'things'
+ #- module: "some_other_project.BadEventStopper"
+ # config:
+ # example_stop_events_from: ['@bad:example.com']
"""
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index cac6bc01..73b72963 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -36,17 +36,13 @@ class SSOConfig(Config):
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
- self.sso_redirect_confirm_template_dir = template_dir
+ self.sso_template_dir = template_dir
self.sso_account_deactivated_template = self.read_file(
- os.path.join(
- self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html"
- ),
+ os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
"sso_account_deactivated_template",
)
self.sso_auth_success_template = self.read_file(
- os.path.join(
- self.sso_redirect_confirm_template_dir, "sso_auth_success.html"
- ),
+ os.path.join(self.sso_template_dir, "sso_auth_success.html"),
"sso_auth_success_template",
)
@@ -65,7 +61,8 @@ class SSOConfig(Config):
def generate_config_section(self, **kwargs):
return """\
- # Additional settings to use with single-sign on systems such as SAML2 and CAS.
+ # Additional settings to use with single-sign on systems such as OpenID Connect,
+ # SAML2 and CAS.
#
sso:
# A list of client URLs which are whitelisted so that the user does not
@@ -137,6 +134,13 @@ class SSOConfig(Config):
#
# This template has no additional variables.
#
+ # * HTML page to display to users if something goes wrong during the
+ # OpenID Connect authentication process: 'sso_error.html'.
+ #
+ # When rendering, this template is given two variables:
+ # * error: the technical name of the error
+ # * error_description: a human-readable message for the error
+ #
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index fef72ed9..ed06b91a 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -13,7 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+import attr
+
+from ._base import Config, ConfigError
+
+
+@attr.s
+class InstanceLocationConfig:
+ """The host and port to talk to an instance via HTTP replication.
+ """
+
+ host = attr.ib(type=str)
+ port = attr.ib(type=int)
+
+
+@attr.s
+class WriterLocations:
+ """Specifies the instances that write various streams.
+
+ Attributes:
+ events: The instance that writes to the event and backfill streams.
+ """
+
+ events = attr.ib(default="master", type=str)
class WorkerConfig(Config):
@@ -71,6 +93,27 @@ class WorkerConfig(Config):
elif not bind_addresses:
bind_addresses.append("")
+ # A map from instance name to host/port of their HTTP replication endpoint.
+ instance_map = config.get("instance_map") or {}
+ self.instance_map = {
+ name: InstanceLocationConfig(**c) for name, c in instance_map.items()
+ }
+
+ # Map from type of streams to source, c.f. WriterLocations.
+ writers = config.get("stream_writers") or {}
+ self.writers = WriterLocations(**writers)
+
+ # Check that the configured writer for events also appears in
+ # `instance_map`.
+ if (
+ self.writers.events != "master"
+ and self.writers.events not in self.instance_map
+ ):
+ raise ConfigError(
+ "Instance %r is configured to write events but does not appear in `instance_map` config."
+ % (self.writers.events,)
+ )
+
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in
# the config. A lot of these options have a worker_* prefix when running
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 46beb533..c5823551 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import Set, Tuple
+from typing import List, Optional, Set, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -29,18 +29,19 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersion,
)
-from synapse.types import UserID, get_domain_from_id
+from synapse.events import EventBase
+from synapse.types import StateMap, UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(
room_version_obj: RoomVersion,
- event,
- auth_events,
- do_sig_check=True,
- do_size_check=True,
-):
+ event: EventBase,
+ auth_events: StateMap[EventBase],
+ do_sig_check: bool = True,
+ do_size_check: bool = True,
+) -> None:
""" Checks if this event is correctly authed.
Args:
@@ -181,7 +182,7 @@ def check(
_can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
- _check_power_levels(event, auth_events)
+ _check_power_levels(room_version_obj, event, auth_events)
if event.type == EventTypes.Redaction:
check_redaction(room_version_obj, event, auth_events)
@@ -189,7 +190,7 @@ def check(
logger.debug("Allowing! %s", event)
-def _check_size_limits(event):
+def _check_size_limits(event: EventBase) -> None:
def too_big(field):
raise EventSizeError("%s too large" % (field,))
@@ -207,13 +208,18 @@ def _check_size_limits(event):
too_big("event")
-def _can_federate(event, auth_events):
+def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
creation_event = auth_events.get((EventTypes.Create, ""))
+ # There should always be a creation event, but if not don't federate.
+ if not creation_event:
+ return False
return creation_event.content.get("m.federate", True) is True
-def _is_membership_change_allowed(event, auth_events):
+def _is_membership_change_allowed(
+ event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
membership = event.content["membership"]
# Check if this is the room creator joining:
@@ -339,21 +345,25 @@ def _is_membership_change_allowed(event, auth_events):
raise AuthError(500, "Unknown membership %s" % membership)
-def _check_event_sender_in_room(event, auth_events):
+def _check_event_sender_in_room(
+ event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
key = (EventTypes.Member, event.user_id)
member_event = auth_events.get(key)
- return _check_joined_room(member_event, event.user_id, event.room_id)
+ _check_joined_room(member_event, event.user_id, event.room_id)
-def _check_joined_room(member, user_id, room_id):
+def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None:
if not member or member.membership != Membership.JOIN:
raise AuthError(
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
)
-def get_send_level(etype, state_key, power_levels_event):
+def get_send_level(
+ etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase]
+) -> int:
"""Get the power level required to send an event of a given type
The federation spec [1] refers to this as "Required Power Level".
@@ -361,13 +371,13 @@ def get_send_level(etype, state_key, power_levels_event):
https://matrix.org/docs/spec/server_server/unstable.html#definitions
Args:
- etype (str): type of event
- state_key (str|None): state_key of state event, or None if it is not
+ etype: type of event
+ state_key: state_key of state event, or None if it is not
a state event.
- power_levels_event (synapse.events.EventBase|None): power levels event
+ power_levels_event: power levels event
in force at this point in the room
Returns:
- int: power level required to send this event.
+ power level required to send this event.
"""
if power_levels_event:
@@ -388,7 +398,7 @@ def get_send_level(etype, state_key, power_levels_event):
return int(send_level)
-def _can_send_event(event, auth_events):
+def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
power_levels_event = _get_power_level_event(auth_events)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
@@ -410,7 +420,9 @@ def _can_send_event(event, auth_events):
return True
-def check_redaction(room_version_obj: RoomVersion, event, auth_events):
+def check_redaction(
+ room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+) -> bool:
"""Check whether the event sender is allowed to redact the target event.
Returns:
@@ -442,7 +454,9 @@ def check_redaction(room_version_obj: RoomVersion, event, auth_events):
raise AuthError(403, "You don't have permission to redact events")
-def _check_power_levels(event, auth_events):
+def _check_power_levels(
+ room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+) -> None:
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
@@ -473,7 +487,7 @@ def _check_power_levels(event, auth_events):
("redact", None),
("kick", None),
("invite", None),
- ]
+ ] # type: List[Tuple[str, Optional[str]]]
old_list = current_state.content.get("users", {})
for user in set(list(old_list) + list(user_list)):
@@ -484,6 +498,14 @@ def _check_power_levels(event, auth_events):
for ev_id in set(list(old_list) + list(new_list)):
levels_to_check.append((ev_id, "events"))
+ # MSC2209 specifies these checks should also be done for the "notifications"
+ # key.
+ if room_version_obj.limit_notifications_power_levels:
+ old_list = current_state.content.get("notifications", {})
+ new_list = event.content.get("notifications", {})
+ for ev_id in set(list(old_list) + list(new_list)):
+ levels_to_check.append((ev_id, "notifications"))
+
old_state = current_state.content
new_state = event.content
@@ -495,12 +517,12 @@ def _check_power_levels(event, auth_events):
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
- old_level = int(old_loc[level_to_check])
+ old_level = int(old_loc[level_to_check]) # type: Optional[int]
else:
old_level = None
if level_to_check in new_loc:
- new_level = int(new_loc[level_to_check])
+ new_level = int(new_loc[level_to_check]) # type: Optional[int]
else:
new_level = None
@@ -526,21 +548,21 @@ def _check_power_levels(event, auth_events):
)
-def _get_power_level_event(auth_events):
+def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
return auth_events.get((EventTypes.PowerLevels, ""))
-def get_user_power_level(user_id, auth_events):
+def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
"""Get a user's power level
Args:
- user_id (str): user's id to look up in power_levels
- auth_events (dict[(str, str), synapse.events.EventBase]):
+ user_id: user's id to look up in power_levels
+ auth_events:
state in force at this point in the room (or rather, a subset of
it including at least the create event and power levels event.
Returns:
- int: the user's power level in this room.
+ the user's power level in this room.
"""
power_level_event = _get_power_level_event(auth_events)
if power_level_event:
@@ -566,7 +588,7 @@ def get_user_power_level(user_id, auth_events):
return 0
-def _get_named_level(auth_events, name, default):
+def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
power_level_event = _get_power_level_event(auth_events)
if not power_level_event:
@@ -579,7 +601,7 @@ def _get_named_level(auth_events, name, default):
return default
-def _verify_third_party_invite(event, auth_events):
+def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]):
"""
Validates that the invite event is authorized by a previous third-party invite.
@@ -654,7 +676,7 @@ def get_public_keys(invite_event):
return public_keys
-def auth_types_for_event(event) -> Set[Tuple[str, str]]:
+def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index a23b6b7b..1ffc9525 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,7 +15,7 @@
# limitations under the License.
import inspect
-from typing import Dict
+from typing import Any, Dict, List
from synapse.spam_checker_api import SpamCheckerApi
@@ -26,24 +26,17 @@ if MYPY:
class SpamChecker(object):
def __init__(self, hs: "synapse.server.HomeServer"):
- self.spam_checker = None
+ self.spam_checkers = [] # type: List[Any]
- module = None
- config = None
- try:
- module, config = hs.config.spam_checker
- except Exception:
- pass
-
- if module is not None:
+ for module, config in hs.config.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
if "api" in spam_args.args:
api = SpamCheckerApi(hs)
- self.spam_checker = module(config=config, api=api)
+ self.spam_checkers.append(module(config=config, api=api))
else:
- self.spam_checker = module(config=config)
+ self.spam_checkers.append(module(config=config))
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
"""Checks if a given event is considered "spammy" by this server.
@@ -58,10 +51,11 @@ class SpamChecker(object):
Returns:
True if the event is spammy.
"""
- if self.spam_checker is None:
- return False
+ for spam_checker in self.spam_checkers:
+ if spam_checker.check_event_for_spam(event):
+ return True
- return self.spam_checker.check_event_for_spam(event)
+ return False
def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
@@ -78,12 +72,14 @@ class SpamChecker(object):
Returns:
True if the user may send an invite, otherwise False
"""
- if self.spam_checker is None:
- return True
+ for spam_checker in self.spam_checkers:
+ if (
+ spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ is False
+ ):
+ return False
- return self.spam_checker.user_may_invite(
- inviter_userid, invitee_userid, room_id
- )
+ return True
def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room
@@ -96,10 +92,11 @@ class SpamChecker(object):
Returns:
True if the user may create a room, otherwise False
"""
- if self.spam_checker is None:
- return True
+ for spam_checker in self.spam_checkers:
+ if spam_checker.user_may_create_room(userid) is False:
+ return False
- return self.spam_checker.user_may_create_room(userid)
+ return True
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
@@ -113,10 +110,11 @@ class SpamChecker(object):
Returns:
True if the user may create a room alias, otherwise False
"""
- if self.spam_checker is None:
- return True
+ for spam_checker in self.spam_checkers:
+ if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+ return False
- return self.spam_checker.user_may_create_room_alias(userid, room_alias)
+ return True
def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
@@ -130,10 +128,11 @@ class SpamChecker(object):
Returns:
True if the user may publish the room, otherwise False
"""
- if self.spam_checker is None:
- return True
+ for spam_checker in self.spam_checkers:
+ if spam_checker.user_may_publish_room(userid, room_id) is False:
+ return False
- return self.spam_checker.user_may_publish_room(userid, room_id)
+ return True
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
@@ -150,13 +149,14 @@ class SpamChecker(object):
Returns:
True if the user is spammy.
"""
- if self.spam_checker is None:
- return False
-
- # For backwards compatibility, if the method does not exist on the spam checker, fallback to not interfering.
- checker = getattr(self.spam_checker, "check_username_for_spam", None)
- if not checker:
- return False
- # Make a copy of the user profile object to ensure the spam checker
- # cannot modify it.
- return checker(user_profile.copy())
+ for spam_checker in self.spam_checkers:
+ # For backwards compatibility, only run if the method exists on the
+ # spam checker
+ checker = getattr(spam_checker, "check_username_for_spam", None)
+ if checker:
+ # Make a copy of the user profile object to ensure the spam checker
+ # cannot modify it.
+ if checker(user_profile.copy()):
+ return True
+
+ return False
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index b75b097e..dd340be9 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -14,7 +14,7 @@
# limitations under the License.
import collections
import re
-from typing import Mapping, Union
+from typing import Any, Mapping, Union
from six import string_types
@@ -23,6 +23,7 @@ from frozendict import frozendict
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.util.async_helpers import yieldable_gather_results
@@ -449,3 +450,35 @@ def copy_power_levels_contents(
raise TypeError("Invalid power_levels value for %s: %r" % (k, v))
return power_levels
+
+
+def validate_canonicaljson(value: Any):
+ """
+ Ensure that the JSON object is valid according to the rules of canonical JSON.
+
+ See the appendix section 3.1: Canonical JSON.
+
+ This rejects JSON that has:
+ * An integer outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
+ * Floats
+ * NaN, Infinity, -Infinity
+ """
+ if isinstance(value, int):
+ if value <= -(2 ** 53) or 2 ** 53 <= value:
+ raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON)
+
+ elif isinstance(value, float):
+ # Note that Infinity, -Infinity, and NaN are also considered floats.
+ raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON)
+
+ elif isinstance(value, (dict, frozendict)):
+ for v in value.values():
+ validate_canonicaljson(v)
+
+ elif isinstance(value, (list, tuple)):
+ for i in value:
+ validate_canonicaljson(i)
+
+ elif not isinstance(value, (bool, str)) and value is not None:
+ # Other potential JSON values (bool, None, str) are safe.
+ raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 9b90c9ce..b001c64b 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -18,6 +18,7 @@ from six import integer_types, string_types
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
+from synapse.events.utils import validate_canonicaljson
from synapse.types import EventID, RoomID, UserID
@@ -55,6 +56,12 @@ class EventValidator(object):
if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "'%s' not a string type" % (s,))
+ # Depending on the room version, ensure the data is spec compliant JSON.
+ if event.room_version.strict_canonicaljson:
+ # Note that only the client controlled portion of the event is
+ # checked, since we trust the portions of the event we created.
+ validate_canonicaljson(event.content)
+
if event.type == EventTypes.Aliases:
if "aliases" in event.content:
for alias in event.content["aliases"]:
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 4b115aac..c0012c68 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -29,7 +29,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
-from synapse.events.utils import prune_event
+from synapse.events.utils import prune_event, validate_canonicaljson
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
PreserveLoggingContext,
@@ -302,6 +302,10 @@ def event_from_pdu_json(
elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
+ # Validate that the JSON conforms to the specification.
+ if room_version.strict_canonicaljson:
+ validate_canonicaljson(pdu_json)
+
event = make_event_from_dict(pdu_json, room_version)
event.internal_metadata.outlier = outlier
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index e1700ca8..52f4f542 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
+from typing import Dict, List, Tuple, Type
from six import iteritems
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = SortedDict() # Stream position -> list[user_id]
+ # Pending presence map user_id -> UserPresenceState
+ self.presence_map = {} # type: Dict[str, UserPresenceState]
+
+ # Stream position -> list[user_id]
+ self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
- self.presence_destinations = SortedDict()
+ self.presence_destinations = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, List[str]]]
+
+ # (destination, key) -> EDU
+ self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
- self.keyed_edu = {} # (destination, key) -> EDU
- self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
+ # stream position -> (destination, key)
+ self.keyed_edu_changed = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, tuple]]
- self.edus = SortedDict() # stream position -> Edu
+ self.edus = SortedDict() # type: SortedDict[int, Edu]
+ # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
self.pos = 1
- self.pos_time = SortedDict()
+
+ # map from stream ID to the time that stream entry was generated, so that we
+ # can clear out entries after a while
+ self.pos_time = SortedDict() # type: SortedDict[int, int]
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
- to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
- for edu_key in to_del:
+ keys_to_del = [
+ edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+ ]
+ for edu_key in keys_to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(token)
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
- from_token (int)
- to_token(int)
- limit (int)
- federation_ack (int): Optional. The position where the worker is
- explicitly acknowledged it has handled. Allows us to drop
- data from before that point
+ instance_name: the name of the current process
+ from_token: the previous stream token: the starting point for fetching the
+ updates
+ to_token: the new stream token: the point to get updates up to
+ target_row_count: a target for the number of rows to be returned.
+
+ Returns: a triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of `(token, row)` entries.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
"""
- # TODO: Handle limit.
+ # TODO: Handle target_row_count.
# To handle restarts where we wrap around
if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
- rows = []
-
- # There should be only one reader, so lets delete everything its
- # acknowledged its seen.
- if federation_ack:
- self._clear_queue_before_pos(federation_ack)
+ rows = [] # type: List[Tuple[int, BaseFederationRow]]
# Fetch changed presence
i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
# Sort rows based on pos
rows.sort()
- return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+ return (
+ [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+ to_token,
+ False,
+ )
class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
@staticmethod
def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-TypeToRow = {
- Row.TypeId: Row
- for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+ PresenceRow,
+ PresenceDestinationsRow,
+ KeyedEduRow,
+ EduRow,
+) # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a477578e..d4735769 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
@@ -498,14 +498,16 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
- def get_current_token(self) -> int:
+ @staticmethod
+ def get_current_token() -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
+ @staticmethod
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
- return []
+ return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20f..4e698981 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
# limitations under the License.
import datetime
import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
-import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+if TYPE_CHECKING:
+ import synapse.server
+
# This is defined in the Matrix spec and enforced by the receiver.
MAX_EDUS_PER_TRANSACTION = 100
@@ -78,6 +80,9 @@ class PerDestinationQueue(object):
# a list of tuples of (pending pdu, order)
self._pending_pdus = [] # type: List[Tuple[EventBase, int]]
+
+ # XXX this is never actually used: see
+ # https://github.com/matrix-org/synapse/issues/7549
self._pending_edus = [] # type: List[Edu]
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3..a2752a54 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import TYPE_CHECKING, List
from canonicaljson import json
-import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
)
from synapse.util.metrics import measure_func
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 4acb4fa4..8a9de913 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -19,8 +19,6 @@ import logging
from six import string_types
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
@@ -51,8 +49,7 @@ class GroupsServerWorkerHandler(object):
self.transport_client = hs.get_federation_transport_client()
self.profile_handler = hs.get_profile_handler()
- @defer.inlineCallbacks
- def check_group_is_ours(
+ async def check_group_is_ours(
self, group_id, requester_user_id, and_exists=False, and_is_admin=None
):
"""Check that the group is ours, and optionally if it exists.
@@ -68,25 +65,24 @@ class GroupsServerWorkerHandler(object):
if not self.is_mine_id(group_id):
raise SynapseError(400, "Group not on this server")
- group = yield self.store.get_group(group_id)
+ group = await self.store.get_group(group_id)
if and_exists and not group:
raise SynapseError(404, "Unknown group")
- is_user_in_group = yield self.store.is_user_in_group(
+ is_user_in_group = await self.store.is_user_in_group(
requester_user_id, group_id
)
if group and not is_user_in_group and not group["is_public"]:
raise SynapseError(404, "Unknown group")
if and_is_admin:
- is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
+ is_admin = await self.store.is_user_admin_in_group(group_id, and_is_admin)
if not is_admin:
raise SynapseError(403, "User is not admin in group")
return group
- @defer.inlineCallbacks
- def get_group_summary(self, group_id, requester_user_id):
+ async def get_group_summary(self, group_id, requester_user_id):
"""Get the summary for a group as seen by requester_user_id.
The group summary consists of the profile of the room, and a curated
@@ -95,28 +91,28 @@ class GroupsServerWorkerHandler(object):
A user/room may appear in multiple roles/categories.
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(
+ is_user_in_group = await self.store.is_user_in_group(
requester_user_id, group_id
)
- profile = yield self.get_group_profile(group_id, requester_user_id)
+ profile = await self.get_group_profile(group_id, requester_user_id)
- users, roles = yield self.store.get_users_for_summary_by_role(
+ users, roles = await self.store.get_users_for_summary_by_role(
group_id, include_private=is_user_in_group
)
# TODO: Add profiles to users
- rooms, categories = yield self.store.get_rooms_for_summary_by_category(
+ rooms, categories = await self.store.get_rooms_for_summary_by_category(
group_id, include_private=is_user_in_group
)
for room_entry in rooms:
room_id = room_entry["room_id"]
- joined_users = yield self.store.get_users_in_room(room_id)
- entry = yield self.room_list_handler.generate_room_entry(
+ joined_users = await self.store.get_users_in_room(room_id)
+ entry = await self.room_list_handler.generate_room_entry(
room_id, len(joined_users), with_alias=False, allow_private=True
)
entry = dict(entry) # so we don't change whats cached
@@ -130,7 +126,7 @@ class GroupsServerWorkerHandler(object):
user_id = entry["user_id"]
if not self.is_mine_id(requester_user_id):
- attestation = yield self.store.get_remote_attestation(group_id, user_id)
+ attestation = await self.store.get_remote_attestation(group_id, user_id)
if not attestation:
continue
@@ -140,12 +136,12 @@ class GroupsServerWorkerHandler(object):
group_id, user_id
)
- user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ user_profile = await self.profile_handler.get_profile_from_cache(user_id)
entry.update(user_profile)
users.sort(key=lambda e: e.get("order", 0))
- membership_info = yield self.store.get_users_membership_info_in_group(
+ membership_info = await self.store.get_users_membership_info_in_group(
group_id, requester_user_id
)
@@ -164,22 +160,20 @@ class GroupsServerWorkerHandler(object):
"user": membership_info,
}
- @defer.inlineCallbacks
- def get_group_categories(self, group_id, requester_user_id):
+ async def get_group_categories(self, group_id, requester_user_id):
"""Get all categories in a group (as seen by user)
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- categories = yield self.store.get_group_categories(group_id=group_id)
+ categories = await self.store.get_group_categories(group_id=group_id)
return {"categories": categories}
- @defer.inlineCallbacks
- def get_group_category(self, group_id, requester_user_id, category_id):
+ async def get_group_category(self, group_id, requester_user_id, category_id):
"""Get a specific category in a group (as seen by user)
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- res = yield self.store.get_group_category(
+ res = await self.store.get_group_category(
group_id=group_id, category_id=category_id
)
@@ -187,32 +181,29 @@ class GroupsServerWorkerHandler(object):
return res
- @defer.inlineCallbacks
- def get_group_roles(self, group_id, requester_user_id):
+ async def get_group_roles(self, group_id, requester_user_id):
"""Get all roles in a group (as seen by user)
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- roles = yield self.store.get_group_roles(group_id=group_id)
+ roles = await self.store.get_group_roles(group_id=group_id)
return {"roles": roles}
- @defer.inlineCallbacks
- def get_group_role(self, group_id, requester_user_id, role_id):
+ async def get_group_role(self, group_id, requester_user_id, role_id):
"""Get a specific role in a group (as seen by user)
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- res = yield self.store.get_group_role(group_id=group_id, role_id=role_id)
+ res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
return res
- @defer.inlineCallbacks
- def get_group_profile(self, group_id, requester_user_id):
+ async def get_group_profile(self, group_id, requester_user_id):
"""Get the group profile as seen by requester_user_id
"""
- yield self.check_group_is_ours(group_id, requester_user_id)
+ await self.check_group_is_ours(group_id, requester_user_id)
- group = yield self.store.get_group(group_id)
+ group = await self.store.get_group(group_id)
if group:
cols = [
@@ -229,20 +220,19 @@ class GroupsServerWorkerHandler(object):
else:
raise SynapseError(404, "Unknown group")
- @defer.inlineCallbacks
- def get_users_in_group(self, group_id, requester_user_id):
+ async def get_users_in_group(self, group_id, requester_user_id):
"""Get the users in group as seen by requester_user_id.
The ordering is arbitrary at the moment
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(
+ is_user_in_group = await self.store.is_user_in_group(
requester_user_id, group_id
)
- user_results = yield self.store.get_users_in_group(
+ user_results = await self.store.get_users_in_group(
group_id, include_private=is_user_in_group
)
@@ -254,14 +244,14 @@ class GroupsServerWorkerHandler(object):
entry = {"user_id": g_user_id}
- profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
+ profile = await self.profile_handler.get_profile_from_cache(g_user_id)
entry.update(profile)
entry["is_public"] = bool(is_public)
entry["is_privileged"] = bool(is_privileged)
if not self.is_mine_id(g_user_id):
- attestation = yield self.store.get_remote_attestation(
+ attestation = await self.store.get_remote_attestation(
group_id, g_user_id
)
if not attestation:
@@ -279,30 +269,29 @@ class GroupsServerWorkerHandler(object):
return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
- @defer.inlineCallbacks
- def get_invited_users_in_group(self, group_id, requester_user_id):
+ async def get_invited_users_in_group(self, group_id, requester_user_id):
"""Get the users that have been invited to a group as seen by requester_user_id.
The ordering is arbitrary at the moment
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(
+ is_user_in_group = await self.store.is_user_in_group(
requester_user_id, group_id
)
if not is_user_in_group:
raise SynapseError(403, "User not in group")
- invited_users = yield self.store.get_invited_users_in_group(group_id)
+ invited_users = await self.store.get_invited_users_in_group(group_id)
user_profiles = []
for user_id in invited_users:
user_profile = {"user_id": user_id}
try:
- profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ profile = await self.profile_handler.get_profile_from_cache(user_id)
user_profile.update(profile)
except Exception as e:
logger.warning("Error getting profile for %s: %s", user_id, e)
@@ -310,20 +299,19 @@ class GroupsServerWorkerHandler(object):
return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
- @defer.inlineCallbacks
- def get_rooms_in_group(self, group_id, requester_user_id):
+ async def get_rooms_in_group(self, group_id, requester_user_id):
"""Get the rooms in group as seen by requester_user_id
This returns rooms in order of decreasing number of joined users
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(
+ is_user_in_group = await self.store.is_user_in_group(
requester_user_id, group_id
)
- room_results = yield self.store.get_rooms_in_group(
+ room_results = await self.store.get_rooms_in_group(
group_id, include_private=is_user_in_group
)
@@ -331,8 +319,8 @@ class GroupsServerWorkerHandler(object):
for room_result in room_results:
room_id = room_result["room_id"]
- joined_users = yield self.store.get_users_in_room(room_id)
- entry = yield self.room_list_handler.generate_room_entry(
+ joined_users = await self.store.get_users_in_room(room_id)
+ entry = await self.room_list_handler.generate_room_entry(
room_id, len(joined_users), with_alias=False, allow_private=True
)
@@ -355,13 +343,12 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
# Ensure attestations get renewed
hs.get_groups_attestation_renewer()
- @defer.inlineCallbacks
- def update_group_summary_room(
+ async def update_group_summary_room(
self, group_id, requester_user_id, room_id, category_id, content
):
"""Add/update a room to the group summary
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -371,7 +358,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
is_public = _parse_visibility_from_contents(content)
- yield self.store.add_room_to_summary(
+ await self.store.add_room_to_summary(
group_id=group_id,
room_id=room_id,
category_id=category_id,
@@ -381,31 +368,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- @defer.inlineCallbacks
- def delete_group_summary_room(
+ async def delete_group_summary_room(
self, group_id, requester_user_id, room_id, category_id
):
"""Remove a room from the summary
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
- yield self.store.remove_room_from_summary(
+ await self.store.remove_room_from_summary(
group_id=group_id, room_id=room_id, category_id=category_id
)
return {}
- @defer.inlineCallbacks
- def set_group_join_policy(self, group_id, requester_user_id, content):
+ async def set_group_join_policy(self, group_id, requester_user_id, content):
"""Sets the group join policy.
Currently supported policies are:
- "invite": an invite must be received and accepted in order to join.
- "open": anyone can join.
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -413,22 +398,23 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if join_policy is None:
raise SynapseError(400, "No value specified for 'm.join_policy'")
- yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
+ await self.store.set_group_join_policy(group_id, join_policy=join_policy)
return {}
- @defer.inlineCallbacks
- def update_group_category(self, group_id, requester_user_id, category_id, content):
+ async def update_group_category(
+ self, group_id, requester_user_id, category_id, content
+ ):
"""Add/Update a group category
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
profile = content.get("profile")
- yield self.store.upsert_group_category(
+ await self.store.upsert_group_category(
group_id=group_id,
category_id=category_id,
is_public=is_public,
@@ -437,25 +423,23 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- @defer.inlineCallbacks
- def delete_group_category(self, group_id, requester_user_id, category_id):
+ async def delete_group_category(self, group_id, requester_user_id, category_id):
"""Delete a group category
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
- yield self.store.remove_group_category(
+ await self.store.remove_group_category(
group_id=group_id, category_id=category_id
)
return {}
- @defer.inlineCallbacks
- def update_group_role(self, group_id, requester_user_id, role_id, content):
+ async def update_group_role(self, group_id, requester_user_id, role_id, content):
"""Add/update a role in a group
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -463,31 +447,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
profile = content.get("profile")
- yield self.store.upsert_group_role(
+ await self.store.upsert_group_role(
group_id=group_id, role_id=role_id, is_public=is_public, profile=profile
)
return {}
- @defer.inlineCallbacks
- def delete_group_role(self, group_id, requester_user_id, role_id):
+ async def delete_group_role(self, group_id, requester_user_id, role_id):
"""Remove role from group
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
- yield self.store.remove_group_role(group_id=group_id, role_id=role_id)
+ await self.store.remove_group_role(group_id=group_id, role_id=role_id)
return {}
- @defer.inlineCallbacks
- def update_group_summary_user(
+ async def update_group_summary_user(
self, group_id, requester_user_id, user_id, role_id, content
):
"""Add/update a users entry in the group summary
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -495,7 +477,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
is_public = _parse_visibility_from_contents(content)
- yield self.store.add_user_to_summary(
+ await self.store.add_user_to_summary(
group_id=group_id,
user_id=user_id,
role_id=role_id,
@@ -505,25 +487,25 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- @defer.inlineCallbacks
- def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
+ async def delete_group_summary_user(
+ self, group_id, requester_user_id, user_id, role_id
+ ):
"""Remove a user from the group summary
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
- yield self.store.remove_user_from_summary(
+ await self.store.remove_user_from_summary(
group_id=group_id, user_id=user_id, role_id=role_id
)
return {}
- @defer.inlineCallbacks
- def update_group_profile(self, group_id, requester_user_id, content):
+ async def update_group_profile(self, group_id, requester_user_id, content):
"""Update the group profile
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@@ -535,40 +517,38 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise SynapseError(400, "%r value is not a string" % (keyname,))
profile[keyname] = value
- yield self.store.update_group_profile(group_id, profile)
+ await self.store.update_group_profile(group_id, profile)
- @defer.inlineCallbacks
- def add_room_to_group(self, group_id, requester_user_id, room_id, content):
+ async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
"""Add room to group
"""
RoomID.from_string(room_id) # Ensure valid room id
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
- yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
+ await self.store.add_room_to_group(group_id, room_id, is_public=is_public)
return {}
- @defer.inlineCallbacks
- def update_room_in_group(
+ async def update_room_in_group(
self, group_id, requester_user_id, room_id, config_key, content
):
"""Update room in group
"""
RoomID.from_string(room_id) # Ensure valid room id
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
if config_key == "m.visibility":
is_public = _parse_visibility_dict(content)
- yield self.store.update_room_in_group_visibility(
+ await self.store.update_room_in_group_visibility(
group_id, room_id, is_public=is_public
)
else:
@@ -576,36 +556,34 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
- @defer.inlineCallbacks
- def remove_room_from_group(self, group_id, requester_user_id, room_id):
+ async def remove_room_from_group(self, group_id, requester_user_id, room_id):
"""Remove room from group
"""
- yield self.check_group_is_ours(
+ await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
- yield self.store.remove_room_from_group(group_id, room_id)
+ await self.store.remove_room_from_group(group_id, room_id)
return {}
- @defer.inlineCallbacks
- def invite_to_group(self, group_id, user_id, requester_user_id, content):
+ async def invite_to_group(self, group_id, user_id, requester_user_id, content):
"""Invite user to group
"""
- group = yield self.check_group_is_ours(
+ group = await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
# TODO: Check if user knocked
- invited_users = yield self.store.get_invited_users_in_group(group_id)
+ invited_users = await self.store.get_invited_users_in_group(group_id)
if user_id in invited_users:
raise SynapseError(
400, "User already invited to group", errcode=Codes.BAD_STATE
)
- user_results = yield self.store.get_users_in_group(
+ user_results = await self.store.get_users_in_group(
group_id, include_private=True
)
if user_id in (user_result["user_id"] for user_result in user_results):
@@ -618,18 +596,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
- res = yield groups_local.on_invite(group_id, user_id, content)
+ res = await groups_local.on_invite(group_id, user_id, content)
local_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content.update({"attestation": local_attestation})
- res = yield self.transport_client.invite_to_group_notification(
+ res = await self.transport_client.invite_to_group_notification(
get_domain_from_id(user_id), group_id, user_id, content
)
user_profile = res.get("user_profile", {})
- yield self.store.add_remote_profile_cache(
+ await self.store.add_remote_profile_cache(
user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
@@ -639,13 +617,13 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if not self.hs.is_mine_id(user_id):
remote_attestation = res["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation, user_id=user_id, group_id=group_id
)
else:
remote_attestation = None
- yield self.store.add_user_to_group(
+ await self.store.add_user_to_group(
group_id,
user_id,
is_admin=False,
@@ -654,15 +632,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
remote_attestation=remote_attestation,
)
elif res["state"] == "invite":
- yield self.store.add_group_invite(group_id, user_id)
+ await self.store.add_group_invite(group_id, user_id)
return {"state": "invite"}
elif res["state"] == "reject":
return {"state": "reject"}
else:
raise SynapseError(502, "Unknown state returned by HS")
- @defer.inlineCallbacks
- def _add_user(self, group_id, user_id, content):
+ async def _add_user(self, group_id, user_id, content):
"""Add a user to a group based on a content dict.
See accept_invite, join_group.
@@ -672,7 +649,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
remote_attestation = content["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation, user_id=user_id, group_id=group_id
)
else:
@@ -681,7 +658,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
is_public = _parse_visibility_from_contents(content)
- yield self.store.add_user_to_group(
+ await self.store.add_user_to_group(
group_id,
user_id,
is_admin=False,
@@ -692,59 +669,55 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return local_attestation
- @defer.inlineCallbacks
- def accept_invite(self, group_id, requester_user_id, content):
+ async def accept_invite(self, group_id, requester_user_id, content):
"""User tries to accept an invite to the group.
This is different from them asking to join, and so should error if no
invite exists (and they're not a member of the group)
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_invited = yield self.store.is_user_invited_to_local_group(
+ is_invited = await self.store.is_user_invited_to_local_group(
group_id, requester_user_id
)
if not is_invited:
raise SynapseError(403, "User not invited to group")
- local_attestation = yield self._add_user(group_id, requester_user_id, content)
+ local_attestation = await self._add_user(group_id, requester_user_id, content)
return {"state": "join", "attestation": local_attestation}
- @defer.inlineCallbacks
- def join_group(self, group_id, requester_user_id, content):
+ async def join_group(self, group_id, requester_user_id, content):
"""User tries to join the group.
This will error if the group requires an invite/knock to join
"""
- group_info = yield self.check_group_is_ours(
+ group_info = await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True
)
if group_info["join_policy"] != "open":
raise SynapseError(403, "Group is not publicly joinable")
- local_attestation = yield self._add_user(group_id, requester_user_id, content)
+ local_attestation = await self._add_user(group_id, requester_user_id, content)
return {"state": "join", "attestation": local_attestation}
- @defer.inlineCallbacks
- def knock(self, group_id, requester_user_id, content):
+ async def knock(self, group_id, requester_user_id, content):
"""A user requests becoming a member of the group
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError()
- @defer.inlineCallbacks
- def accept_knock(self, group_id, requester_user_id, content):
+ async def accept_knock(self, group_id, requester_user_id, content):
"""Accept a users knock to the room.
Errors if the user hasn't knocked, rather than inviting them.
"""
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError()
@@ -872,8 +845,6 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
group_id (str)
request_user_id (str)
- Returns:
- Deferred
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 3b781d98..61dc4bea 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
import synapse.types
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import LimitExceededError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -44,11 +44,26 @@ class BaseHandler(object):
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
- self.ratelimiter = hs.get_ratelimiter()
- self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs
+ # The rate_hz and burst_count are overridden on a per-user basis
+ self.request_ratelimiter = Ratelimiter(
+ clock=self.clock, rate_hz=0, burst_count=0
+ )
+ self._rc_message = self.hs.config.rc_message
+
+ # Check whether ratelimiting room admin message redaction is enabled
+ # by the presence of rate limits in the config
+ if self.hs.config.rc_admin_redaction:
+ self.admin_redaction_ratelimiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=self.hs.config.rc_admin_redaction.per_second,
+ burst_count=self.hs.config.rc_admin_redaction.burst_count,
+ )
+ else:
+ self.admin_redaction_ratelimiter = None
+
self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
@@ -70,7 +85,6 @@ class BaseHandler(object):
Raises:
LimitExceededError if the request should be ratelimited
"""
- time_now = self.clock.time()
user_id = requester.user.to_string()
# The AS user itself is never rate limited.
@@ -83,48 +97,32 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited():
return
+ messages_per_second = self._rc_message.per_second
+ burst_count = self._rc_message.burst_count
+
# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
if override:
- # If overriden with a null Hz then ratelimiting has been entirely
+ # If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
+
+ if is_admin_redaction and self.admin_redaction_ratelimiter:
+ # If we have separate config for admin redactions, use a separate
+ # ratelimiter as to not have user_ids clash
+ self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
else:
- # We default to different values if this is an admin redaction and
- # the config is set
- if is_admin_redaction and self.hs.config.rc_admin_redaction:
- messages_per_second = self.hs.config.rc_admin_redaction.per_second
- burst_count = self.hs.config.rc_admin_redaction.burst_count
- else:
- messages_per_second = self.hs.config.rc_message.per_second
- burst_count = self.hs.config.rc_message.burst_count
-
- if is_admin_redaction and self.hs.config.rc_admin_redaction:
- # If we have separate config for admin redactions we use a separate
- # ratelimiter
- allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action(
- user_id,
- time_now,
- rate_hz=messages_per_second,
- burst_count=burst_count,
- update=update,
- )
- else:
- allowed, time_allowed = self.ratelimiter.can_do_action(
+ # Override rate and burst count per-user
+ self.request_ratelimiter.ratelimit(
user_id,
- time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now))
- )
async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 5c20e291..119678e6 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -80,7 +80,9 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled
+ self._sso_enabled = (
+ hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
+ )
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
@@ -106,7 +108,11 @@ class AuthHandler(BaseHandler):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
- self._failed_uia_attempts_ratelimiter = Ratelimiter()
+ self._failed_uia_attempts_ratelimiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
self._clock = self.hs.get_clock()
@@ -126,13 +132,13 @@ class AuthHandler(BaseHandler):
# It notifies the user they are about to give access to their matrix account
# to the client.
self._sso_redirect_confirm_template = load_jinja2_templates(
- hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+ hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
)[0]
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = load_jinja2_templates(
- hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
+ hs.config.sso_template_dir, ["sso_auth_confirm.html"],
)[0]
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
@@ -194,13 +200,7 @@ class AuthHandler(BaseHandler):
user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
- self._failed_uia_attempts_ratelimiter.ratelimit(
- user_id,
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=False,
- )
+ self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types]
@@ -210,14 +210,8 @@ class AuthHandler(BaseHandler):
flows, request, request_body, clientip, description
)
except LoginError:
- # Update the ratelimite to say we failed (`can_do_action` doesn't raise).
- self._failed_uia_attempts_ratelimiter.can_do_action(
- user_id,
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=True,
- )
+ # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
+ self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
raise
# find the completed login type
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 9bd941b5..230d1702 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Any, Dict, Optional
from six import iteritems, itervalues
@@ -29,7 +30,12 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import (
+ RoomStreamToken,
+ get_domain_from_id,
+ get_verify_key_from_cross_signing_key,
+)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -535,6 +541,15 @@ class DeviceListUpdater(object):
iterable=True,
)
+ # Attempt to resync out of sync device lists every 30s.
+ self._resync_retry_in_progress = False
+ self.clock.looping_call(
+ run_as_background_process,
+ 30 * 1000,
+ func=self._maybe_retry_device_resync,
+ desc="_maybe_retry_device_resync",
+ )
+
@trace
@defer.inlineCallbacks
def incoming_device_list_update(self, origin, edu_content):
@@ -679,25 +694,83 @@ class DeviceListUpdater(object):
return False
@defer.inlineCallbacks
- def user_device_resync(self, user_id):
+ def _maybe_retry_device_resync(self):
+ """Retry to resync device lists that are out of sync, except if another retry is
+ in progress.
+ """
+ if self._resync_retry_in_progress:
+ return
+
+ try:
+ # Prevent another call of this function to retry resyncing device lists so
+ # we don't send too many requests.
+ self._resync_retry_in_progress = True
+ # Get all of the users that need resyncing.
+ need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
+ # Iterate over the set of user IDs.
+ for user_id in need_resync:
+ try:
+ # Try to resync the current user's devices list.
+ result = yield self.user_device_resync(
+ user_id=user_id, mark_failed_as_stale=False,
+ )
+
+ # user_device_resync only returns a result if it managed to
+ # successfully resync and update the database. Updating the table
+ # of users requiring resync isn't necessary here as
+ # user_device_resync already does it (through
+ # self.store.update_remote_device_list_cache).
+ if result:
+ logger.debug(
+ "Successfully resynced the device list for %s", user_id,
+ )
+ except Exception as e:
+ # If there was an issue resyncing this user, e.g. if the remote
+ # server sent a malformed result, just log the error instead of
+ # aborting all the subsequent resyncs.
+ logger.debug(
+ "Could not resync the device list for %s: %s", user_id, e,
+ )
+ finally:
+ # Allow future calls to retry resyncinc out of sync device lists.
+ self._resync_retry_in_progress = False
+
+ @defer.inlineCallbacks
+ def user_device_resync(self, user_id, mark_failed_as_stale=True):
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id (str): The user's id whose device_list will be updated.
+ mark_failed_as_stale (bool): Whether to mark the user's device list as stale
+ if the attempt to resync failed.
Returns:
Deferred[dict]: a dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
+ logger.debug("Attempting to resync the device list for %s", user_id)
log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
- except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
- # TODO: Remember that we are now out of sync and try again
- # later
- logger.warning("Failed to handle device list update for %s", user_id)
+ except NotRetryingDestination:
+ if mark_failed_as_stale:
+ # Mark the remote user's device list as stale so we know we need to retry
+ # it later.
+ yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+
+ return
+ except (RequestSendFailed, HttpResponseException) as e:
+ logger.warning(
+ "Failed to handle device list update for %s: %s", user_id, e,
+ )
+
+ if mark_failed_as_stale:
+ # Mark the remote user's device list as stale so we know we need to retry
+ # it later.
+ yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
@@ -711,18 +784,29 @@ class DeviceListUpdater(object):
logger.info(e)
return
except Exception as e:
- # TODO: Remember that we are now out of sync and try again
- # later
set_tag("error", True)
log_kv(
{"message": "Exception raised by federation request", "exception": e}
)
logger.exception("Failed to handle device list update for %s", user_id)
+
+ if mark_failed_as_stale:
+ # Mark the remote user's device list as stale so we know we need to retry
+ # it later.
+ yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+
return
log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
+ # Get the master key and the self-signing key for this user if provided in the
+ # response (None if not in the response).
+ # The response will not contain the user signing key, as this key is only used by
+ # its owner, thus it doesn't make sense to send it over federation.
+ master_key = result.get("master_key")
+ self_signing_key = result.get("self_signing_key")
+
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
@@ -752,6 +836,13 @@ class DeviceListUpdater(object):
yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
+
+ # Handle cross-signing keys.
+ cross_signing_device_ids = yield self.process_cross_signing_key_update(
+ user_id, master_key, self_signing_key,
+ )
+ device_ids = device_ids + cross_signing_device_ids
+
yield self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
@@ -759,3 +850,40 @@ class DeviceListUpdater(object):
self._seen_updates[user_id] = {stream_id}
defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def process_cross_signing_key_update(
+ self,
+ user_id: str,
+ master_key: Optional[Dict[str, Any]],
+ self_signing_key: Optional[Dict[str, Any]],
+ ) -> list:
+ """Process the given new master and self-signing key for the given remote user.
+
+ Args:
+ user_id: The ID of the user these keys are for.
+ master_key: The dict of the cross-signing master key as returned by the
+ remote server.
+ self_signing_key: The dict of the cross-signing self-signing key as returned
+ by the remote server.
+
+ Return:
+ The device IDs for the given keys.
+ """
+ device_ids = []
+
+ if master_key:
+ yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+ _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ device_ids.append(verify_key.version)
+ if self_signing_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
+ device_ids.append(verify_key.version)
+
+ return device_ids
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8f1bc032..774a2526 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1291,6 +1291,7 @@ class SigningKeyEduUpdater(object):
"""
device_handler = self.e2e_keys_handler.device_handler
+ device_list_updater = device_handler.device_list_updater
with (yield self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
@@ -1303,22 +1304,9 @@ class SigningKeyEduUpdater(object):
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- if master_key:
- yield self.store.set_e2e_cross_signing_key(
- user_id, "master", master_key
- )
- _, verify_key = get_verify_key_from_cross_signing_key(master_key)
- # verify_key is a VerifyKey from signedjson, which uses
- # .version to denote the portion of the key ID after the
- # algorithm and colon, which is the device ID
- device_ids.append(verify_key.version)
- if self_signing_key:
- yield self.store.set_e2e_cross_signing_key(
- user_id, "self_signing", self_signing_key
- )
- _, verify_key = get_verify_key_from_cross_signing_key(
- self_signing_key
- )
- device_ids.append(verify_key.version)
+ new_device_ids = yield device_list_updater.process_cross_signing_key_update(
+ user_id, master_key, self_signing_key,
+ )
+ device_ids = device_ids + new_device_ids
yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4e5c6455..3e60774b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -40,6 +40,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
FederationError,
+ HttpResponseException,
RequestSendFailed,
SynapseError,
)
@@ -125,10 +126,10 @@ class FederationHandler(BaseHandler):
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
+ self._instance_name = hs.get_instance_name()
+ self._replication = hs.get_replication_data_handler()
- self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
- hs
- )
+ self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
hs
)
@@ -500,7 +501,7 @@ class FederationHandler(BaseHandler):
min_depth=min_depth,
timeout=60000,
)
- except RequestSendFailed as e:
+ except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
@@ -1038,6 +1039,12 @@ class FederationHandler(BaseHandler):
except SynapseError as e:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
+ except HttpResponseException as e:
+ if 400 <= e.code < 500:
+ raise e.to_synapse_error()
+
+ logger.info("Failed to backfill from %s because %s", dom, e)
+ continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
@@ -1214,7 +1221,7 @@ class FederationHandler(BaseHandler):
async def do_invite_join(
self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
- ) -> None:
+ ) -> Tuple[str, int]:
""" Attempts to join the `joinee` to the room `room_id` via the
servers contained in `target_hosts`.
@@ -1235,6 +1242,10 @@ class FederationHandler(BaseHandler):
content: The event content to use for the join event.
"""
+ # TODO: We should be able to call this on workers, but the upgrading of
+ # room stuff after join currently doesn't work on workers.
+ assert self.config.worker.worker_app is None
+
logger.debug("Joining %s to %s", joinee, room_id)
origin, event, room_version_obj = await self._make_and_verify_event(
@@ -1297,15 +1308,23 @@ class FederationHandler(BaseHandler):
room_id=room_id, room_version=room_version_obj,
)
- await self._persist_auth_tree(
+ max_stream_id = await self._persist_auth_tree(
origin, auth_chain, state, event, room_version_obj
)
+ # We wait here until this instance has seen the events come down
+ # replication (if we're using replication) as the below uses caches.
+ #
+ # TODO: Currently the events stream is written to from master
+ await self._replication.wait_for_stream_position(
+ self.config.worker.writers.events, "events", max_stream_id
+ )
+
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = await self.store.get_room_predecessor(room_id)
if not predecessor or not isinstance(predecessor.get("room_id"), str):
- return
+ return event.event_id, max_stream_id
old_room_id = predecessor["room_id"]
logger.debug(
"Found predecessor for %s during remote join: %s", room_id, old_room_id
@@ -1318,6 +1337,7 @@ class FederationHandler(BaseHandler):
)
logger.debug("Finished joining %s to %s", joinee, room_id)
+ return event.event_id, max_stream_id
finally:
room_queue = self.room_queues[room_id]
del self.room_queues[room_id]
@@ -1547,7 +1567,7 @@ class FederationHandler(BaseHandler):
async def do_remotely_reject_invite(
self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
- ) -> EventBase:
+ ) -> Tuple[EventBase, int]:
origin, event, room_version = await self._make_and_verify_event(
target_hosts, room_id, user_id, "leave", content=content
)
@@ -1567,9 +1587,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(target_hosts, event)
context = await self.state_handler.compute_event_context(event)
- await self.persist_events_and_notify([(event, context)])
+ stream_id = await self.persist_events_and_notify([(event, context)])
- return event
+ return event, stream_id
async def _make_and_verify_event(
self,
@@ -1881,7 +1901,7 @@ class FederationHandler(BaseHandler):
state: List[EventBase],
event: EventBase,
room_version: RoomVersion,
- ) -> None:
+ ) -> int:
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event separately. Notifies about the persisted events
@@ -1975,7 +1995,7 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
- await self.persist_events_and_notify([(event, new_event_context)])
+ return await self.persist_events_and_notify([(event, new_event_context)])
async def _prep_event(
self,
@@ -2681,8 +2701,7 @@ class FederationHandler(BaseHandler):
member_handler = self.hs.get_room_member_handler()
await member_handler.send_membership_event(None, event, context)
- @defer.inlineCallbacks
- def add_display_name_to_third_party_invite(
+ async def add_display_name_to_third_party_invite(
self, room_version, event_dict, event, context
):
key = (
@@ -2690,10 +2709,10 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
- original_invite = yield self.store.get_event(
+ original_invite = await self.store.get_event(
original_invite_id, allow_none=True
)
if original_invite:
@@ -2714,14 +2733,13 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
- event, context = yield self.event_creation_handler.create_new_client_event(
+ event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
EventValidator().validate_new(event, self.config)
return (event, context)
- @defer.inlineCallbacks
- def _check_signature(self, event, context):
+ async def _check_signature(self, event, context):
"""
Checks that the signature in the event is consistent with its invite.
@@ -2738,12 +2756,12 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
if invite_event_id:
- invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+ invite_event = await self.store.get_event(invite_event_id, allow_none=True)
if not invite_event:
raise AuthError(403, "Could not find invite")
@@ -2792,7 +2810,7 @@ class FederationHandler(BaseHandler):
raise
try:
if "key_validity_url" in public_key_object:
- yield self._check_key_revocation(
+ await self._check_key_revocation(
public_key, public_key_object["key_validity_url"]
)
except Exception:
@@ -2806,8 +2824,7 @@ class FederationHandler(BaseHandler):
last_exception = e
raise last_exception
- @defer.inlineCallbacks
- def _check_key_revocation(self, public_key, url):
+ async def _check_key_revocation(self, public_key, url):
"""
Checks whether public_key has been revoked.
@@ -2821,7 +2838,7 @@ class FederationHandler(BaseHandler):
for revocation.
"""
try:
- response = yield self.http_client.get_json(url, {"public_key": public_key})
+ response = await self.http_client.get_json(url, {"public_key": public_key})
except Exception:
raise SynapseError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
@@ -2831,7 +2848,7 @@ class FederationHandler(BaseHandler):
self,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ) -> None:
+ ) -> int:
"""Persists events and tells the notifier/pushers about them, if
necessary.
@@ -2840,12 +2857,14 @@ class FederationHandler(BaseHandler):
backfilled: Whether these events are a result of
backfilling or not
"""
- if self.config.worker_app:
- await self._send_events_to_master(
+ if self.config.worker.writers.events != self._instance_name:
+ result = await self._send_events(
+ instance_name=self.config.worker.writers.events,
store=self.store,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)
+ return result["max_stream_id"]
else:
max_stream_id = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
@@ -2860,6 +2879,8 @@ class FederationHandler(BaseHandler):
for event, _ in event_and_contexts:
await self._notify_persisted_event(event, max_stream_id)
+ return max_stream_id
+
async def _notify_persisted_event(
self, event: EventBase, max_stream_id: int
) -> None:
@@ -2916,8 +2937,7 @@ class FederationHandler(BaseHandler):
else:
user_joined_room(self.distributor, user, room_id)
- @defer.inlineCallbacks
- def get_room_complexity(self, remote_room_hosts, room_id):
+ async def get_room_complexity(self, remote_room_hosts, room_id):
"""
Fetch the complexity of a remote room over federation.
@@ -2931,12 +2951,12 @@ class FederationHandler(BaseHandler):
"""
for host in remote_room_hosts:
- res = yield self.federation_client.get_room_complexity(host, room_id)
+ res = await self.federation_client.get_room_complexity(host, room_id)
# We got a result, return it.
if res:
- defer.returnValue(res)
+ return res
# We fell off the bottom, couldn't get the complexity from anyone. Oh
# well.
- defer.returnValue(None)
+ return None
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ca5c8381..ebe8d25b 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -18,8 +18,6 @@ import logging
from six import iteritems
-from twisted.internet import defer
-
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import get_domain_from_id
@@ -92,19 +90,18 @@ class GroupsLocalWorkerHandler(object):
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
- @defer.inlineCallbacks
- def get_group_summary(self, group_id, requester_user_id):
+ async def get_group_summary(self, group_id, requester_user_id):
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
"""
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.get_group_summary(
+ res = await self.groups_server_handler.get_group_summary(
group_id, requester_user_id
)
else:
try:
- res = yield self.transport_client.get_group_summary(
+ res = await self.transport_client.get_group_summary(
get_domain_from_id(group_id), group_id, requester_user_id
)
except HttpResponseException as e:
@@ -122,7 +119,7 @@ class GroupsLocalWorkerHandler(object):
attestation = entry.pop("attestation", {})
try:
if get_domain_from_id(g_user_id) != group_server_name:
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
attestation,
group_id=group_id,
user_id=g_user_id,
@@ -139,19 +136,18 @@ class GroupsLocalWorkerHandler(object):
# Add `is_publicised` flag to indicate whether the user has publicised their
# membership of the group on their profile
- result = yield self.store.get_publicised_groups_for_user(requester_user_id)
+ result = await self.store.get_publicised_groups_for_user(requester_user_id)
is_publicised = group_id in result
res.setdefault("user", {})["is_publicised"] = is_publicised
return res
- @defer.inlineCallbacks
- def get_users_in_group(self, group_id, requester_user_id):
+ async def get_users_in_group(self, group_id, requester_user_id):
"""Get users in a group
"""
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.get_users_in_group(
+ res = await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
return res
@@ -159,7 +155,7 @@ class GroupsLocalWorkerHandler(object):
group_server_name = get_domain_from_id(group_id)
try:
- res = yield self.transport_client.get_users_in_group(
+ res = await self.transport_client.get_users_in_group(
get_domain_from_id(group_id), group_id, requester_user_id
)
except HttpResponseException as e:
@@ -174,7 +170,7 @@ class GroupsLocalWorkerHandler(object):
attestation = entry.pop("attestation", {})
try:
if get_domain_from_id(g_user_id) != group_server_name:
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
attestation,
group_id=group_id,
user_id=g_user_id,
@@ -188,15 +184,13 @@ class GroupsLocalWorkerHandler(object):
return res
- @defer.inlineCallbacks
- def get_joined_groups(self, user_id):
- group_ids = yield self.store.get_joined_groups(user_id)
+ async def get_joined_groups(self, user_id):
+ group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}
- @defer.inlineCallbacks
- def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id):
if self.hs.is_mine_id(user_id):
- result = yield self.store.get_publicised_groups_for_user(user_id)
+ result = await self.store.get_publicised_groups_for_user(user_id)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
@@ -206,7 +200,7 @@ class GroupsLocalWorkerHandler(object):
return {"groups": result}
else:
try:
- bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+ bulk_result = await self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id]
)
except HttpResponseException as e:
@@ -218,8 +212,7 @@ class GroupsLocalWorkerHandler(object):
# TODO: Verify attestations
return {"groups": result}
- @defer.inlineCallbacks
- def bulk_get_publicised_groups(self, user_ids, proxy=True):
+ async def bulk_get_publicised_groups(self, user_ids, proxy=True):
destinations = {}
local_users = set()
@@ -236,7 +229,7 @@ class GroupsLocalWorkerHandler(object):
failed_results = []
for destination, dest_user_ids in iteritems(destinations):
try:
- r = yield self.transport_client.bulk_get_publicised_groups(
+ r = await self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids)
)
results.update(r["users"])
@@ -244,7 +237,7 @@ class GroupsLocalWorkerHandler(object):
failed_results.extend(dest_user_ids)
for uid in local_users:
- results[uid] = yield self.store.get_publicised_groups_for_user(uid)
+ results[uid] = await self.store.get_publicised_groups_for_user(uid)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
@@ -333,12 +326,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- @defer.inlineCallbacks
- def join_group(self, group_id, user_id, content):
+ async def join_group(self, group_id, user_id, content):
"""Request to join a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.join_group(group_id, user_id, content)
+ await self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@@ -346,7 +338,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
content["attestation"] = local_attestation
try:
- res = yield self.transport_client.join_group(
+ res = await self.transport_client.join_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
@@ -356,7 +348,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
remote_attestation = res["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
@@ -366,7 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# TODO: Check that the group is public and we're being added publically
is_publicised = content.get("publicise", False)
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
@@ -379,12 +371,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- @defer.inlineCallbacks
- def accept_invite(self, group_id, user_id, content):
+ async def accept_invite(self, group_id, user_id, content):
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.accept_invite(group_id, user_id, content)
+ await self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@@ -392,7 +383,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
content["attestation"] = local_attestation
try:
- res = yield self.transport_client.accept_group_invite(
+ res = await self.transport_client.accept_group_invite(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
@@ -402,7 +393,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
remote_attestation = res["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
@@ -412,7 +403,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# TODO: Check that the group is public and we're being added publically
is_publicised = content.get("publicise", False)
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
@@ -425,18 +416,17 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- @defer.inlineCallbacks
- def invite(self, group_id, user_id, requester_user_id, config):
+ async def invite(self, group_id, user_id, requester_user_id, config):
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.invite_to_group(
+ res = await self.groups_server_handler.invite_to_group(
group_id, user_id, requester_user_id, content
)
else:
try:
- res = yield self.transport_client.invite_to_group(
+ res = await self.transport_client.invite_to_group(
get_domain_from_id(group_id),
group_id,
user_id,
@@ -450,8 +440,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- @defer.inlineCallbacks
- def on_invite(self, group_id, user_id, content):
+ async def on_invite(self, group_id, user_id, content):
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
@@ -466,7 +455,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
if "avatar_url" in content["profile"]:
local_profile["avatar_url"] = content["profile"]["avatar_url"]
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="invite",
@@ -474,7 +463,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
try:
- user_profile = yield self.profile_handler.get_profile(user_id)
+ user_profile = await self.profile_handler.get_profile(user_id)
except Exception as e:
logger.warning("No profile for user %s: %s", user_id, e)
user_profile = {}
@@ -516,12 +505,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- @defer.inlineCallbacks
- def user_removed_from_group(self, group_id, user_id, content):
+ async def user_removed_from_group(self, group_id, user_id, content):
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 0f0e632b..4ba00427 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -25,7 +25,6 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
-from twisted.internet import defer
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
@@ -60,8 +59,7 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_http_client()
self.hs = hs
- @defer.inlineCallbacks
- def threepid_from_creds(self, id_server, creds):
+ async def threepid_from_creds(self, id_server, creds):
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
@@ -97,7 +95,7 @@ class IdentityHandler(BaseHandler):
url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
try:
- data = yield self.http_client.get_json(url, query_params)
+ data = await self.http_client.get_json(url, query_params)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
@@ -120,8 +118,7 @@ class IdentityHandler(BaseHandler):
logger.info("%s reported non-validated threepid: %s", id_server, creds)
return None
- @defer.inlineCallbacks
- def bind_threepid(
+ async def bind_threepid(
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
):
"""Bind a 3PID to an identity server
@@ -161,12 +158,12 @@ class IdentityHandler(BaseHandler):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
bind_url, bind_data, headers=headers
)
# Remember where we bound the threepid
- yield self.store.add_user_bound_threepid(
+ await self.store.add_user_bound_threepid(
user_id=mxid,
medium=data["medium"],
address=data["address"],
@@ -185,13 +182,12 @@ class IdentityHandler(BaseHandler):
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
- res = yield self.bind_threepid(
+ res = await self.bind_threepid(
client_secret, sid, mxid, id_server, id_access_token, use_v2=False
)
return res
- @defer.inlineCallbacks
- def try_unbind_threepid(self, mxid, threepid):
+ async def try_unbind_threepid(self, mxid, threepid):
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
@@ -211,7 +207,7 @@ class IdentityHandler(BaseHandler):
if threepid.get("id_server"):
id_servers = [threepid["id_server"]]
else:
- id_servers = yield self.store.get_id_servers_user_bound(
+ id_servers = await self.store.get_id_servers_user_bound(
user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)
@@ -221,14 +217,13 @@ class IdentityHandler(BaseHandler):
changed = True
for id_server in id_servers:
- changed &= yield self.try_unbind_threepid_with_id_server(
+ changed &= await self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server
)
return changed
- @defer.inlineCallbacks
- def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
+ async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
"""Removes a binding from an identity server
Args:
@@ -266,7 +261,7 @@ class IdentityHandler(BaseHandler):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
- yield self.blacklisting_http_client.post_json_get_json(
+ await self.blacklisting_http_client.post_json_get_json(
url, content, headers
)
changed = True
@@ -281,7 +276,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
- yield self.store.remove_user_bound_threepid(
+ await self.store.remove_user_bound_threepid(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
@@ -290,8 +285,7 @@ class IdentityHandler(BaseHandler):
return changed
- @defer.inlineCallbacks
- def send_threepid_validation(
+ async def send_threepid_validation(
self,
email_address,
client_secret,
@@ -319,7 +313,7 @@ class IdentityHandler(BaseHandler):
"""
# Check that this email/client_secret/send_attempt combo is new or
# greater than what we've seen previously
- session = yield self.store.get_threepid_validation_session(
+ session = await self.store.get_threepid_validation_session(
"email", client_secret, address=email_address, validated=False
)
@@ -353,7 +347,7 @@ class IdentityHandler(BaseHandler):
# Send the mail with the link containing the token, client_secret
# and session_id
try:
- yield send_email_func(email_address, token, client_secret, session_id)
+ await send_email_func(email_address, token, client_secret, session_id)
except Exception:
logger.exception(
"Error sending threepid validation email to %s", email_address
@@ -364,7 +358,7 @@ class IdentityHandler(BaseHandler):
self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
)
- yield self.store.start_or_continue_validation_session(
+ await self.store.start_or_continue_validation_session(
"email",
email_address,
session_id,
@@ -377,8 +371,7 @@ class IdentityHandler(BaseHandler):
return session_id
- @defer.inlineCallbacks
- def requestEmailToken(
+ async def requestEmailToken(
self, id_server, email, client_secret, send_attempt, next_link=None
):
"""
@@ -413,7 +406,7 @@ class IdentityHandler(BaseHandler):
)
try:
- data = yield self.http_client.post_json_get_json(
+ data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
params,
)
@@ -424,8 +417,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
- @defer.inlineCallbacks
- def requestMsisdnToken(
+ async def requestMsisdnToken(
self,
id_server,
country,
@@ -467,7 +459,7 @@ class IdentityHandler(BaseHandler):
)
try:
- data = yield self.http_client.post_json_get_json(
+ data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
params,
)
@@ -488,8 +480,7 @@ class IdentityHandler(BaseHandler):
)
return data
- @defer.inlineCallbacks
- def validate_threepid_session(self, client_secret, sid):
+ async def validate_threepid_session(self, client_secret, sid):
"""Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally.
@@ -511,12 +502,12 @@ class IdentityHandler(BaseHandler):
# Try to validate as email
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server
- validation_session = yield self.threepid_from_creds(
+ validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
- validation_session = yield self.store.get_threepid_validation_session(
+ validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
@@ -526,14 +517,13 @@ class IdentityHandler(BaseHandler):
# Try to validate as msisdn
if self.hs.config.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
- validation_session = yield self.threepid_from_creds(
+ validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
return validation_session
- @defer.inlineCallbacks
- def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+ async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
"""Proxy a POST submitToken request to an identity server for verification purposes
Args:
@@ -554,11 +544,9 @@ class IdentityHandler(BaseHandler):
body = {"client_secret": client_secret, "sid": sid, "token": token}
try:
- return (
- yield self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
- body,
- )
+ return await self.http_client.post_json_get_json(
+ id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
+ body,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@@ -566,8 +554,7 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
- @defer.inlineCallbacks
- def lookup_3pid(self, id_server, medium, address, id_access_token=None):
+ async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
Args:
@@ -583,7 +570,7 @@ class IdentityHandler(BaseHandler):
"""
if id_access_token is not None:
try:
- results = yield self._lookup_3pid_v2(
+ results = await self._lookup_3pid_v2(
id_server, id_access_token, medium, address
)
return results
@@ -602,10 +589,9 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return (yield self._lookup_3pid_v1(id_server, medium, address))
+ return await self._lookup_3pid_v1(id_server, medium, address)
- @defer.inlineCallbacks
- def _lookup_3pid_v1(self, id_server, medium, address):
+ async def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
@@ -618,7 +604,7 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = yield self.blacklisting_http_client.get_json(
+ data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)
@@ -626,7 +612,7 @@ class IdentityHandler(BaseHandler):
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
- yield self._verify_any_signature(data, id_server)
+ await self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@@ -635,8 +621,7 @@ class IdentityHandler(BaseHandler):
return None
- @defer.inlineCallbacks
- def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
+ async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
@@ -651,7 +636,7 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = yield self.blacklisting_http_client.get_json(
+ hash_details = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
@@ -718,7 +703,7 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = yield self.blacklisting_http_client.post_json_get_json(
+ lookup_results = await self.blacklisting_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
@@ -746,13 +731,12 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
- @defer.inlineCallbacks
- def _verify_any_signature(self, data, server_hostname):
+ async def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
- key_data = yield self.blacklisting_http_client.get_json(
+ key_data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
@@ -771,8 +755,7 @@ class IdentityHandler(BaseHandler):
)
return
- @defer.inlineCallbacks
- def ask_id_server_for_third_party_invite(
+ async def ask_id_server_for_third_party_invite(
self,
requester,
id_server,
@@ -845,7 +828,7 @@ class IdentityHandler(BaseHandler):
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
@@ -865,7 +848,7 @@ class IdentityHandler(BaseHandler):
url = base_url + "/api/v1/store-invite"
try:
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
@@ -883,7 +866,7 @@ class IdentityHandler(BaseHandler):
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
- data = yield self.blacklisting_http_client.post_urlencoded_get_json(
+ data = await self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a622a600..649ca1f0 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import Optional, Tuple
from six import iteritems, itervalues, string_types
@@ -42,6 +42,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
+from synapse.events import EventBase
from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -72,7 +73,6 @@ class MessageHandler(object):
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
- self._is_worker_app = bool(hs.config.worker_app)
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
@@ -260,7 +260,6 @@ class MessageHandler(object):
Args:
event (EventBase): The event to schedule the expiry of.
"""
- assert not self._is_worker_app
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
if not isinstance(expiry_ts, int) or event.is_state():
@@ -363,14 +362,16 @@ class EventCreationHandler(object):
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
- self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
+ self._is_event_writer = (
+ self.config.worker.writers.events == hs.get_instance_name()
+ )
self.room_invite_state_types = self.hs.config.room_invite_state_types
- self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
+ self.send_event = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
@@ -486,9 +487,13 @@ class EventCreationHandler(object):
try:
if "displayname" not in content:
- content["displayname"] = yield profile.get_displayname(target)
+ displayname = yield profile.get_displayname(target)
+ if displayname is not None:
+ content["displayname"] = displayname
if "avatar_url" not in content:
- content["avatar_url"] = yield profile.get_avatar_url(target)
+ avatar_url = yield profile.get_avatar_url(target)
+ if avatar_url is not None:
+ content["avatar_url"] = avatar_url
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s", target, e
@@ -628,7 +633,9 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
- async def send_nonmember_event(self, requester, event, context, ratelimit=True):
+ async def send_nonmember_event(
+ self, requester, event, context, ratelimit=True
+ ) -> int:
"""
Persists and notifies local clients and federation of an event.
@@ -637,6 +644,9 @@ class EventCreationHandler(object):
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
+
+ Return:
+ The stream_id of the persisted event.
"""
if event.type == EventTypes.Member:
raise SynapseError(
@@ -657,7 +667,7 @@ class EventCreationHandler(object):
)
return prev_state
- await self.handle_new_client_event(
+ return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@@ -686,7 +696,7 @@ class EventCreationHandler(object):
async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
- ):
+ ) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
@@ -709,10 +719,10 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
- await self.send_nonmember_event(
+ stream_id = await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
- return event
+ return event, stream_id
@measure_func("create_new_client_event")
@defer.inlineCallbacks
@@ -772,7 +782,7 @@ class EventCreationHandler(object):
@measure_func("handle_new_client_event")
async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
- ):
+ ) -> int:
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@@ -785,6 +795,9 @@ class EventCreationHandler(object):
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
+
+ Return:
+ The stream_id of the persisted event.
"""
if event.is_state() and (event.type, event.state_key) == (
@@ -824,8 +837,9 @@ class EventCreationHandler(object):
success = False
try:
# If we're a worker we need to hit out to the master.
- if self.config.worker_app:
- await self.send_event_to_master(
+ if not self._is_event_writer:
+ result = await self.send_event(
+ instance_name=self.config.worker.writers.events,
event_id=event.event_id,
store=self.store,
requester=requester,
@@ -834,14 +848,17 @@ class EventCreationHandler(object):
ratelimit=ratelimit,
extra_users=extra_users,
)
+ stream_id = result["stream_id"]
+ event.internal_metadata.stream_ordering = stream_id
success = True
- return
+ return stream_id
- await self.persist_and_notify_client_event(
+ stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
success = True
+ return stream_id
finally:
if not success:
# Ensure that we actually remove the entries in the push actions
@@ -884,13 +901,13 @@ class EventCreationHandler(object):
async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
- ):
+ ) -> int:
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
- This should only be run on master.
+ This should only be run on the instance in charge of persisting events.
"""
- assert not self.config.worker_app
+ assert self._is_event_writer
if ratelimit:
# We check if this is a room admin redacting an event so that we
@@ -1074,6 +1091,8 @@ class EventCreationHandler(object):
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
+ return event_stream_id
+
async def _bump_active_time(self, user):
try:
presence = self.hs.get_presence_handler()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
new file mode 100644
index 00000000..9c08eb53
--- /dev/null
+++ b/synapse/handlers/oidc_handler.py
@@ -0,0 +1,1049 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+from typing import Dict, Generic, List, Optional, Tuple, TypeVar
+from urllib.parse import urlencode
+
+import attr
+import pymacaroons
+from authlib.common.security import generate_token
+from authlib.jose import JsonWebToken
+from authlib.oauth2.auth import ClientAuth
+from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
+from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
+from jinja2 import Environment, Template
+from pymacaroons.exceptions import (
+ MacaroonDeserializationException,
+ MacaroonInvalidSignatureException,
+)
+from typing_extensions import TypedDict
+
+from twisted.web.client import readBody
+
+from synapse.config import ConfigError
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable
+from synapse.push.mailer import load_jinja2_templates
+from synapse.server import HomeServer
+from synapse.types import UserID, map_username_to_mxid_localpart
+
+logger = logging.getLogger(__name__)
+
+SESSION_COOKIE_NAME = b"oidc_session"
+
+#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
+#: OpenID.Core sec 3.1.3.3.
+Token = TypedDict(
+ "Token",
+ {
+ "access_token": str,
+ "token_type": str,
+ "id_token": Optional[str],
+ "refresh_token": Optional[str],
+ "expires_in": int,
+ "scope": Optional[str],
+ },
+)
+
+#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
+#: there is no real point of doing this in our case.
+JWK = Dict[str, str]
+
+#: A JWK Set, as per RFC7517 sec 5.
+JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+
+
+class OidcError(Exception):
+ """Used to catch errors when calling the token_endpoint
+ """
+
+ def __init__(self, error, error_description=None):
+ self.error = error
+ self.error_description = error_description
+
+ def __str__(self):
+ if self.error_description:
+ return "{}: {}".format(self.error, self.error_description)
+ return self.error
+
+
+class MappingException(Exception):
+ """Used to catch errors when mapping the UserInfo object
+ """
+
+
+class OidcHandler:
+ """Handles requests related to the OpenID Connect login flow.
+ """
+
+ def __init__(self, hs: HomeServer):
+ self._callback_url = hs.config.oidc_callback_url # type: str
+ self._scopes = hs.config.oidc_scopes # type: List[str]
+ self._client_auth = ClientAuth(
+ hs.config.oidc_client_id,
+ hs.config.oidc_client_secret,
+ hs.config.oidc_client_auth_method,
+ ) # type: ClientAuth
+ self._client_auth_method = hs.config.oidc_client_auth_method # type: str
+ self._provider_metadata = OpenIDProviderMetadata(
+ issuer=hs.config.oidc_issuer,
+ authorization_endpoint=hs.config.oidc_authorization_endpoint,
+ token_endpoint=hs.config.oidc_token_endpoint,
+ userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
+ jwks_uri=hs.config.oidc_jwks_uri,
+ ) # type: OpenIDProviderMetadata
+ self._provider_needs_discovery = hs.config.oidc_discover # type: bool
+ self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
+ hs.config.oidc_user_mapping_provider_config
+ ) # type: OidcMappingProvider
+ self._skip_verification = hs.config.oidc_skip_verification # type: bool
+
+ self._http_client = hs.get_proxied_http_client()
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_registration_handler()
+ self._datastore = hs.get_datastore()
+ self._clock = hs.get_clock()
+ self._hostname = hs.hostname # type: str
+ self._server_name = hs.config.server_name # type: str
+ self._macaroon_secret_key = hs.config.macaroon_secret_key
+ self._error_template = load_jinja2_templates(
+ hs.config.sso_template_dir, ["sso_error.html"]
+ )[0]
+
+ # identifier for the external_ids table
+ self._auth_provider_id = "oidc"
+
+ def _render_error(
+ self, request, error: str, error_description: Optional[str] = None
+ ) -> None:
+ """Renders the error template and respond with it.
+
+ This is used to show errors to the user. The template of this page can
+ be found under ``synapse/res/templates/sso_error.html``.
+
+ Args:
+ request: The incoming request from the browser.
+ We'll respond with an HTML page describing the error.
+ error: A technical identifier for this error. Those include
+ well-known OAuth2/OIDC error types like invalid_request or
+ access_denied.
+ error_description: A human-readable description of the error.
+ """
+ html_bytes = self._error_template.render(
+ error=error, error_description=error_description
+ ).encode("utf-8")
+
+ request.setResponseCode(400)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
+ request.write(html_bytes)
+ finish_request(request)
+
+ def _validate_metadata(self):
+ """Verifies the provider metadata.
+
+ This checks the validity of the currently loaded provider. Not
+ everything is checked, only:
+
+ - ``issuer``
+ - ``authorization_endpoint``
+ - ``token_endpoint``
+ - ``response_types_supported`` (checks if "code" is in it)
+ - ``jwks_uri``
+
+ Raises:
+ ValueError: if something in the provider is not valid
+ """
+ # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
+ if self._skip_verification is True:
+ return
+
+ m = self._provider_metadata
+ m.validate_issuer()
+ m.validate_authorization_endpoint()
+ m.validate_token_endpoint()
+
+ if m.get("token_endpoint_auth_methods_supported") is not None:
+ m.validate_token_endpoint_auth_methods_supported()
+ if (
+ self._client_auth_method
+ not in m["token_endpoint_auth_methods_supported"]
+ ):
+ raise ValueError(
+ '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
+ auth_method=self._client_auth_method,
+ supported=m["token_endpoint_auth_methods_supported"],
+ )
+ )
+
+ if m.get("response_types_supported") is not None:
+ m.validate_response_types_supported()
+
+ if "code" not in m["response_types_supported"]:
+ raise ValueError(
+ '"code" not in "response_types_supported" (%r)'
+ % (m["response_types_supported"],)
+ )
+
+ # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+ if self._uses_userinfo:
+ if m.get("userinfo_endpoint") is None:
+ raise ValueError(
+ 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+ )
+ else:
+ # If we're not using userinfo, we need a valid jwks to validate the ID token
+ if m.get("jwks") is None:
+ if m.get("jwks_uri") is not None:
+ m.validate_jwks_uri()
+ else:
+ raise ValueError('"jwks_uri" must be set')
+
+ @property
+ def _uses_userinfo(self) -> bool:
+ """Returns True if the ``userinfo_endpoint`` should be used.
+
+ This is based on the requested scopes: if the scopes include
+ ``openid``, the provider should give use an ID token containing the
+ user informations. If not, we should fetch them using the
+ ``access_token`` with the ``userinfo_endpoint``.
+ """
+
+ # Maybe that should be user-configurable and not inferred?
+ return "openid" not in self._scopes
+
+ async def load_metadata(self) -> OpenIDProviderMetadata:
+ """Load and validate the provider metadata.
+
+ The values metadatas are discovered if ``oidc_config.discovery`` is
+ ``True`` and then cached.
+
+ Raises:
+ ValueError: if something in the provider is not valid
+
+ Returns:
+ The provider's metadata.
+ """
+ # If we are using the OpenID Discovery documents, it needs to be loaded once
+ # FIXME: should there be a lock here?
+ if self._provider_needs_discovery:
+ url = get_well_known_url(self._provider_metadata["issuer"], external=True)
+ metadata_response = await self._http_client.get_json(url)
+ # TODO: maybe update the other way around to let user override some values?
+ self._provider_metadata.update(metadata_response)
+ self._provider_needs_discovery = False
+
+ self._validate_metadata()
+
+ return self._provider_metadata
+
+ async def load_jwks(self, force: bool = False) -> JWKS:
+ """Load the JSON Web Key Set used to sign ID tokens.
+
+ If we're not using the ``userinfo_endpoint``, user infos are extracted
+ from the ID token, which is a JWT signed by keys given by the provider.
+ The keys are then cached.
+
+ Args:
+ force: Force reloading the keys.
+
+ Returns:
+ The key set
+
+ Looks like this::
+
+ {
+ 'keys': [
+ {
+ 'kid': 'abcdef',
+ 'kty': 'RSA',
+ 'alg': 'RS256',
+ 'use': 'sig',
+ 'e': 'XXXX',
+ 'n': 'XXXX',
+ }
+ ]
+ }
+ """
+ if self._uses_userinfo:
+ # We're not using jwt signing, return an empty jwk set
+ return {"keys": []}
+
+ # First check if the JWKS are loaded in the provider metadata.
+ # It can happen either if the provider gives its JWKS in the discovery
+ # document directly or if it was already loaded once.
+ metadata = await self.load_metadata()
+ jwk_set = metadata.get("jwks")
+ if jwk_set is not None and not force:
+ return jwk_set
+
+ # Loading the JWKS using the `jwks_uri` metadata
+ uri = metadata.get("jwks_uri")
+ if not uri:
+ raise RuntimeError('Missing "jwks_uri" in metadata')
+
+ jwk_set = await self._http_client.get_json(uri)
+
+ # Caching the JWKS in the provider's metadata
+ self._provider_metadata["jwks"] = jwk_set
+ return jwk_set
+
+ async def _exchange_code(self, code: str) -> Token:
+ """Exchange an authorization code for a token.
+
+ This calls the ``token_endpoint`` with the authorization code we
+ received in the callback to exchange it for a token. The call uses the
+ ``ClientAuth`` to authenticate with the client with its ID and secret.
+
+ See:
+ https://tools.ietf.org/html/rfc6749#section-3.2
+ https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
+
+ Args:
+ code: The authorization code we got from the callback.
+
+ Returns:
+ A dict containing various tokens.
+
+ May look like this::
+
+ {
+ 'token_type': 'bearer',
+ 'access_token': 'abcdef',
+ 'expires_in': 3599,
+ 'id_token': 'ghijkl',
+ 'refresh_token': 'mnopqr',
+ }
+
+ Raises:
+ OidcError: when the ``token_endpoint`` returned an error.
+ """
+ metadata = await self.load_metadata()
+ token_endpoint = metadata.get("token_endpoint")
+ headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "User-Agent": self._http_client.user_agent,
+ "Accept": "application/json",
+ }
+
+ args = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": self._callback_url,
+ }
+ body = urlencode(args, True)
+
+ # Fill the body/headers with credentials
+ uri, headers, body = self._client_auth.prepare(
+ method="POST", uri=token_endpoint, headers=headers, body=body
+ )
+ headers = {k: [v] for (k, v) in headers.items()}
+
+ # Do the actual request
+ # We're not using the SimpleHttpClient util methods as we don't want to
+ # check the HTTP status code and we do the body encoding ourself.
+ response = await self._http_client.request(
+ method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
+ )
+
+ # This is used in multiple error messages below
+ status = "{code} {phrase}".format(
+ code=response.code, phrase=response.phrase.decode("utf-8")
+ )
+
+ resp_body = await make_deferred_yieldable(readBody(response))
+
+ if response.code >= 500:
+ # In case of a server error, we should first try to decode the body
+ # and check for an error field. If not, we respond with a generic
+ # error message.
+ try:
+ resp = json.loads(resp_body.decode("utf-8"))
+ error = resp["error"]
+ description = resp.get("error_description", error)
+ except (ValueError, KeyError):
+ # Catch ValueError for the JSON decoding and KeyError for the "error" field
+ error = "server_error"
+ description = (
+ (
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=status),
+ )
+
+ raise OidcError(error, description)
+
+ # Since it is a not a 5xx code, body should be a valid JSON. It will
+ # raise if not.
+ resp = json.loads(resp_body.decode("utf-8"))
+
+ if "error" in resp:
+ error = resp["error"]
+ # In case the authorization server responded with an error field,
+ # it should be a 4xx code. If not, warn about it but don't do
+ # anything special and report the original error message.
+ if response.code < 400:
+ logger.debug(
+ "Invalid response from the authorization server: "
+ 'responded with a "{status}" '
+ "but body has an error field: {error!r}".format(
+ status=status, error=resp["error"]
+ )
+ )
+
+ description = resp.get("error_description", error)
+ raise OidcError(error, description)
+
+ # Now, this should not be an error. According to RFC6749 sec 5.1, it
+ # should be a 200 code. We're a bit more flexible than that, and will
+ # only throw on a 4xx code.
+ if response.code >= 400:
+ description = (
+ 'Authorization server responded with a "{status}" error '
+ 'but did not include an "error" field in its response.'.format(
+ status=status
+ )
+ )
+ logger.warning(description)
+ # Body was still valid JSON. Might be useful to log it for debugging.
+ logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+ raise OidcError("server_error", description)
+
+ return resp
+
+ async def _fetch_userinfo(self, token: Token) -> UserInfo:
+ """Fetch user informations from the ``userinfo_endpoint``.
+
+ Args:
+ token: the token given by the ``token_endpoint``.
+ Must include an ``access_token`` field.
+
+ Returns:
+ UserInfo: an object representing the user.
+ """
+ metadata = await self.load_metadata()
+
+ resp = await self._http_client.get_json(
+ metadata["userinfo_endpoint"],
+ headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
+ )
+
+ return UserInfo(resp)
+
+ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+ """Return an instance of UserInfo from token's ``id_token``.
+
+ Args:
+ token: the token given by the ``token_endpoint``.
+ Must include an ``id_token`` field.
+ nonce: the nonce value originally sent in the initial authorization
+ request. This value should match the one inside the token.
+
+ Returns:
+ An object representing the user.
+ """
+ metadata = await self.load_metadata()
+ claims_params = {
+ "nonce": nonce,
+ "client_id": self._client_auth.client_id,
+ }
+ if "access_token" in token:
+ # If we got an `access_token`, there should be an `at_hash` claim
+ # in the `id_token` that we can check against.
+ claims_params["access_token"] = token["access_token"]
+ claims_cls = CodeIDToken
+ else:
+ claims_cls = ImplicitIDToken
+
+ alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+ jwt = JsonWebToken(alg_values)
+
+ claim_options = {"iss": {"values": [metadata["issuer"]]}}
+
+ # Try to decode the keys in cache first, then retry by forcing the keys
+ # to be reloaded
+ jwk_set = await self.load_jwks()
+ try:
+ claims = jwt.decode(
+ token["id_token"],
+ key=jwk_set,
+ claims_cls=claims_cls,
+ claims_options=claim_options,
+ claims_params=claims_params,
+ )
+ except ValueError:
+ logger.info("Reloading JWKS after decode error")
+ jwk_set = await self.load_jwks(force=True) # try reloading the jwks
+ claims = jwt.decode(
+ token["id_token"],
+ key=jwk_set,
+ claims_cls=claims_cls,
+ claims_options=claim_options,
+ claims_params=claims_params,
+ )
+
+ claims.validate(leeway=120) # allows 2 min of clock skew
+ return UserInfo(claims)
+
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: bytes,
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
+ """Handle an incoming request to /login/sso/redirect
+
+ It returns a redirect to the authorization endpoint with a few
+ parameters:
+
+ - ``client_id``: the client ID set in ``oidc_config.client_id``
+ - ``response_type``: ``code``
+ - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+ - ``scope``: the list of scopes set in ``oidc_config.scopes``
+ - ``state``: a random string
+ - ``nonce``: a random string
+
+ In addition generating a redirect URL, we are setting a cookie with
+ a signed macaroon token containing the state, the nonce and the
+ client_redirect_url params. Those are then checked when the client
+ comes back from the provider.
+
+ Args:
+ request: the incoming request from the browser.
+ We'll respond to it with a redirect and a cookie.
+ client_redirect_url: the URL that we should redirect the client to
+ when everything is done
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
+
+ Returns:
+ The redirect URL to the authorization endpoint.
+
+ """
+
+ state = generate_token()
+ nonce = generate_token()
+
+ cookie = self._generate_oidc_session_token(
+ state=state,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url.decode(),
+ ui_auth_session_id=ui_auth_session_id,
+ )
+ request.addCookie(
+ SESSION_COOKIE_NAME,
+ cookie,
+ path="/_synapse/oidc",
+ max_age="3600",
+ httpOnly=True,
+ sameSite="lax",
+ )
+
+ metadata = await self.load_metadata()
+ authorization_endpoint = metadata.get("authorization_endpoint")
+ return prepare_grant_uri(
+ authorization_endpoint,
+ client_id=self._client_auth.client_id,
+ response_type="code",
+ redirect_uri=self._callback_url,
+ scope=self._scopes,
+ state=state,
+ nonce=nonce,
+ )
+
+ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+ """Handle an incoming request to /_synapse/oidc/callback
+
+ Since we might want to display OIDC-related errors in a user-friendly
+ way, we don't raise SynapseError from here. Instead, we call
+ ``self._render_error`` which displays an HTML page for the error.
+
+ Most of the OpenID Connect logic happens here:
+
+ - first, we check if there was any error returned by the provider and
+ display it
+ - then we fetch the session cookie, decode and verify it
+ - the ``state`` query parameter should match with the one stored in the
+ session cookie
+ - once we known this session is legit, exchange the code with the
+ provider using the ``token_endpoint`` (see ``_exchange_code``)
+ - once we have the token, use it to either extract the UserInfo from
+ the ``id_token`` (``_parse_id_token``), or use the ``access_token``
+ to fetch UserInfo from the ``userinfo_endpoint``
+ (``_fetch_userinfo``)
+ - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
+ finish the login
+
+ Args:
+ request: the incoming request from the browser.
+ """
+
+ # The provider might redirect with an error.
+ # In that case, just display it as-is.
+ if b"error" in request.args:
+ # error response from the auth server. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+ # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+ error = request.args[b"error"][0].decode()
+ description = request.args.get(b"error_description", [b""])[0].decode()
+
+ # Most of the errors returned by the provider could be due by
+ # either the provider misbehaving or Synapse being misconfigured.
+ # The only exception of that is "access_denied", where the user
+ # probably cancelled the login flow. In other cases, log those errors.
+ if error != "access_denied":
+ logger.error("Error from the OIDC provider: %s %s", error, description)
+
+ self._render_error(request, error, description)
+ return
+
+ # otherwise, it is presumably a successful response. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+ # Fetch the session cookie
+ session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
+ if session is None:
+ logger.info("No session cookie found")
+ self._render_error(request, "missing_session", "No session cookie found")
+ return
+
+ # Remove the cookie. There is a good chance that if the callback failed
+ # once, it will fail next time and the code will already be exchanged.
+ # Removing it early avoids spamming the provider with token requests.
+ request.addCookie(
+ SESSION_COOKIE_NAME,
+ b"",
+ path="/_synapse/oidc",
+ expires="Thu, Jan 01 1970 00:00:00 UTC",
+ httpOnly=True,
+ sameSite="lax",
+ )
+
+ # Check for the state query parameter
+ if b"state" not in request.args:
+ logger.info("State parameter is missing")
+ self._render_error(request, "invalid_request", "State parameter is missing")
+ return
+
+ state = request.args[b"state"][0].decode()
+
+ # Deserialize the session token and verify it.
+ try:
+ (
+ nonce,
+ client_redirect_url,
+ ui_auth_session_id,
+ ) = self._verify_oidc_session_token(session, state)
+ except MacaroonDeserializationException as e:
+ logger.exception("Invalid session")
+ self._render_error(request, "invalid_session", str(e))
+ return
+ except MacaroonInvalidSignatureException as e:
+ logger.exception("Could not verify session")
+ self._render_error(request, "mismatching_session", str(e))
+ return
+
+ # Exchange the code with the provider
+ if b"code" not in request.args:
+ logger.info("Code parameter is missing")
+ self._render_error(request, "invalid_request", "Code parameter is missing")
+ return
+
+ logger.debug("Exchanging code")
+ code = request.args[b"code"][0].decode()
+ try:
+ token = await self._exchange_code(code)
+ except OidcError as e:
+ logger.exception("Could not exchange code")
+ self._render_error(request, e.error, e.error_description)
+ return
+
+ logger.debug("Successfully obtained OAuth2 access token")
+
+ # Now that we have a token, get the userinfo, either by decoding the
+ # `id_token` or by fetching the `userinfo_endpoint`.
+ if self._uses_userinfo:
+ logger.debug("Fetching userinfo")
+ try:
+ userinfo = await self._fetch_userinfo(token)
+ except Exception as e:
+ logger.exception("Could not fetch userinfo")
+ self._render_error(request, "fetch_error", str(e))
+ return
+ else:
+ logger.debug("Extracting userinfo from id_token")
+ try:
+ userinfo = await self._parse_id_token(token, nonce=nonce)
+ except Exception as e:
+ logger.exception("Invalid id_token")
+ self._render_error(request, "invalid_token", str(e))
+ return
+
+ # Call the mapper to register/login the user
+ try:
+ user_id = await self._map_userinfo_to_user(userinfo, token)
+ except MappingException as e:
+ logger.exception("Could not map user")
+ self._render_error(request, "mapping_error", str(e))
+ return
+
+ # and finally complete the login
+ if ui_auth_session_id:
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
+ else:
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url
+ )
+
+ def _generate_oidc_session_token(
+ self,
+ state: str,
+ nonce: str,
+ client_redirect_url: str,
+ ui_auth_session_id: Optional[str],
+ duration_in_ms: int = (60 * 60 * 1000),
+ ) -> str:
+ """Generates a signed token storing data about an OIDC session.
+
+ When Synapse initiates an authorization flow, it creates a random state
+ and a random nonce. Those parameters are given to the provider and
+ should be verified when the client comes back from the provider.
+ It is also used to store the client_redirect_url, which is used to
+ complete the SSO login flow.
+
+ Args:
+ state: The ``state`` parameter passed to the OIDC provider.
+ nonce: The ``nonce`` parameter passed to the OIDC provider.
+ client_redirect_url: The URL the client gave when it initiated the
+ flow.
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
+ duration_in_ms: An optional duration for the token in milliseconds.
+ Defaults to an hour.
+
+ Returns:
+ A signed macaroon token with the session informations.
+ """
+ macaroon = pymacaroons.Macaroon(
+ location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+ )
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = session")
+ macaroon.add_first_party_caveat("state = %s" % (state,))
+ macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
+ macaroon.add_first_party_caveat(
+ "client_redirect_url = %s" % (client_redirect_url,)
+ )
+ if ui_auth_session_id:
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (ui_auth_session_id,)
+ )
+ now = self._clock.time_msec()
+ expiry = now + duration_in_ms
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
+ return macaroon.serialize()
+
+ def _verify_oidc_session_token(
+ self, session: bytes, state: str
+ ) -> Tuple[str, str, Optional[str]]:
+ """Verifies and extract an OIDC session token.
+
+ This verifies that a given session token was issued by this homeserver
+ and extract the nonce and client_redirect_url caveats.
+
+ Args:
+ session: The session token to verify
+ state: The state the OIDC provider gave back
+
+ Returns:
+ The nonce, client_redirect_url, and ui_auth_session_id for this session
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(session)
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = session")
+ v.satisfy_exact("state = %s" % (state,))
+ v.satisfy_general(lambda c: c.startswith("nonce = "))
+ v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+ # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+ # to always satisfy this.
+ v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
+ v.satisfy_general(self._verify_expiry)
+
+ v.verify(macaroon, self._macaroon_secret_key)
+
+ # Extract the `nonce`, `client_redirect_url`, and maybe the
+ # `ui_auth_session_id` from the token.
+ nonce = self._get_value_from_macaroon(macaroon, "nonce")
+ client_redirect_url = self._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+ try:
+ ui_auth_session_id = self._get_value_from_macaroon(
+ macaroon, "ui_auth_session_id"
+ ) # type: Optional[str]
+ except ValueError:
+ ui_auth_session_id = None
+
+ return nonce, client_redirect_url, ui_auth_session_id
+
+ def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ Exception: if the caveat was not in the macaroon
+ """
+ prefix = key + " = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(prefix):
+ return caveat.caveat_id[len(prefix) :]
+ raise ValueError("No %s caveat in macaroon" % (key,))
+
+ def _verify_expiry(self, caveat: str) -> bool:
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ now = self._clock.time_msec()
+ return now < expiry
+
+ async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+ """Maps a UserInfo object to a mxid.
+
+ UserInfo should have a claim that uniquely identifies users. This claim
+ is usually `sub`, but can be configured with `oidc_config.subject_claim`.
+ It is then used as an `external_id`.
+
+ If we don't find the user that way, we should register the user,
+ mapping the localpart and the display name from the UserInfo.
+
+ If a user already exists with the mxid we've mapped, raise an exception.
+
+ Args:
+ userinfo: an object representing the user
+ token: a dict with the tokens obtained from the provider
+
+ Raises:
+ MappingException: if there was an error while mapping some properties
+
+ Returns:
+ The mxid of the user
+ """
+ try:
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ except Exception as e:
+ raise MappingException(
+ "Failed to extract subject from OIDC response: %s" % (e,)
+ )
+
+ logger.info(
+ "Looking for existing mapping for user %s:%s",
+ self._auth_provider_id,
+ remote_user_id,
+ )
+
+ registered_user_id = await self._datastore.get_user_by_external_id(
+ self._auth_provider_id, remote_user_id,
+ )
+
+ if registered_user_id is not None:
+ logger.info("Found existing mapping %s", registered_user_id)
+ return registered_user_id
+
+ try:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token
+ )
+ except Exception as e:
+ raise MappingException(
+ "Could not extract user attributes from OIDC response: " + str(e)
+ )
+
+ logger.debug(
+ "Retrieved user attributes from user mapping provider: %r", attributes
+ )
+
+ if not attributes["localpart"]:
+ raise MappingException("localpart is empty")
+
+ localpart = map_username_to_mxid_localpart(attributes["localpart"])
+
+ user_id = UserID(localpart, self._hostname)
+ if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
+ # This mxid is taken
+ raise MappingException(
+ "mxid '{}' is already taken".format(user_id.to_string())
+ )
+
+ # It's the first time this user is logging in and the mapped mxid was
+ # not taken, register the user
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=attributes["display_name"],
+ )
+
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id,
+ )
+ return registered_user_id
+
+
+UserAttribute = TypedDict(
+ "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+)
+C = TypeVar("C")
+
+
+class OidcMappingProvider(Generic[C]):
+ """A mapping provider maps a UserInfo object to user attributes.
+
+ It should provide the API described by this class.
+ """
+
+ def __init__(self, config: C):
+ """
+ Args:
+ config: A custom config object from this module, parsed by ``parse_config()``
+ """
+
+ @staticmethod
+ def parse_config(config: dict) -> C:
+ """Parse the dict provided by the homeserver's config
+
+ Args:
+ config: A dictionary containing configuration options for this provider
+
+ Returns:
+ A custom config object for this module
+ """
+ raise NotImplementedError()
+
+ def get_remote_user_id(self, userinfo: UserInfo) -> str:
+ """Get a unique user ID for this user.
+
+ Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+
+ Returns:
+ A unique user ID
+ """
+ raise NotImplementedError()
+
+ async def map_user_attributes(
+ self, userinfo: UserInfo, token: Token
+ ) -> UserAttribute:
+ """Map a ``UserInfo`` objects into user attributes.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ token: A dict with the tokens returned by the provider
+
+ Returns:
+ A dict containing the ``localpart`` and (optionally) the ``display_name``
+ """
+ raise NotImplementedError()
+
+
+# Used to clear out "None" values in templates
+def jinja_finalize(thing):
+ return thing if thing is not None else ""
+
+
+env = Environment(finalize=jinja_finalize)
+
+
+@attr.s
+class JinjaOidcMappingConfig:
+ subject_claim = attr.ib() # type: str
+ localpart_template = attr.ib() # type: Template
+ display_name_template = attr.ib() # type: Optional[Template]
+
+
+class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
+ """An implementation of a mapping provider based on Jinja templates.
+
+ This is the default mapping provider.
+ """
+
+ def __init__(self, config: JinjaOidcMappingConfig):
+ self._config = config
+
+ @staticmethod
+ def parse_config(config: dict) -> JinjaOidcMappingConfig:
+ subject_claim = config.get("subject_claim", "sub")
+
+ if "localpart_template" not in config:
+ raise ConfigError(
+ "missing key: oidc_config.user_mapping_provider.config.localpart_template"
+ )
+
+ try:
+ localpart_template = env.from_string(config["localpart_template"])
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
+ % (e,)
+ )
+
+ display_name_template = None # type: Optional[Template]
+ if "display_name_template" in config:
+ try:
+ display_name_template = env.from_string(config["display_name_template"])
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
+ % (e,)
+ )
+
+ return JinjaOidcMappingConfig(
+ subject_claim=subject_claim,
+ localpart_template=localpart_template,
+ display_name_template=display_name_template,
+ )
+
+ def get_remote_user_id(self, userinfo: UserInfo) -> str:
+ return userinfo[self._config.subject_claim]
+
+ async def map_user_attributes(
+ self, userinfo: UserInfo, token: Token
+ ) -> UserAttribute:
+ localpart = self._config.localpart_template.render(user=userinfo).strip()
+
+ display_name = None # type: Optional[str]
+ if self._config.display_name_template is not None:
+ display_name = self._config.display_name_template.render(
+ user=userinfo
+ ).strip()
+
+ if display_name == "":
+ display_name = None
+
+ return UserAttribute(localpart=localpart, display_name=display_name)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5cbefae1..3594f3b0 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -193,6 +193,12 @@ class BasePresenceHandler(abc.ABC):
) -> None:
"""Set the presence state of the user. """
+ @abc.abstractmethod
+ async def bump_presence_active_time(self, user: UserID):
+ """We've seen the user do something that indicates they're interacting
+ with the app.
+ """
+
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
@@ -204,6 +210,7 @@ class PresenceHandler(BasePresenceHandler):
self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler()
+ self._presence_enabled = hs.config.use_presence
federation_registry = hs.get_federation_registry()
@@ -676,13 +683,14 @@ class PresenceHandler(BasePresenceHandler):
async def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server.
"""
+ if not self._presence_enabled:
+ return
+
now = self.clock.time_msec()
updates = []
for push in content.get("push", []):
# A "push" contains a list of presence that we are probably interested
# in.
- # TODO: Actually check if we're interested, rather than blindly
- # accepting presence updates.
user_id = push.get("user_id", None)
if not user_id:
logger.info(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a6178e74..51979ea4 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,8 +16,6 @@
"""Contains functions for registering clients."""
import logging
-from twisted.internet import defer
-
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@@ -75,8 +73,9 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
- @defer.inlineCallbacks
- def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
+ async def check_username(
+ self, localpart, guest_access_token=None, assigned_user_id=None
+ ):
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -113,13 +112,13 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME,
)
- users = yield self.store.get_users_by_id_case_insensitive(user_id)
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
- user_data = yield self.auth.get_user_by_access_token(guest_access_token)
+ user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
403,
@@ -128,8 +127,16 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.FORBIDDEN,
)
- @defer.inlineCallbacks
- def register_user(
+ if guest_access_token is None:
+ try:
+ int(localpart)
+ raise SynapseError(
+ 400, "Numeric user IDs are reserved for guest users."
+ )
+ except ValueError:
+ pass
+
+ async def register_user(
self,
localpart=None,
password_hash=None,
@@ -141,6 +148,7 @@ class RegistrationHandler(BaseHandler):
default_display_name=None,
address=None,
bind_emails=[],
+ by_admin=False,
):
"""Registers a new client on the server.
@@ -156,29 +164,24 @@ class RegistrationHandler(BaseHandler):
will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the registration.
bind_emails (List[str]): list of emails to bind to this account.
+ by_admin (bool): True if this registration is being made via the
+ admin api, otherwise False.
Returns:
- Deferred[str]: user_id
+ str: user_id
Raises:
SynapseError if there was a problem registering.
"""
- yield self.check_registration_ratelimit(address)
+ self.check_registration_ratelimit(address)
- yield self.auth.check_auth_blocking(threepid=threepid)
+ # do not check_auth_blocking if the call is coming through the Admin API
+ if not by_admin:
+ await self.auth.check_auth_blocking(threepid=threepid)
if localpart is not None:
- yield self.check_username(localpart, guest_access_token=guest_access_token)
+ await self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
- if not was_guest:
- try:
- int(localpart)
- raise SynapseError(
- 400, "Numeric user IDs are reserved for guest users."
- )
- except ValueError:
- pass
-
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -189,7 +192,7 @@ class RegistrationHandler(BaseHandler):
elif default_display_name is None:
default_display_name = localpart
- yield self.register_with_store(
+ await self.register_with_store(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -201,8 +204,8 @@ class RegistrationHandler(BaseHandler):
)
if self.hs.config.user_directory_search_all_users:
- profile = yield self.store.get_profileinfo(localpart)
- yield self.user_directory_handler.handle_local_profile_change(
+ profile = await self.store.get_profileinfo(localpart)
+ await self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
@@ -215,14 +218,14 @@ class RegistrationHandler(BaseHandler):
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
- localpart = yield self._generate_user_id()
+ localpart = await self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
- yield self.check_user_id_not_appservice_exclusive(user_id)
+ self.check_user_id_not_appservice_exclusive(user_id)
if default_display_name is None:
default_display_name = localpart
try:
- yield self.register_with_store(
+ await self.register_with_store(
user_id=user_id,
password_hash=password_hash,
make_guest=make_guest,
@@ -239,7 +242,13 @@ class RegistrationHandler(BaseHandler):
fail_count += 1
if not self.hs.config.user_consent_at_registration:
- yield defer.ensureDeferred(self._auto_join_rooms(user_id))
+ if not self.hs.config.auto_join_rooms_for_guests and make_guest:
+ logger.info(
+ "Skipping auto-join for %s because auto-join for guests is disabled",
+ user_id,
+ )
+ else:
+ await self._auto_join_rooms(user_id)
else:
logger.info(
"Skipping auto-join for %s because consent is required at registration",
@@ -257,7 +266,7 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- yield self._register_email_threepid(user_id, threepid_dict, None)
+ await self._register_email_threepid(user_id, threepid_dict, None)
return user_id
@@ -322,8 +331,7 @@ class RegistrationHandler(BaseHandler):
"""
await self._auto_join_rooms(user_id)
- @defer.inlineCallbacks
- def appservice_register(self, user_localpart, as_token):
+ async def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -338,11 +346,9 @@ class RegistrationHandler(BaseHandler):
service_id = service.id if service.is_exclusive_user(user_id) else None
- yield self.check_user_id_not_appservice_exclusive(
- user_id, allowed_appservice=service
- )
+ self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
- yield self.register_with_store(
+ await self.register_with_store(
user_id=user_id,
password_hash="",
appservice_id=service_id,
@@ -374,13 +380,12 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- @defer.inlineCallbacks
- def _generate_user_id(self):
+ async def _generate_user_id(self):
if self._next_generated_user_id is None:
- with (yield self._generate_user_id_linearizer.queue(())):
+ with await self._generate_user_id_linearizer.queue(()):
if self._next_generated_user_id is None:
self._next_generated_user_id = (
- yield self.store.find_next_generated_user_id_localpart()
+ await self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
@@ -425,14 +430,7 @@ class RegistrationHandler(BaseHandler):
if not address:
return
- time_now = self.clock.time()
-
- self.ratelimiter.ratelimit(
- address,
- time_now_s=time_now,
- rate_hz=self.hs.config.rc_registration.per_second,
- burst_count=self.hs.config.rc_registration.burst_count,
- )
+ self.ratelimiter.ratelimit(address)
def register_with_store(
self,
@@ -490,8 +488,9 @@ class RegistrationHandler(BaseHandler):
user_type=user_type,
)
- @defer.inlineCallbacks
- def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
+ async def register_device(
+ self, user_id, device_id, initial_display_name, is_guest=False
+ ):
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
@@ -505,11 +504,11 @@ class RegistrationHandler(BaseHandler):
is_guest (bool): Whether this is a guest account
Returns:
- defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+ tuple[str, str]: Tuple of device ID and access token
"""
if self.hs.config.worker_app:
- r = yield self._register_device_client(
+ r = await self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
@@ -525,7 +524,7 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
- device_id = yield self.device_handler.check_device_registered(
+ device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
@@ -534,10 +533,8 @@ class RegistrationHandler(BaseHandler):
user_id, ["guest = true"]
)
else:
- access_token = yield defer.ensureDeferred(
- self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id, valid_until_ms=valid_until_ms
- )
+ access_token = await self._auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
return (device_id, access_token)
@@ -588,8 +585,7 @@ class RegistrationHandler(BaseHandler):
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- @defer.inlineCallbacks
- def _register_email_threepid(self, user_id, threepid, token):
+ async def _register_email_threepid(self, user_id, threepid, token):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
@@ -602,8 +598,6 @@ class RegistrationHandler(BaseHandler):
threepid (object): m.login.email.identity auth response
token (str|None): access_token for the user, or None if not logged
in.
- Returns:
- defer.Deferred:
"""
reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
@@ -611,13 +605,8 @@ class RegistrationHandler(BaseHandler):
logger.info("Can't add incomplete 3pid")
return
- yield defer.ensureDeferred(
- self._auth_handler.add_threepid(
- user_id,
- threepid["medium"],
- threepid["address"],
- threepid["validated_at"],
- )
+ await self._auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
)
# And we add an email pusher for them by default, but only
@@ -633,10 +622,10 @@ class RegistrationHandler(BaseHandler):
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
- user_tuple = yield self.store.get_user_by_access_token(token)
+ user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"]
- yield self.pusher_pool.add_pusher(
+ await self.pusher_pool.add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
@@ -648,8 +637,7 @@ class RegistrationHandler(BaseHandler):
data={},
)
- @defer.inlineCallbacks
- def _register_msisdn_threepid(self, user_id, threepid):
+ async def _register_msisdn_threepid(self, user_id, threepid):
"""Add a phone number as a 3pid identifier
Must be called on master.
@@ -657,8 +645,6 @@ class RegistrationHandler(BaseHandler):
Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
- Returns:
- defer.Deferred:
"""
try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
@@ -669,11 +655,6 @@ class RegistrationHandler(BaseHandler):
return None
raise
- yield defer.ensureDeferred(
- self._auth_handler.add_threepid(
- user_id,
- threepid["medium"],
- threepid["address"],
- threepid["validated_at"],
- )
+ await self._auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index da12df7f..61db3ccc 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -22,11 +22,10 @@ import logging
import math
import string
from collections import OrderedDict
+from typing import Tuple
from six import iteritems, string_types
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -90,6 +89,8 @@ class RoomCreationHandler(BaseHandler):
self.room_member_handler = hs.get_room_member_handler()
self.config = hs.config
+ self._replication = hs.get_replication_data_handler()
+
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@@ -103,8 +104,7 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
- @defer.inlineCallbacks
- def upgrade_room(
+ async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""Replace a room with a new room with a different version
@@ -117,7 +117,7 @@ class RoomCreationHandler(BaseHandler):
Returns:
Deferred[unicode]: the new room id
"""
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
user_id = requester.user.to_string()
@@ -138,7 +138,7 @@ class RoomCreationHandler(BaseHandler):
# If this user has sent multiple upgrade requests for the same room
# and one of them is not complete yet, cache the response and
# return it to all subsequent requests
- ret = yield self._upgrade_response_cache.wrap(
+ ret = await self._upgrade_response_cache.wrap(
(old_room_id, user_id),
self._upgrade_room,
requester,
@@ -442,73 +442,78 @@ class RoomCreationHandler(BaseHandler):
new_room_id: str,
old_room_state: StateMap[str],
):
- directory_handler = self.hs.get_handlers().directory_handler
-
- aliases = await self.store.get_aliases_for_room(old_room_id)
-
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
- # first we try to remove the aliases from the old room (we suppress sending
- # the room_aliases event until the end).
- #
- # Note that we'll only be able to remove aliases that (a) aren't owned by an AS,
- # and (b) unless the user is a server admin, which the user created.
- #
- # This is probably correct - given we don't allow such aliases to be deleted
- # normally, it would be odd to allow it in the case of doing a room upgrade -
- # but it makes the upgrade less effective, and you have to wonder why a room
- # admin can't remove aliases that point to that room anyway.
- # (cf https://github.com/matrix-org/synapse/issues/2360)
- #
- removed_aliases = []
- for alias_str in aliases:
- alias = RoomAlias.from_string(alias_str)
- try:
- await directory_handler.delete_association(requester, alias)
- removed_aliases.append(alias_str)
- except SynapseError as e:
- logger.warning("Unable to remove alias %s from old room: %s", alias, e)
-
- # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
- # of this.
- if not removed_aliases:
+ await self.store.update_aliases_for_room(old_room_id, new_room_id)
+
+ if not canonical_alias_event:
return
- # we can now add any aliases we successfully removed to the new room.
- for alias in removed_aliases:
- try:
- await directory_handler.create_association(
- requester,
- RoomAlias.from_string(alias),
- new_room_id,
- servers=(self.hs.hostname,),
- check_membership=False,
- )
- logger.info("Moved alias %s to new room", alias)
- except SynapseError as e:
- # I'm not really expecting this to happen, but it could if the spam
- # checking module decides it shouldn't, or similar.
- logger.error("Error adding alias %s to new room: %s", alias, e)
+ # If there is a canonical alias we need to update the one in the old
+ # room and set one in the new one.
+ old_canonical_alias_content = dict(canonical_alias_event.content)
+ new_canonical_alias_content = {}
+
+ canonical = canonical_alias_event.content.get("alias")
+ if canonical and self.hs.is_mine_id(canonical):
+ new_canonical_alias_content["alias"] = canonical
+ old_canonical_alias_content.pop("alias", None)
+
+ # We convert to a list as it will be a Tuple.
+ old_alt_aliases = list(old_canonical_alias_content.get("alt_aliases", []))
+ if old_alt_aliases:
+ old_canonical_alias_content["alt_aliases"] = old_alt_aliases
+ new_alt_aliases = new_canonical_alias_content.setdefault("alt_aliases", [])
+ for alias in canonical_alias_event.content.get("alt_aliases", []):
+ try:
+ if self.hs.is_mine_id(alias):
+ new_alt_aliases.append(alias)
+ old_alt_aliases.remove(alias)
+ except Exception:
+ logger.info(
+ "Invalid alias %s in canonical alias event %s",
+ alias,
+ canonical_alias_event_id,
+ )
+
+ if not old_alt_aliases:
+ old_canonical_alias_content.pop("alt_aliases")
# If a canonical alias event existed for the old room, fire a canonical
# alias event for the new room with a copy of the information.
try:
- if canonical_alias_event:
- await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.CanonicalAlias,
- "state_key": "",
- "room_id": new_room_id,
- "sender": requester.user.to_string(),
- "content": canonical_alias_event.content,
- },
- ratelimit=False,
- )
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.CanonicalAlias,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": requester.user.to_string(),
+ "content": old_canonical_alias_content,
+ },
+ ratelimit=False,
+ )
+ except SynapseError as e:
+ # again I'm not really expecting this to fail, but if it does, I'd rather
+ # we returned the new room to the client at this point.
+ logger.error("Unable to send updated alias events in old room: %s", e)
+
+ try:
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.CanonicalAlias,
+ "state_key": "",
+ "room_id": new_room_id,
+ "sender": requester.user.to_string(),
+ "content": new_canonical_alias_content,
+ },
+ ratelimit=False,
+ )
except SynapseError as e:
# again I'm not really expecting this to fail, but if it does, I'd rather
# we returned the new room to the client at this point.
@@ -516,7 +521,7 @@ class RoomCreationHandler(BaseHandler):
async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None
- ):
+ ) -> Tuple[dict, int]:
""" Creates a new room.
Args:
@@ -533,9 +538,9 @@ class RoomCreationHandler(BaseHandler):
`avatar_url` and/or `displayname`.
Returns:
- Deferred[dict]:
- a dict containing the keys `room_id` and, if an alias was
- requested, `room_alias`.
+ First, a dict containing the keys `room_id` and, if an alias
+ was, requested, `room_alias`. Secondly, the stream_id of the
+ last persisted event.
Raises:
SynapseError if the room ID couldn't be stored, or something went
horribly wrong.
@@ -667,7 +672,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
- await self._send_events_for_new_room(
+ last_stream_id = await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@@ -681,7 +686,10 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
- await self.event_creation_handler.create_and_send_nonmember_event(
+ (
+ _,
+ last_stream_id,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@@ -695,7 +703,10 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
- await self.event_creation_handler.create_and_send_nonmember_event(
+ (
+ _,
+ last_stream_id,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@@ -713,7 +724,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
- await self.room_member_handler.update_membership(
+ _, last_stream_id = await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@@ -727,7 +738,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
- await self.hs.get_room_member_handler().do_3pid_invite(
+ last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@@ -743,7 +754,12 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
- return result
+ # Always wait for room creation to progate before returning
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", last_stream_id
+ )
+
+ return result, last_stream_id
async def _send_events_for_new_room(
self,
@@ -756,7 +772,13 @@ class RoomCreationHandler(BaseHandler):
room_alias=None,
power_level_content_override=None, # Doesn't apply when initial state has power level state event content
creator_join_profile=None,
- ):
+ ) -> int:
+ """Sends the initial events into a new room.
+
+ Returns:
+ The stream_id of the last event persisted.
+ """
+
def create(etype, content, **kwargs):
e = {"type": etype, "content": content}
@@ -765,12 +787,16 @@ class RoomCreationHandler(BaseHandler):
return e
- async def send(etype, content, **kwargs):
+ async def send(etype, content, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
- await self.event_creation_handler.create_and_send_nonmember_event(
+ (
+ _,
+ last_stream_id,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
)
+ return last_stream_id
config = RoomCreationHandler.PRESETS_DICT[preset_config]
@@ -795,7 +821,9 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
- await send(etype=EventTypes.PowerLevels, content=pl_content)
+ last_sent_stream_id = await send(
+ etype=EventTypes.PowerLevels, content=pl_content
+ )
else:
power_level_content = {
"users": {creator_id: 100},
@@ -828,36 +856,41 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
- await send(etype=EventTypes.PowerLevels, content=power_level_content)
+ last_sent_stream_id = await send(
+ etype=EventTypes.PowerLevels, content=power_level_content
+ )
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
- await send(
+ last_sent_stream_id = await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
if (EventTypes.JoinRules, "") not in initial_state:
- await send(
+ last_sent_stream_id = await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
- await send(
+ last_sent_stream_id = await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
- await send(
+ last_sent_stream_id = await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
- await send(etype=etype, state_key=state_key, content=content)
+ last_sent_stream_id = await send(
+ etype=etype, state_key=state_key, content=content
+ )
+
+ return last_sent_stream_id
- @defer.inlineCallbacks
- def _generate_room_id(
+ async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +902,7 @@ class RoomCreationHandler(BaseHandler):
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode("utf-8")
- yield self.store.store_room(
+ await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
@@ -888,8 +921,7 @@ class RoomContextHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
- @defer.inlineCallbacks
- def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -908,7 +940,7 @@ class RoomContextHandler(object):
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
- users = yield self.store.get_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
def filter_evts(events):
@@ -916,17 +948,17 @@ class RoomContextHandler(object):
self.storage, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, get_prev_content=True, allow_none=True
)
if not event:
return None
- filtered = yield (filter_evts([event]))
+ filtered = await filter_evts([event])
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
- results = yield self.store.get_events_around(
+ results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
@@ -934,8 +966,8 @@ class RoomContextHandler(object):
results["events_before"] = event_filter.filter(results["events_before"])
results["events_after"] = event_filter.filter(results["events_after"])
- results["events_before"] = yield filter_evts(results["events_before"])
- results["events_after"] = yield filter_evts(results["events_after"])
+ results["events_before"] = await filter_evts(results["events_before"])
+ results["events_after"] = await filter_evts(results["events_after"])
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +994,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.state_store.get_state_for_events(
+ state = await self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -970,7 +1002,7 @@ class RoomContextHandler(object):
if event_filter:
state_events = event_filter.filter(state_events)
- results["state"] = yield filter_evts(state_events)
+ results["state"] = await filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
@@ -989,13 +1021,12 @@ class RoomEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_new_events(
+ async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
- to_key = yield self.get_current_key()
+ to_key = await self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
@@ -1008,11 +1039,11 @@ class RoomEventSource(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
- room_events = yield self.store.get_membership_changes_for_user(
+ room_events = await self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
)
- room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key,
to_key=to_key,
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index e75dabcd..4cbc02b0 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -253,10 +253,21 @@ class RoomListHandler(BaseHandler):
"""
result = {"room_id": room_id, "num_joined_members": num_joined_users}
+ if with_alias:
+ aliases = yield self.store.get_aliases_for_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ if aliases:
+ result["aliases"] = aliases
+
current_state_ids = yield self.store.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
+ if not current_state_ids:
+ # We're not in the room, so may as well bail out here.
+ return result
+
event_map = yield self.store.get_events(
[
event_id
@@ -289,14 +300,7 @@ class RoomListHandler(BaseHandler):
create_event = current_state.get((EventTypes.Create, ""))
result["m.federate"] = create_event.content.get("m.federate", True)
- if with_alias:
- aliases = yield self.store.get_aliases_for_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- if aliases:
- result["aliases"] = aliases
-
- name_event = yield current_state.get((EventTypes.Name, ""))
+ name_event = current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 53b49bc1..0f7af982 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,15 +17,19 @@
import abc
import logging
+from typing import Dict, Iterable, List, Optional, Tuple
from six.moves import http_client
-from twisted.internet import defer
-
from synapse import types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.types import Collection, RoomID, UserID
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.replication.http.membership import (
+ ReplicationLocallyRejectInviteRestServlet,
+)
+from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -43,11 +47,6 @@ class RoomMemberHandler(object):
__metaclass__ = abc.ABCMeta
def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -70,90 +69,117 @@ class RoomMemberHandler(object):
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
+ self._event_stream_writer_instance = hs.config.worker.writers.events
+ self._is_on_event_persistence_instance = (
+ self._event_stream_writer_instance == hs.get_instance_name()
+ )
+ if self._is_on_event_persistence_instance:
+ self.persist_event_storage = hs.get_storage().persistence
+ else:
+ self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client(
+ hs
+ )
+
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
self.base_handler = BaseHandler(hs)
@abc.abstractmethod
- def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ async def _remote_join(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
"""Try and join a room that this server is not in
Args:
- requester (Requester)
- remote_room_hosts (list[str]): List of servers that can be used
- to join via.
- room_id (str): Room that we are trying to join
- user (UserID): User who is trying to join
- content (dict): A dict that should be used as the content of the
- join event.
-
- Returns:
- Deferred
+ requester
+ remote_room_hosts: List of servers that can be used to join via.
+ room_id: Room that we are trying to join
+ user: User who is trying to join
+ content: A dict that should be used as the content of the join event.
"""
raise NotImplementedError()
@abc.abstractmethod
- def _remote_reject_invite(
- self, requester, remote_room_hosts, room_id, target, content
- ):
+ async def _remote_reject_invite(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ target: UserID,
+ content: dict,
+ ) -> Tuple[Optional[str], int]:
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
Args:
- requester (Requester)
- remote_room_hosts (list[str]): List of servers to use to try and
- reject invite
- room_id (str)
- target (UserID): The user rejecting the invite
- content (dict): The content for the rejection event
+ requester
+ remote_room_hosts: List of servers to use to try and reject invite
+ room_id
+ target: The user rejecting the invite
+ content: The content for the rejection event
Returns:
- Deferred[dict]: A dictionary to be returned to the client, may
+ A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected
"""
raise NotImplementedError()
+ async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+ """Mark the invite has having been rejected even though we failed to
+ create a leave event for it.
+ """
+ if self._is_on_event_persistence_instance:
+ return await self.persist_event_storage.locally_reject_invite(
+ user_id, room_id
+ )
+ else:
+ result = await self._locally_reject_client(
+ instance_name=self._event_stream_writer_instance,
+ user_id=user_id,
+ room_id=room_id,
+ )
+ return result["stream_id"]
+
@abc.abstractmethod
- def _user_joined_room(self, target, room_id):
+ async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the
room.
Args:
- target (UserID)
- room_id (str)
-
- Returns:
- Deferred|None
+ target
+ room_id
"""
raise NotImplementedError()
@abc.abstractmethod
- def _user_left_room(self, target, room_id):
+ async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
Args:
- target (UserID)
- room_id (str)
-
- Returns:
- Deferred|None
+ target
+ room_id
"""
raise NotImplementedError()
async def _local_membership_update(
self,
- requester,
- target,
- room_id,
- membership,
+ requester: Requester,
+ target: UserID,
+ room_id: str,
+ membership: str,
prev_event_ids: Collection[str],
- txn_id=None,
- ratelimit=True,
- content=None,
- require_consent=True,
- ):
+ txn_id: Optional[str] = None,
+ ratelimit: bool = True,
+ content: Optional[dict] = None,
+ require_consent: bool = True,
+ ) -> Tuple[str, int]:
user_id = target.to_string()
if content is None:
@@ -186,9 +212,10 @@ class RoomMemberHandler(object):
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
- return duplicate
+ _, stream_id = await self.store.get_event_ordering(duplicate.event_id)
+ return duplicate.event_id, stream_id
- await self.event_creation_handler.handle_new_client_event(
+ stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
@@ -212,22 +239,20 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target, room_id)
- return event
+ return event.event_id, stream_id
- @defer.inlineCallbacks
- def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id):
+ async def copy_room_tags_and_direct_to_room(
+ self, old_room_id, new_room_id, user_id
+ ) -> None:
"""Copies the tags and direct room state from one room to another.
Args:
- old_room_id (str)
- new_room_id (str)
- user_id (str)
-
- Returns:
- Deferred[None]
+ old_room_id: The room ID of the old room.
+ new_room_id: The room ID of the new room.
+ user_id: The user's ID.
"""
# Retrieve user account data for predecessor room
- user_account_data, _ = yield self.store.get_account_data_for_user(user_id)
+ user_account_data, _ = await self.store.get_account_data_for_user(user_id)
# Copy direct message state if applicable
direct_rooms = user_account_data.get("m.direct", {})
@@ -240,31 +265,31 @@ class RoomMemberHandler(object):
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
- yield self.store.add_account_data_for_user(
+ await self.store.add_account_data_for_user(
user_id, "m.direct", direct_rooms
)
break
# Copy room tags if applicable
- room_tags = yield self.store.get_tags_for_room(user_id, old_room_id)
+ room_tags = await self.store.get_tags_for_room(user_id, old_room_id)
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
- yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+ await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
async def update_membership(
self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- require_consent=True,
- ):
+ requester: Requester,
+ target: UserID,
+ room_id: str,
+ action: str,
+ txn_id: Optional[str] = None,
+ remote_room_hosts: Optional[List[str]] = None,
+ third_party_signed: Optional[dict] = None,
+ ratelimit: bool = True,
+ content: Optional[dict] = None,
+ require_consent: bool = True,
+ ) -> Tuple[Optional[str], int]:
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -285,17 +310,17 @@ class RoomMemberHandler(object):
async def _update_membership(
self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- require_consent=True,
- ):
+ requester: Requester,
+ target: UserID,
+ room_id: str,
+ action: str,
+ txn_id: Optional[str] = None,
+ remote_room_hosts: Optional[List[str]] = None,
+ third_party_signed: Optional[dict] = None,
+ ratelimit: bool = True,
+ content: Optional[dict] = None,
+ require_consent: bool = True,
+ ) -> Tuple[Optional[str], int]:
content_specified = bool(content)
if content is None:
content = {}
@@ -399,7 +424,13 @@ class RoomMemberHandler(object):
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
- return old_state
+ _, stream_id = await self.store.get_event_ordering(
+ old_state.event_id
+ )
+ return (
+ old_state.event_id,
+ stream_id,
+ )
if old_membership in ["ban", "leave"] and action == "kick":
raise AuthError(403, "The target user is not in the room")
@@ -469,12 +500,11 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
- res = await self._remote_reject_invite(
+ return await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
)
- return res
- res = await self._local_membership_update(
+ return await self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@@ -485,10 +515,10 @@ class RoomMemberHandler(object):
content=content,
require_consent=require_consent,
)
- return res
- @defer.inlineCallbacks
- def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+ async def transfer_room_state_on_room_upgrade(
+ self, old_room_id: str, room_id: str
+ ) -> None:
"""Upon our server becoming aware of an upgraded room, either by upgrading a room
ourselves or joining one, we can transfer over information from the previous room.
@@ -496,50 +526,44 @@ class RoomMemberHandler(object):
well as migrating the room directory state.
Args:
- old_room_id (str): The ID of the old room
-
- room_id (str): The ID of the new room
-
- Returns:
- Deferred
+ old_room_id: The ID of the old room
+ room_id: The ID of the new room
"""
logger.info("Transferring room state from %s to %s", old_room_id, room_id)
# Find all local users that were in the old room and copy over each user's state
- users = yield self.store.get_users_in_room(old_room_id)
- yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
+ users = await self.store.get_users_in_room(old_room_id)
+ await self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
- old_room = yield self.store.get_room(old_room_id)
+ old_room = await self.store.get_room(old_room_id)
if old_room and old_room["is_public"]:
- yield self.store.set_room_is_public(old_room_id, False)
- yield self.store.set_room_is_public(room_id, True)
+ await self.store.set_room_is_public(old_room_id, False)
+ await self.store.set_room_is_public(room_id, True)
# Transfer alias mappings in the room directory
- yield self.store.update_aliases_for_room(old_room_id, room_id)
+ await self.store.update_aliases_for_room(old_room_id, room_id)
# Check if any groups we own contain the predecessor room
- local_group_ids = yield self.store.get_local_groups_for_room(old_room_id)
+ local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
for group_id in local_group_ids:
# Add new the new room to those groups
- yield self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
+ await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
# Remove the old room from those groups
- yield self.store.remove_room_from_group(group_id, old_room_id)
+ await self.store.remove_room_from_group(group_id, old_room_id)
- @defer.inlineCallbacks
- def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+ async def copy_user_state_on_room_upgrade(
+ self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
+ ) -> None:
"""Copy user-specific information when they join a new room when that new room is the
result of a room upgrade
Args:
- old_room_id (str): The ID of upgraded room
- new_room_id (str): The ID of the new room
- user_ids (Iterable[str]): User IDs to copy state for
-
- Returns:
- Deferred
+ old_room_id: The ID of upgraded room
+ new_room_id: The ID of the new room
+ user_ids: User IDs to copy state for
"""
logger.debug(
@@ -552,11 +576,11 @@ class RoomMemberHandler(object):
for user_id in user_ids:
try:
# It is an upgraded room. Copy over old tags
- yield self.copy_room_tags_and_direct_to_room(
+ await self.copy_room_tags_and_direct_to_room(
old_room_id, new_room_id, user_id
)
# Copy over push rules
- yield self.store.copy_push_rules_from_room_to_room_for_user(
+ await self.store.copy_push_rules_from_room_to_room_for_user(
old_room_id, new_room_id, user_id
)
except Exception:
@@ -569,17 +593,23 @@ class RoomMemberHandler(object):
)
continue
- async def send_membership_event(self, requester, event, context, ratelimit=True):
+ async def send_membership_event(
+ self,
+ requester: Requester,
+ event: EventBase,
+ context: EventContext,
+ ratelimit: bool = True,
+ ):
"""
Change the membership status of a user in a room.
Args:
- requester (Requester): The local user who requested the membership
+ requester: The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
- event (SynapseEvent): The membership event.
+ event: The membership event.
context: The context of the event.
- ratelimit (bool): Whether to rate limit this request.
+ ratelimit: Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
@@ -639,8 +669,9 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
- @defer.inlineCallbacks
- def _can_guest_join(self, current_state_ids):
+ async def _can_guest_join(
+ self, current_state_ids: Dict[Tuple[str, str], str]
+ ) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
@@ -648,7 +679,7 @@ class RoomMemberHandler(object):
if not guest_access_id:
return False
- guest_access = yield self.store.get_event(guest_access_id)
+ guest_access = await self.store.get_event(guest_access_id)
return (
guest_access
@@ -657,13 +688,14 @@ class RoomMemberHandler(object):
and guest_access.content["guest_access"] == "can_join"
)
- @defer.inlineCallbacks
- def lookup_room_alias(self, room_alias):
+ async def lookup_room_alias(
+ self, room_alias: RoomAlias
+ ) -> Tuple[RoomID, List[str]]:
"""
Get the room ID associated with a room alias.
Args:
- room_alias (RoomAlias): The alias to look up.
+ room_alias: The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
@@ -672,7 +704,7 @@ class RoomMemberHandler(object):
SynapseError if room alias could not be found.
"""
directory_handler = self.directory_handler
- mapping = yield directory_handler.get_association(room_alias)
+ mapping = await directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
@@ -687,25 +719,25 @@ class RoomMemberHandler(object):
return RoomID.from_string(room_id), servers
- @defer.inlineCallbacks
- def _get_inviter(self, user_id, room_id):
- invite = yield self.store.get_invite_for_local_user_in_room(
+ async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]:
+ invite = await self.store.get_invite_for_local_user_in_room(
user_id=user_id, room_id=room_id
)
if invite:
return UserID.from_string(invite.sender)
+ return None
async def do_3pid_invite(
self,
- room_id,
- inviter,
- medium,
- address,
- id_server,
- requester,
- txn_id,
- id_access_token=None,
- ):
+ room_id: str,
+ inviter: UserID,
+ medium: str,
+ address: str,
+ id_server: str,
+ requester: Requester,
+ txn_id: Optional[str],
+ id_access_token: Optional[str] = None,
+ ) -> int:
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
@@ -737,11 +769,11 @@ class RoomMemberHandler(object):
)
if invitee:
- await self.update_membership(
+ _, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
- await self._make_and_store_3pid_invite(
+ stream_id = await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@@ -752,17 +784,19 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
+ return stream_id
+
async def _make_and_store_3pid_invite(
self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- user,
- txn_id,
- id_access_token=None,
- ):
+ requester: Requester,
+ id_server: str,
+ medium: str,
+ address: str,
+ room_id: str,
+ user: UserID,
+ txn_id: Optional[str],
+ id_access_token: Optional[str] = None,
+ ) -> int:
room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
@@ -817,7 +851,10 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
- await self.event_creation_handler.create_and_send_nonmember_event(
+ (
+ event,
+ stream_id,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@@ -835,9 +872,11 @@ class RoomMemberHandler(object):
ratelimit=False,
txn_id=txn_id,
)
+ return stream_id
- @defer.inlineCallbacks
- def _is_host_in_room(self, current_state_ids):
+ async def _is_host_in_room(
+ self, current_state_ids: Dict[Tuple[str, str], str]
+ ) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -850,7 +889,7 @@ class RoomMemberHandler(object):
continue
event_id = current_state_ids[(etype, state_key)]
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if not event:
continue
@@ -859,11 +898,10 @@ class RoomMemberHandler(object):
return False
- @defer.inlineCallbacks
- def _is_server_notice_room(self, room_id):
+ async def _is_server_notice_room(self, room_id: str) -> bool:
if self._server_notices_mxid is None:
return False
- user_ids = yield self.store.get_users_in_room(room_id)
+ user_ids = await self.store.get_users_in_room(room_id)
return self._server_notices_mxid in user_ids
@@ -875,20 +913,21 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
- @defer.inlineCallbacks
- def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
+ async def _is_remote_room_too_complex(
+ self, room_id: str, remote_room_hosts: List[str]
+ ) -> Optional[bool]:
"""
Check if complexity of a remote room is too great.
Args:
- room_id (str)
- remote_room_hosts (list[str])
+ room_id
+ remote_room_hosts
Returns: bool of whether the complexity is too great, or None
if unable to be fetched
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
- complexity = yield self.federation_handler.get_room_complexity(
+ complexity = await self.federation_handler.get_room_complexity(
remote_room_hosts, room_id
)
@@ -896,22 +935,26 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity
return None
- @defer.inlineCallbacks
- def _is_local_room_too_complex(self, room_id):
+ async def _is_local_room_too_complex(self, room_id: str) -> bool:
"""
Check if the complexity of a local room is too great.
Args:
- room_id (str)
-
- Returns: bool
+ room_id: The room ID to check for complexity.
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
- complexity = yield self.store.get_room_complexity(room_id)
+ complexity = await self.store.get_room_complexity(room_id)
return complexity["v1"] > max_complexity
- async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ async def _remote_join(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@@ -940,7 +983,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
- await self.federation_handler.do_invite_join(
+ event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
await self._user_joined_room(user, room_id)
@@ -950,14 +993,14 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled:
if too_complex is False:
# We checked, and we're under the limit.
- return
+ return event_id, stream_id
# Check again, but with the local state events
too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
- return
+ return event_id, stream_id
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
@@ -970,20 +1013,24 @@ class RoomMemberMasterHandler(RoomMemberHandler):
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
- @defer.inlineCallbacks
- def _remote_reject_invite(
- self, requester, remote_room_hosts, room_id, target, content
- ):
+ return event_id, stream_id
+
+ async def _remote_reject_invite(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ target: UserID,
+ content: dict,
+ ) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite
"""
fed_handler = self.federation_handler
try:
- ret = yield defer.ensureDeferred(
- fed_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, target.to_string(), content=content,
- )
+ event, stream_id = await fed_handler.do_remotely_reject_invite(
+ remote_room_hosts, room_id, target.to_string(), content=content,
)
- return ret
+ return event.event_id, stream_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@@ -993,24 +1040,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
- yield self.store.locally_reject_invite(target.to_string(), room_id)
- return {}
+ stream_id = await self.locally_reject_invite(target.to_string(), room_id)
+ return None, stream_id
- def _user_joined_room(self, target, room_id):
+ async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
"""
- return defer.succeed(user_joined_room(self.distributor, target, room_id))
+ user_joined_room(self.distributor, target, room_id)
- def _user_left_room(self, target, room_id):
+ async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
- return defer.succeed(user_left_room(self.distributor, target, room_id))
+ user_left_room(self.distributor, target, room_id)
- @defer.inlineCallbacks
- def forget(self, user, room_id):
+ async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
- member = yield self.state_handler.get_current_state(
+ member = await self.state_handler.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
@@ -1022,4 +1068,4 @@ class RoomMemberMasterHandler(RoomMemberHandler):
raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
if membership:
- yield self.store.forget(user_id, room_id)
+ await self.store.forget(user_id, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 69be8689..02e0c410 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -14,8 +14,7 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
@@ -24,6 +23,7 @@ from synapse.replication.http.membership import (
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
+from synapse.types import Requester, UserID
logger = logging.getLogger(__name__)
@@ -36,14 +36,20 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
- @defer.inlineCallbacks
- def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ async def _remote_join(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join
"""
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
- ret = yield self._remote_join_client(
+ ret = await self._remote_join_client(
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
@@ -51,33 +57,39 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
content=content,
)
- yield self._user_joined_room(user, room_id)
+ await self._user_joined_room(user, room_id)
- return ret
+ return ret["event_id"], ret["stream_id"]
- def _remote_reject_invite(
- self, requester, remote_room_hosts, room_id, target, content
- ):
+ async def _remote_reject_invite(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ target: UserID,
+ content: dict,
+ ) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite
"""
- return self._remote_reject_client(
+ ret = await self._remote_reject_client(
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user_id=target.to_string(),
content=content,
)
+ return ret["event_id"], ret["stream_id"]
- def _user_joined_room(self, target, room_id):
+ async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
"""
- return self._notify_change_client(
+ await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="joined"
)
- def _user_left_room(self, target, room_id):
+ async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
- return self._notify_change_client(
+ await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left"
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 96f2dd36..abecaa83 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import Optional, Tuple
+from typing import Callable, Dict, Optional, Set, Tuple
import attr
import saml2
@@ -23,10 +23,9 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
-from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
-from synapse.module_api.errors import RedirectException
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
@@ -79,19 +78,19 @@ class SamlHandler:
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
- self._error_html_content = hs.config.saml2_error_html_content
-
- def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
+ def handle_redirect_request(
+ self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
+ ) -> bytes:
"""Handle an incoming request to /login/sso/redirect
Args:
- client_redirect_url (bytes): the URL that we should redirect the
+ client_redirect_url: the URL that we should redirect the
client to when everything is done
- ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
- bytes: URL to redirect to
+ URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
relay_state=client_redirect_url
@@ -109,15 +108,15 @@ class SamlHandler:
# this shouldn't happen!
raise Exception("prepare_for_authenticate didn't return a Location header")
- async def handle_saml_response(self, request):
+ async def handle_saml_response(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_matrix/saml2/authn_response
Args:
- request (SynapseRequest): the incoming request from the browser. We'll
+ request: the incoming request from the browser. We'll
respond to it with a redirect.
Returns:
- Deferred[none]: Completes once we have handled the request.
+ Completes once we have handled the request.
"""
resp_bytes = parse_string(request, "SAMLResponse", required=True)
relay_state = parse_string(request, "RelayState", required=True)
@@ -126,26 +125,9 @@ class SamlHandler:
# the dict.
self.expire_sessions()
- try:
- user_id, current_session = await self._map_saml_response_to_user(
- resp_bytes, relay_state
- )
- except RedirectException:
- # Raise the exception as per the wishes of the SAML module response
- raise
- except Exception as e:
- # If decoding the response or mapping it to a user failed, then log the
- # error and tell the user that something went wrong.
- logger.error(e)
-
- request.setResponseCode(400)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(
- b"Content-Length", b"%d" % (len(self._error_html_content),)
- )
- request.write(self._error_html_content.encode("utf8"))
- finish_request(request)
- return
+ user_id, current_session = await self._map_saml_response_to_user(
+ resp_bytes, relay_state
+ )
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
@@ -168,6 +150,11 @@ class SamlHandler:
Returns:
Tuple of the user ID and SAML session associated with this response.
+
+ Raises:
+ SynapseError if there was a problem with the response.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
"""
try:
saml2_auth = self._saml_client.parse_authn_request_response(
@@ -176,11 +163,9 @@ class SamlHandler:
outstanding=self._outstanding_requests_dict,
)
except Exception as e:
- logger.warning("Exception parsing SAML2 response: %s", e)
raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,))
if saml2_auth.not_signed:
- logger.warning("SAML2 response was not signed")
raise SynapseError(400, "SAML2 response was not signed")
logger.debug("SAML2 response: %s", saml2_auth.origxml)
@@ -261,13 +246,13 @@ class SamlHandler:
localpart = attribute_dict.get("mxid_localpart")
if not localpart:
- logger.error(
- "SAML mapping provider plugin did not return a "
- "mxid_localpart object"
+ raise Exception(
+ "Error parsing SAML2 response: SAML mapping provider plugin "
+ "did not return a mxid_localpart value"
)
- raise SynapseError(500, "Error parsing SAML2 response")
displayname = attribute_dict.get("displayname")
+ emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
@@ -285,7 +270,9 @@ class SamlHandler:
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayname
+ localpart=localpart,
+ default_display_name=displayname,
+ bind_emails=emails,
)
await self._datastore.record_user_external_id(
@@ -310,6 +297,7 @@ DOT_REPLACE_PATTERN = re.compile(
def dot_replace_for_mxid(username: str) -> str:
+ """Replace any characters which are not allowed in Matrix IDs with a dot."""
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)
@@ -321,7 +309,7 @@ def dot_replace_for_mxid(username: str) -> str:
MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
-}
+} # type: Dict[str, Callable[[str], str]]
@attr.s
@@ -349,7 +337,7 @@ class DefaultSamlMappingProvider(object):
def get_remote_user_id(
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
- ):
+ ) -> str:
"""Extracts the remote user id from the SAML response"""
try:
return saml_response.ava["uid"][0]
@@ -377,6 +365,7 @@ class DefaultSamlMappingProvider(object):
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
+ * emails (list[str]): Any emails for the user
"""
try:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
@@ -399,9 +388,13 @@ class DefaultSamlMappingProvider(object):
# If displayname is None, the mxid_localpart will be used instead
displayname = saml_response.ava.get("displayName", [None])[0]
+ # Retrieve any emails present in the saml response
+ emails = saml_response.ava.get("email", [])
+
return {
"mxid_localpart": localpart,
"displayname": displayname,
+ "emails": emails,
}
@staticmethod
@@ -428,16 +421,16 @@ class DefaultSamlMappingProvider(object):
return SamlConfig(mxid_source_attribute, mxid_mapper)
@staticmethod
- def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
"""Returns the required attributes of a SAML
Args:
config: A SamlConfig object containing configuration params for this provider
Returns:
- tuple[set,set]: The first set equates to the saml auth response
+ The first set equates to the saml auth response
attributes that are required for the module to function, whereas the
second set consists of those attributes which can be used if
available, but are not necessary
"""
- return {"uid", config.mxid_source_attribute}, {"displayName"}
+ return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ec1542d4..4d40d3ac 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -18,8 +18,6 @@ import logging
from unpaddedbase64 import decode_base64, encode_base64
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
@@ -39,8 +37,7 @@ class SearchHandler(BaseHandler):
self.state_store = self.storage.state
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def get_old_rooms_from_upgraded_room(self, room_id):
+ async def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
@@ -60,7 +57,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = []
# The initial room must have been known for us to get this far
- predecessor = yield self.store.get_room_predecessor(room_id)
+ predecessor = await self.store.get_room_predecessor(room_id)
while True:
if not predecessor:
@@ -75,7 +72,7 @@ class SearchHandler(BaseHandler):
# Don't add it to the list until we have checked that we are in the room
try:
- next_predecessor_room = yield self.store.get_room_predecessor(
+ next_predecessor_room = await self.store.get_room_predecessor(
predecessor_room_id
)
except NotFoundError:
@@ -89,8 +86,7 @@ class SearchHandler(BaseHandler):
return historical_room_ids
- @defer.inlineCallbacks
- def search(self, user, content, batch=None):
+ async def search(self, user, content, batch=None):
"""Performs a full text search for a user.
Args:
@@ -179,7 +175,7 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
- rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
+ rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
@@ -192,7 +188,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
- ids = yield self.get_old_rooms_from_upgraded_room(room_id)
+ ids = await self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids
# Prevent any historical events from being filtered
@@ -223,7 +219,7 @@ class SearchHandler(BaseHandler):
count = None
if order_by == "rank":
- search_result = yield self.store.search_msgs(room_ids, search_term, keys)
+ search_result = await self.store.search_msgs(room_ids, search_term, keys)
count = search_result["count"]
@@ -238,7 +234,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
@@ -267,7 +263,7 @@ class SearchHandler(BaseHandler):
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
- search_result = yield self.store.search_rooms(
+ search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
@@ -288,7 +284,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
@@ -343,11 +339,11 @@ class SearchHandler(BaseHandler):
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
contexts = {}
for event in allowed_events:
- res = yield self.store.get_events_around(
+ res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)
@@ -357,11 +353,11 @@ class SearchHandler(BaseHandler):
len(res["events_after"]),
)
- res["events_before"] = yield filter_events_for_client(
+ res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"]
)
- res["events_after"] = yield filter_events_for_client(
+ res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"]
)
@@ -390,7 +386,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders]
)
- state = yield self.state_store.get_state_for_event(
+ state = await self.state_store.get_state_for_event(
last_event_id, state_filter
)
@@ -412,10 +408,10 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = yield self._event_serializer.serialize_events(
+ context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"], time_now
)
- context["events_after"] = yield self._event_serializer.serialize_events(
+ context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"], time_now
)
@@ -423,7 +419,7 @@ class SearchHandler(BaseHandler):
if include_state:
rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
- state = yield self.state_handler.get_current_state(room_id)
+ state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
state_results.values()
@@ -437,7 +433,7 @@ class SearchHandler(BaseHandler):
{
"rank": rank_map[e.event_id],
"result": (
- yield self._event_serializer.serialize_event(e, time_now)
+ await self._event_serializer.serialize_event(e, time_now)
),
"context": contexts.get(e.event_id, {}),
}
@@ -452,7 +448,7 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
for room_id, state in state_results.items():
- s[room_id] = yield self._event_serializer.serialize_events(
+ s[room_id] = await self._event_serializer.serialize_events(
state, time_now
)
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 63d8f9aa..4d245b61 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -35,16 +35,13 @@ class SetPasswordHandler(BaseHandler):
async def set_password(
self,
user_id: str,
- new_password: str,
+ password_hash: str,
logout_devices: bool,
requester: Optional[Requester] = None,
):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
- self._password_policy_handler.validate_password(new_password)
- password_hash = await self._auth_handler.hash(new_password)
-
try:
await self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index f065970c..8590c1ef 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
logger = logging.getLogger(__name__)
@@ -24,8 +22,7 @@ class StateDeltasHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
@@ -41,10 +38,10 @@ class StateDeltasHandler(object):
prev_event = None
event = None
if prev_event_id:
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if event_id:
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index d93a2766..149f8612 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -16,17 +16,14 @@
import logging
from collections import Counter
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
-from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
logger = logging.getLogger(__name__)
-class StatsHandler(StateDeltasHandler):
+class StatsHandler:
"""Handles keeping the *_stats tables updated with a simple time-series of
information about the users, rooms and media on the server, such that admins
have some idea of who is consuming their resources.
@@ -35,7 +32,6 @@ class StatsHandler(StateDeltasHandler):
"""
def __init__(self, hs):
- super(StatsHandler, self).__init__(hs)
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -68,20 +64,18 @@ class StatsHandler(StateDeltasHandler):
self._is_processing = True
- @defer.inlineCallbacks
- def process():
+ async def process():
try:
- yield self._unsafe_process()
+ await self._unsafe_process()
finally:
self._is_processing = False
run_as_background_process("stats.notify_new_event", process)
- @defer.inlineCallbacks
- def _unsafe_process(self):
+ async def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
- self.pos = yield self.store.get_stats_positions()
+ self.pos = await self.store.get_stats_positions()
# Loop round handling deltas until we're up to date
@@ -96,13 +90,13 @@ class StatsHandler(StateDeltasHandler):
logger.debug(
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
)
- max_pos, deltas = yield self.store.get_current_state_deltas(
+ max_pos, deltas = await self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
if deltas:
logger.debug("Handling %d state deltas", len(deltas))
- room_deltas, user_deltas = yield self._handle_deltas(deltas)
+ room_deltas, user_deltas = await self._handle_deltas(deltas)
else:
room_deltas = {}
user_deltas = {}
@@ -111,7 +105,7 @@ class StatsHandler(StateDeltasHandler):
(
room_count,
user_count,
- ) = yield self.store.get_changes_room_total_events_and_bytes(
+ ) = await self.store.get_changes_room_total_events_and_bytes(
self.pos, max_pos
)
@@ -125,7 +119,7 @@ class StatsHandler(StateDeltasHandler):
logger.debug("user_deltas: %s", user_deltas)
# Always call this so that we update the stats position.
- yield self.store.bulk_update_stats_delta(
+ await self.store.bulk_update_stats_delta(
self.clock.time_msec(),
updates={"room": room_deltas, "user": user_deltas},
stream_id=max_pos,
@@ -137,13 +131,12 @@ class StatsHandler(StateDeltasHandler):
self.pos = max_pos
- @defer.inlineCallbacks
- def _handle_deltas(self, deltas):
+ async def _handle_deltas(self, deltas):
"""Called with the state deltas to process
Returns:
- Deferred[tuple[dict[str, Counter], dict[str, counter]]]
- Resovles to two dicts, the room deltas and the user deltas,
+ tuple[dict[str, Counter], dict[str, counter]]
+ Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields.
"""
@@ -162,7 +155,7 @@ class StatsHandler(StateDeltasHandler):
logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id)
- token = yield self.store.get_earliest_token_for_stats("room", room_id)
+ token = await self.store.get_earliest_token_for_stats("room", room_id)
# If the earliest token to begin from is larger than our current
# stream ID, skip processing this delta.
@@ -184,7 +177,7 @@ class StatsHandler(StateDeltasHandler):
sender = None
if event_id is not None:
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if event:
event_content = event.content or {}
sender = event.sender
@@ -200,16 +193,16 @@ class StatsHandler(StateDeltasHandler):
room_stats_delta["current_state_events"] += 1
if typ == EventTypes.Member:
- # we could use _get_key_change here but it's a bit inefficient
- # given we're not testing for a specific result; might as well
- # just grab the prev_membership and membership strings and
- # compare them.
+ # we could use StateDeltasHandler._get_key_change here but it's
+ # a bit inefficient given we're not testing for a specific
+ # result; might as well just grab the prev_membership and
+ # membership strings and compare them.
# We take None rather than leave as a previous membership
# in the absence of a previous event because we do not want to
# reduce the leave count when a new-to-the-room user joins.
prev_membership = None
if prev_event_id is not None:
- prev_event = yield self.store.get_event(
+ prev_event = await self.store.get_event(
prev_event_id, allow_none=True
)
if prev_event:
@@ -301,6 +294,6 @@ class StatsHandler(StateDeltasHandler):
for room_id, state in room_to_state_updates.items():
logger.debug("Updating room_stats_state for %s: %s", room_id, state)
- yield self.store.update_room_state(room_id, state)
+ await self.store.update_room_state(room_id, state)
return room_to_stats_deltas, user_to_stats_deltas
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 00718d7f..6bdb24ba 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1370,7 +1370,7 @@ class SyncHandler(object):
sync_result_builder.now_token = now_token
# We check up front if anything has changed, if it hasn't then there is
- # no point in going futher.
+ # no point in going further.
since_token = sync_result_builder.since_token
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8363d887..8b24a733 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker:
self.hs = hs
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def _check_threepid(self, medium, authdict):
+ async def _check_threepid(self, medium, authdict):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker:
raise SynapseError(
400, "Phone number verification is not enabled on this homeserver"
)
- threepid = yield identity_handler.threepid_from_creds(
+ threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
- threepid = yield identity_handler.threepid_from_creds(
+ threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
threepid = None
- row = yield self.store.get_threepid_validation_session(
+ row = await self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
@@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker:
}
# Valid threepid returned, delete from the db
- yield self.store.delete_threepid_session(threepid_creds["sid"])
+ await self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(
400, "Email address verification is not enabled on this homeserver"
@@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
)
def check_auth(self, authdict, clientip):
- return self._check_threepid("email", authdict)
+ return defer.ensureDeferred(self._check_threepid("email", authdict))
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return bool(self.hs.config.account_threepid_delegate_msisdn)
def check_auth(self, authdict, clientip):
- return self._check_threepid("msisdn", authdict)
+ return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
INTERACTIVE_AUTH_CHECKERS = [
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 722760c5..12423b90 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -17,14 +17,11 @@ import logging
from six import iteritems, iterkeys
-from twisted.internet import defer
-
import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
-from synapse.types import get_localpart_from_id
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -103,43 +100,39 @@ class UserDirectoryHandler(StateDeltasHandler):
if self._is_processing:
return
- @defer.inlineCallbacks
- def process():
+ async def process():
try:
- yield self._unsafe_process()
+ await self._unsafe_process()
finally:
self._is_processing = False
self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process)
- @defer.inlineCallbacks
- def handle_local_profile_change(self, user_id, profile):
+ async def handle_local_profile_change(self, user_id, profile):
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- is_support = yield self.store.is_support_user(user_id)
+ is_support = await self.store.is_support_user(user_id)
# Support users are for diagnostics and should not appear in the user directory.
if not is_support:
- yield self.store.update_profile_in_user_dir(
+ await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- @defer.inlineCallbacks
- def handle_user_deactivated(self, user_id):
+ async def handle_user_deactivated(self, user_id):
"""Called when a user ID is deactivated
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- yield self.store.remove_from_user_dir(user_id)
+ await self.store.remove_from_user_dir(user_id)
- @defer.inlineCallbacks
- def _unsafe_process(self):
+ async def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
- self.pos = yield self.store.get_user_directory_stream_pos()
+ self.pos = await self.store.get_user_directory_stream_pos()
# If still None then the initial background update hasn't happened yet
if self.pos is None:
@@ -155,12 +148,12 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug(
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
)
- max_pos, deltas = yield self.store.get_current_state_deltas(
+ max_pos, deltas = await self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
logger.debug("Handling %d state deltas", len(deltas))
- yield self._handle_deltas(deltas)
+ await self._handle_deltas(deltas)
self.pos = max_pos
@@ -169,10 +162,9 @@ class UserDirectoryHandler(StateDeltasHandler):
max_pos
)
- yield self.store.update_user_directory_stream_pos(max_pos)
+ await self.store.update_user_directory_stream_pos(max_pos)
- @defer.inlineCallbacks
- def _handle_deltas(self, deltas):
+ async def _handle_deltas(self, deltas):
"""Called with the state deltas to process
"""
for delta in deltas:
@@ -187,11 +179,11 @@ class UserDirectoryHandler(StateDeltasHandler):
# For join rule and visibility changes we need to check if the room
# may have become public or not and add/remove the users in said room
if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
- yield self._handle_room_publicity_change(
+ await self._handle_room_publicity_change(
room_id, prev_event_id, event_id, typ
)
elif typ == EventTypes.Member:
- change = yield self._get_key_change(
+ change = await self._get_key_change(
prev_event_id,
event_id,
key_name="membership",
@@ -201,7 +193,7 @@ class UserDirectoryHandler(StateDeltasHandler):
if change is False:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
- is_in_room = yield self.store.is_host_joined(
+ is_in_room = await self.store.is_host_joined(
room_id, self.server_name
)
if not is_in_room:
@@ -209,40 +201,41 @@ class UserDirectoryHandler(StateDeltasHandler):
# Fetch all the users that we marked as being in user
# directory due to being in the room and then check if
# need to remove those users or not
- user_ids = yield self.store.get_users_in_dir_due_to_room(
+ user_ids = await self.store.get_users_in_dir_due_to_room(
room_id
)
for user_id in user_ids:
- yield self._handle_remove_user(room_id, user_id)
+ await self._handle_remove_user(room_id, user_id)
return
else:
logger.debug("Server is still in room: %r", room_id)
- is_support = yield self.store.is_support_user(state_key)
+ is_support = await self.store.is_support_user(state_key)
if not is_support:
if change is None:
# Handle any profile changes
- yield self._handle_profile_change(
+ await self._handle_profile_change(
state_key, room_id, prev_event_id, event_id
)
continue
if change: # The user joined
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
)
- yield self._handle_new_user(room_id, state_key, profile)
+ await self._handle_new_user(room_id, state_key, profile)
else: # The user left
- yield self._handle_remove_user(room_id, state_key)
+ await self._handle_remove_user(room_id, state_key)
else:
logger.debug("Ignoring irrelevant type: %r", typ)
- @defer.inlineCallbacks
- def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ):
+ async def _handle_room_publicity_change(
+ self, room_id, prev_event_id, event_id, typ
+ ):
"""Handle a room having potentially changed from/to world_readable/publically
joinable.
@@ -255,14 +248,14 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Handling change for %s: %s", typ, room_id)
if typ == EventTypes.RoomHistoryVisibility:
- change = yield self._get_key_change(
+ change = await self._get_key_change(
prev_event_id,
event_id,
key_name="history_visibility",
public_value="world_readable",
)
elif typ == EventTypes.JoinRules:
- change = yield self._get_key_change(
+ change = await self._get_key_change(
prev_event_id,
event_id,
key_name="join_rule",
@@ -278,7 +271,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# There's been a change to or from being world readable.
- is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
@@ -293,11 +286,11 @@ class UserDirectoryHandler(StateDeltasHandler):
# ignore the change
return
- users_with_profile = yield self.state.get_current_users_in_room(room_id)
+ users_with_profile = await self.state.get_current_users_in_room(room_id)
# Remove every user from the sharing tables for that room.
for user_id in iterkeys(users_with_profile):
- yield self.store.remove_user_who_share_room(user_id, room_id)
+ await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables.
# NOTE: this is not the most efficient method, as handle_new_user sets
@@ -306,26 +299,9 @@ class UserDirectoryHandler(StateDeltasHandler):
# being added multiple times. The batching upserts shouldn't make this
# too bad, though.
for user_id, profile in iteritems(users_with_profile):
- yield self._handle_new_user(room_id, user_id, profile)
-
- @defer.inlineCallbacks
- def _handle_local_user(self, user_id):
- """Adds a new local roomless user into the user_directory_search table.
- Used to populate up the user index when we have an
- user_directory_search_all_users specified.
- """
- logger.debug("Adding new local user to dir, %r", user_id)
-
- profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
-
- row = yield self.store.get_user_in_directory(user_id)
- if not row:
- yield self.store.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
- )
+ await self._handle_new_user(room_id, user_id, profile)
- @defer.inlineCallbacks
- def _handle_new_user(self, room_id, user_id, profile):
+ async def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory
Args:
@@ -334,18 +310,18 @@ class UserDirectoryHandler(StateDeltasHandler):
"""
logger.debug("Adding new user to dir, %r", user_id)
- yield self.store.update_profile_in_user_dir(
+ await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+ is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
# Now we update users who share rooms with users.
- users_with_profile = yield self.state.get_current_users_in_room(room_id)
+ users_with_profile = await self.state.get_current_users_in_room(room_id)
if is_public:
- yield self.store.add_users_in_public_rooms(room_id, (user_id,))
+ await self.store.add_users_in_public_rooms(room_id, (user_id,))
else:
to_insert = set()
@@ -376,10 +352,9 @@ class UserDirectoryHandler(StateDeltasHandler):
to_insert.add((other_user_id, user_id))
if to_insert:
- yield self.store.add_users_who_share_private_room(room_id, to_insert)
+ await self.store.add_users_who_share_private_room(room_id, to_insert)
- @defer.inlineCallbacks
- def _handle_remove_user(self, room_id, user_id):
+ async def _handle_remove_user(self, room_id, user_id):
"""Called when we might need to remove user from directory
Args:
@@ -389,24 +364,23 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Removing user %r", user_id)
# Remove user from sharing tables
- yield self.store.remove_user_who_share_room(user_id, room_id)
+ await self.store.remove_user_who_share_room(user_id, room_id)
# Are they still in any rooms? If not, remove them entirely.
- rooms_user_is_in = yield self.store.get_user_dir_rooms_user_is_in(user_id)
+ rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
if len(rooms_user_is_in) == 0:
- yield self.store.remove_from_user_dir(user_id)
+ await self.store.remove_from_user_dir(user_id)
- @defer.inlineCallbacks
- def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+ async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
"""Check member event changes for any profile changes and update the
database if there are.
"""
if not prev_event_id or not event_id:
return
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
- event = yield self.store.get_event(event_id, allow_none=True)
+ prev_event = await self.store.get_event(prev_event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if not prev_event or not event:
return
@@ -421,4 +395,4 @@ class UserDirectoryHandler(StateDeltasHandler):
new_avatar = event.content.get("avatar_url")
if prev_name != new_name or prev_avatar != new_avatar:
- yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
+ await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 37975458..3cef747a 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util.async_helpers import timeout_deferred
-from synapse.util.caches import CACHE_SIZE_FACTOR
logger = logging.getLogger(__name__)
@@ -241,7 +240,10 @@ class SimpleHttpClient(object):
# tends to do so in batches, so we need to allow the pool to keep
# lots of idle connections around.
pool = HTTPConnectionPool(self.reactor)
- pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+ # XXX: The justification for using the cache factor here is that larger instances
+ # will need both more cache and more connections.
+ # Still, this should probably be a separate dial
+ pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
pool.cachedConnectionTimeout = 2 * 60
self.agent = ProxyAgent(
@@ -359,6 +361,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
+ b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
@@ -399,6 +402,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
+ b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
@@ -434,6 +438,10 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
+ actual_headers = {b"Accept": [b"application/json"]}
+ if headers:
+ actual_headers.update(headers)
+
body = yield self.get_raw(uri, args, headers=headers)
return json.loads(body)
@@ -467,6 +475,7 @@ class SimpleHttpClient(object):
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
+ b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 6b0a532c..2d47b9ea 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,7 +19,7 @@ import random
import sys
from io import BytesIO
-from six import PY3, raise_from, string_types
+from six import raise_from, string_types
from six.moves import urllib
import attr
@@ -70,11 +70,7 @@ incoming_responses_counter = Counter(
MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3
-
-if PY3:
- MAXINT = sys.maxsize
-else:
- MAXINT = sys.maxint
+MAXINT = sys.maxsize
_next_id = 1
@@ -148,6 +144,11 @@ def _handle_json_response(reactor, timeout_sec, request, response):
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = yield make_deferred_yieldable(d)
+ except TimeoutError as e:
+ logger.warning(
+ "{%s} [%s] Timed out reading response", request.txn_id, request.destination,
+ )
+ raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
@@ -408,7 +409,7 @@ class MatrixFederationHttpClient(object):
_sec_timeout,
)
- outgoing_requests_counter.labels(method_bytes).inc()
+ outgoing_requests_counter.labels(request.method).inc()
try:
with Measure(self.clock, "outbound_request"):
@@ -428,13 +429,17 @@ class MatrixFederationHttpClient(object):
)
response = yield request_deferred
+ except TimeoutError as e:
+ raise RequestSendFailed(e, can_retry=True) from e
except DNSLookupError as e:
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
except Exception as e:
logger.info("Failed to send request: %s", e)
raise_from(RequestSendFailed(e, can_retry=True), e)
- incoming_responses_counter.labels(method_bytes, response.code).inc()
+ incoming_responses_counter.labels(
+ request.method, response.code
+ ).inc()
set_tag(tags.HTTP_STATUS_CODE, response.code)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 042a6051..2487a721 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,13 +21,15 @@ import logging
import types
import urllib
from io import BytesIO
+from typing import Awaitable, Callable, TypeVar, Union
+import jinja2
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
from twisted.internet import defer
from twisted.python import failure
from twisted.web import resource
-from twisted.web.server import NOT_DONE_YET
+from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import NoRangeStaticProducer
from twisted.web.util import redirectTo
@@ -40,6 +42,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
+from synapse.http.site import SynapseRequest
from synapse.logging.context import preserve_fn
from synapse.logging.opentracing import trace_servlet
from synapse.util.caches import intern_dict
@@ -130,7 +133,12 @@ def wrap_json_request_handler(h):
return wrap_async_request_handler(wrapped_request_handler)
-def wrap_html_request_handler(h):
+TV = TypeVar("TV")
+
+
+def wrap_html_request_handler(
+ h: Callable[[TV, SynapseRequest], Awaitable]
+) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
"""Wraps a request handler method with exception handling.
Also does the wrapping with request.processing as per wrap_async_request_handler.
@@ -141,20 +149,26 @@ def wrap_html_request_handler(h):
async def wrapped_request_handler(self, request):
try:
- return await h(self, request)
+ await h(self, request)
except Exception:
f = failure.Failure()
- return _return_html_error(f, request)
+ return_html_error(f, request, HTML_ERROR_TEMPLATE)
return wrap_async_request_handler(wrapped_request_handler)
-def _return_html_error(f, request):
- """Sends an HTML error page corresponding to the given failure
+def return_html_error(
+ f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
+) -> None:
+ """Sends an HTML error page corresponding to the given failure.
+
+ Handles RedirectException and other CodeMessageExceptions (such as SynapseError)
Args:
- f (twisted.python.failure.Failure):
- request (twisted.web.server.Request):
+ f: the error to report
+ request: the failing request
+ error_template: the HTML template. Can be either a string (with `{code}`,
+ `{msg}` placeholders), or a jinja2 template
"""
if f.check(CodeMessageException):
cme = f.value
@@ -174,7 +188,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
else:
- code = http.client.INTERNAL_SERVER_ERROR
+ code = http.HTTPStatus.INTERNAL_SERVER_ERROR
msg = "Internal server error"
logger.error(
@@ -183,11 +197,16 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
+ if isinstance(error_template, str):
+ body = error_template.format(code=code, msg=html.escape(msg))
+ else:
+ body = error_template.render(code=code, msg=msg)
+
+ body_bytes = body.encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%i" % (len(body),))
- request.write(body)
+ request.setHeader(b"Content-Length", b"%i" % (len(body_bytes),))
+ request.write(body_bytes)
finish_request(request)
@@ -350,9 +369,6 @@ class JsonResource(HttpServer, resource.Resource):
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
"""
- if request.method == b"OPTIONS":
- return _options_handler, "options_request_handler", {}
-
request_path = request.path.decode("ascii")
# Loop through all the registered callbacks to check if the method
@@ -448,6 +464,26 @@ class RootRedirect(resource.Resource):
return resource.Resource.getChild(self, name, request)
+class OptionsResource(resource.Resource):
+ """Responds to OPTION requests for itself and all children."""
+
+ def render_OPTIONS(self, request):
+ code, response_json_object = _options_handler(request)
+
+ return respond_with_json(
+ request, code, response_json_object, send_cors=True, canonical_json=False,
+ )
+
+ def getChildWithDefault(self, path, request):
+ if request.method == b"OPTIONS":
+ return self # select ourselves as the child to render
+ return resource.Resource.getChildWithDefault(self, path, request)
+
+
+class RootOptionsRedirectResource(OptionsResource, RootRedirect):
+ pass
+
+
def respond_with_json(
request,
code,
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 32feb0d9..167293c4 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,9 @@
import contextlib
import logging
import time
+from typing import Optional
+from twisted.python.failure import Failure
from twisted.web.server import Request, Site
from synapse.http import redact_uri
@@ -44,7 +46,7 @@ class SynapseRequest(Request):
request even after the client has disconnected.
Attributes:
- logcontext(LoggingContext) : the log context for this request
+ logcontext: the log context for this request
"""
def __init__(self, channel, *args, **kw):
@@ -52,10 +54,10 @@ class SynapseRequest(Request):
self.site = channel.site
self._channel = channel # this is used by the tests
self.authenticated_entity = None
- self.start_time = 0
+ self.start_time = 0.0
# we can't yet create the logcontext, as we don't know the method.
- self.logcontext = None
+ self.logcontext = None # type: Optional[LoggingContext]
global _next_request_seq
self.request_seq = _next_request_seq
@@ -181,6 +183,7 @@ class SynapseRequest(Request):
self.finish_time = time.time()
Request.finish(self)
if not self._is_processing:
+ assert self.logcontext is not None
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
@@ -190,6 +193,12 @@ class SynapseRequest(Request):
Overrides twisted.web.server.Request.connectionLost to record the finish time and
do logging.
"""
+ # There is a bug in Twisted where reason is not wrapped in a Failure object
+ # Detect this and wrap it manually as a workaround
+ # More information: https://github.com/matrix-org/synapse/issues/7441
+ if not isinstance(reason, Failure):
+ reason = Failure(reason)
+
self.finish_time = time.time()
Request.connectionLost(self, reason)
@@ -242,6 +251,7 @@ class SynapseRequest(Request):
def _finished_processing(self):
"""Log the completion of this request and update the metrics
"""
+ assert self.logcontext is not None
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 0c2527bd..99049bb5 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -20,8 +20,6 @@ import time
from functools import wraps
from inspect import getcallargs
-from six import PY3
-
_TIME_FUNC_ID = 0
@@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG):
- if PY3:
- lineno = f.__code__.co_firstlineno
- pathname = f.__code__.co_filename
- else:
- lineno = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
+ lineno = f.__code__.co_firstlineno
+ pathname = f.__code__.co_filename
record = logging.LogRecord(
name=name,
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index d2fd29ac..9cf31f96 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -26,7 +26,12 @@ import six
import attr
from prometheus_client import Counter, Gauge, Histogram
-from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricFamily
+from prometheus_client.core import (
+ REGISTRY,
+ CounterMetricFamily,
+ GaugeMetricFamily,
+ HistogramMetricFamily,
+)
from twisted.internet import reactor
@@ -338,6 +343,78 @@ class GCCounts(object):
if not running_on_pypy:
REGISTRY.register(GCCounts())
+
+#
+# PyPy GC / memory metrics
+#
+
+
+class PyPyGCStats(object):
+ def collect(self):
+
+ # @stats is a pretty-printer object with __str__() returning a nice table,
+ # plus some fields that contain data from that table.
+ # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB').
+ stats = gc.get_stats(memory_pressure=False) # type: ignore
+ # @s contains same fields as @stats, but as actual integers.
+ s = stats._s # type: ignore
+
+ # also note that field naming is completely braindead
+ # and only vaguely correlates with the pretty-printed table.
+ # >>>> gc.get_stats(False)
+ # Total memory consumed:
+ # GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory
+ # in arenas: 3.0MB # s.total_arena_memory
+ # rawmalloced: 1.7MB # s.total_rawmalloced_memory
+ # nursery: 4.0MB # s.nursery_size
+ # raw assembler used: 31.0kB # s.jit_backend_used
+ # -----------------------------
+ # Total: 8.8MB # stats.memory_used_sum
+ #
+ # Total memory allocated:
+ # GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory
+ # in arenas: 30.9MB # s.peak_arena_memory
+ # rawmalloced: 4.1MB # s.peak_rawmalloced_memory
+ # nursery: 4.0MB # s.nursery_size
+ # raw assembler allocated: 1.0MB # s.jit_backend_allocated
+ # -----------------------------
+ # Total: 39.7MB # stats.memory_allocated_sum
+ #
+ # Total time spent in GC: 0.073 # s.total_gc_time
+
+ pypy_gc_time = CounterMetricFamily(
+ "pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[],
+ )
+ pypy_gc_time.add_metric([], s.total_gc_time / 1000)
+ yield pypy_gc_time
+
+ pypy_mem = GaugeMetricFamily(
+ "pypy_memory_bytes",
+ "Memory tracked by PyPy allocator",
+ labels=["state", "class", "kind"],
+ )
+ # memory used by JIT assembler
+ pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used)
+ pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated)
+ # memory used by GCed objects
+ pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory)
+ pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory)
+ pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory)
+ pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory)
+ pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size)
+ pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size)
+ # totals
+ pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory)
+ pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory)
+ pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory)
+ pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory)
+ yield pypy_mem
+
+
+if running_on_pypy:
+ REGISTRY.register(PyPyGCStats())
+
+
#
# Twisted reactor metrics
#
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index a2481031..ab7f948e 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -33,6 +33,8 @@ from prometheus_client import REGISTRY
from twisted.web.resource import Resource
+from synapse.util import caches
+
try:
from prometheus_client.samples import Sample
except ImportError:
@@ -103,13 +105,15 @@ def nameify_sample(sample):
def generate_latest(registry, emit_help=False):
- output = []
- for metric in registry.collect():
+ # Trigger the cache metrics to be rescraped, which updates the common
+ # metrics but do not produce metrics themselves
+ for collector in caches.collectors_by_name.values():
+ collector.collect()
- if metric.name.startswith("__unused"):
- continue
+ output = []
+ for metric in registry.collect():
if not metric.samples:
# No samples, don't bother.
continue
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 8449ef82..13785038 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -17,16 +17,18 @@ import logging
import threading
from asyncio import iscoroutine
from functools import wraps
-from typing import Dict, Set
+from typing import TYPE_CHECKING, Dict, Optional, Set
-import six
-
-from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
+from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
+if TYPE_CHECKING:
+ import resource
+
+
logger = logging.getLogger(__name__)
@@ -36,6 +38,12 @@ _background_process_start_count = Counter(
["name"],
)
+_background_process_in_flight_count = Gauge(
+ "synapse_background_process_in_flight_count",
+ "Number of background processes in flight",
+ labelnames=["name"],
+)
+
# we set registry=None in all of these to stop them getting registered with
# the default registry. Instead we collect them all via the CustomCollector,
# which ensures that we can update them before they are collected.
@@ -83,13 +91,17 @@ _background_process_db_sched_duration = Counter(
# it's much simpler to do so than to try to combine them.)
_background_process_counts = {} # type: Dict[str, int]
-# map from description to the currently running background processes.
+# Set of all running background processes that became active active since the
+# last time metrics were scraped (i.e. background processes that performed some
+# work since the last scrape.)
#
-# it's kept as a dict of sets rather than a big set so that we can keep track
-# of process descriptions that no longer have any active processes.
-_background_processes = {} # type: Dict[str, Set[_BackgroundProcess]]
+# We do it like this to handle the case where we have a large number of
+# background processes stacking up behind a lock or linearizer, where we then
+# only need to iterate over and update metrics for the process that have
+# actually been active and can ignore the idle ones.
+_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess]
-# A lock that covers the above dicts
+# A lock that covers the above set and dict
_bg_metrics_lock = threading.Lock()
@@ -101,25 +113,16 @@ class _Collector(object):
"""
def collect(self):
- background_process_in_flight_count = GaugeMetricFamily(
- "synapse_background_process_in_flight_count",
- "Number of background processes in flight",
- labels=["name"],
- )
+ global _background_processes_active_since_last_scrape
- # We copy the dict so that it doesn't change from underneath us.
- # We also copy the process lists as that can also change
+ # We swap out the _background_processes set with an empty one so that
+ # we can safely iterate over the set without holding the lock.
with _bg_metrics_lock:
- _background_processes_copy = {
- k: list(v) for k, v in six.iteritems(_background_processes)
- }
+ _background_processes_copy = _background_processes_active_since_last_scrape
+ _background_processes_active_since_last_scrape = set()
- for desc, processes in six.iteritems(_background_processes_copy):
- background_process_in_flight_count.add_metric((desc,), len(processes))
- for process in processes:
- process.update_metrics()
-
- yield background_process_in_flight_count
+ for process in _background_processes_copy:
+ process.update_metrics()
# now we need to run collect() over each of the static Counters, and
# yield each metric they return.
@@ -191,13 +194,10 @@ def run_as_background_process(desc, func, *args, **kwargs):
_background_process_counts[desc] = count + 1
_background_process_start_count.labels(desc).inc()
+ _background_process_in_flight_count.labels(desc).inc()
- with LoggingContext(desc) as context:
+ with BackgroundProcessLoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count)
- proc = _BackgroundProcess(desc, context)
-
- with _bg_metrics_lock:
- _background_processes.setdefault(desc, set()).add(proc)
try:
result = func(*args, **kwargs)
@@ -214,10 +214,7 @@ def run_as_background_process(desc, func, *args, **kwargs):
except Exception:
logger.exception("Background process '%s' threw an exception", desc)
finally:
- proc.update_metrics()
-
- with _bg_metrics_lock:
- _background_processes[desc].remove(proc)
+ _background_process_in_flight_count.labels(desc).dec()
with PreserveLoggingContext():
return run()
@@ -238,3 +235,42 @@ def wrap_as_background_process(desc):
return wrap_as_background_process_inner_2
return wrap_as_background_process_inner
+
+
+class BackgroundProcessLoggingContext(LoggingContext):
+ """A logging context that tracks in flight metrics for background
+ processes.
+ """
+
+ __slots__ = ["_proc"]
+
+ def __init__(self, name: str):
+ super().__init__(name)
+
+ self._proc = _BackgroundProcess(name, self)
+
+ def start(self, rusage: "Optional[resource._RUsage]"):
+ """Log context has started running (again).
+ """
+
+ super().start(rusage)
+
+ # We've become active again so we make sure we're in the list of active
+ # procs. (Note that "start" here means we've become active, as opposed
+ # to starting for the first time.)
+ with _bg_metrics_lock:
+ _background_processes_active_since_last_scrape.add(self._proc)
+
+ def __exit__(self, type, value, traceback) -> None:
+ """Log context has finished.
+ """
+
+ super().__exit__(type, value, traceback)
+
+ # The background process has finished. We explictly remove and manually
+ # update the metrics here so that if nothing is scraping metrics the set
+ # doesn't infinitely grow.
+ with _bg_metrics_lock:
+ _background_processes_active_since_last_scrape.discard(self._proc)
+
+ self._proc.update_metrics()
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d678c0eb..ecdf1ad6 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -128,8 +128,12 @@ class ModuleApi(object):
Returns:
Deferred[str]: user_id
"""
- return self._hs.get_registration_handler().register_user(
- localpart=localpart, default_display_name=displayname, bind_emails=emails
+ return defer.ensureDeferred(
+ self._hs.get_registration_handler().register_user(
+ localpart=localpart,
+ default_display_name=displayname,
+ bind_emails=emails,
+ )
)
def register_device(self, user_id, device_id=None, initial_display_name=None):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 433ca2f4..e75d964a 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache(
"cache",
"push_rules_delta_state_cache_metric",
cache=[], # Meaningless size, as this isn't a cache that stores values
+ resizable=False,
)
@@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object):
self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
- cache=[], # Meaningless size, as this isn't a cache that stores values
+ cache=[], # Meaningless size, as this isn't a cache that stores values,
+ resizable=False,
)
@defer.inlineCallbacks
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index ba4551d6..568c13ea 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -15,7 +15,6 @@
import logging
-from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -132,8 +131,7 @@ class EmailPusher(object):
self._is_processing = False
self._start_processing()
- @defer.inlineCallbacks
- def _process(self):
+ async def _process(self):
# we should never get here if we are already processing
assert not self._is_processing
@@ -142,7 +140,7 @@ class EmailPusher(object):
if self.throttle_params is None:
# this is our first loop: load up the throttle params
- self.throttle_params = yield self.store.get_throttle_params_by_room(
+ self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
@@ -151,7 +149,7 @@ class EmailPusher(object):
while True:
starting_max_ordering = self.max_stream_ordering
try:
- yield self._unsafe_process()
+ await self._unsafe_process()
except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
@@ -159,8 +157,7 @@ class EmailPusher(object):
finally:
self._is_processing = False
- @defer.inlineCallbacks
- def _unsafe_process(self):
+ async def _unsafe_process(self):
"""
Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it
@@ -168,12 +165,12 @@ class EmailPusher(object):
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
- unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
+ unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
soonest_due_at = None
if not unprocessed:
- yield self.save_last_stream_ordering_and_success(self.max_stream_ordering)
+ await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
return
for push_action in unprocessed:
@@ -201,15 +198,15 @@ class EmailPusher(object):
"throttle_ms": self.get_room_throttle_ms(push_action["room_id"]),
}
- yield self.send_notification(unprocessed, reason)
+ await self.send_notification(unprocessed, reason)
- yield self.save_last_stream_ordering_and_success(
+ await self.save_last_stream_ordering_and_success(
max(ea["stream_ordering"] for ea in unprocessed)
)
# we update the throttle on all the possible unprocessed push actions
for ea in unprocessed:
- yield self.sent_notif_update_throttle(ea["room_id"], ea)
+ await self.sent_notif_update_throttle(ea["room_id"], ea)
break
else:
if soonest_due_at is None or should_notify_at < soonest_due_at:
@@ -227,14 +224,13 @@ class EmailPusher(object):
self.seconds_until(soonest_due_at), self.on_timer
)
- @defer.inlineCallbacks
- def save_last_stream_ordering_and_success(self, last_stream_ordering):
+ async def save_last_stream_ordering_and_success(self, last_stream_ordering):
if last_stream_ordering is None:
# This happens if we haven't yet processed anything
return
self.last_stream_ordering = last_stream_ordering
- pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.email,
self.user_id,
@@ -275,13 +271,12 @@ class EmailPusher(object):
may_send_at = last_sent_ts + throttle_ms
return may_send_at
- @defer.inlineCallbacks
- def sent_notif_update_throttle(self, room_id, notified_push_action):
+ async def sent_notif_update_throttle(self, room_id, notified_push_action):
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased.
- time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before(
+ time_of_previous_notifs = await self.store.get_time_of_last_push_action_before(
notified_push_action["stream_ordering"]
)
@@ -310,14 +305,13 @@ class EmailPusher(object):
"last_sent_ts": self.clock.time_msec(),
"throttle_ms": new_throttle_ms,
}
- yield self.store.set_throttle_params(
+ await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)
- @defer.inlineCallbacks
- def send_notification(self, push_actions, reason):
+ async def send_notification(self, push_actions, reason):
logger.info("Sending notif email for user %r", self.user_id)
- yield self.mailer.send_notification_mail(
+ await self.mailer.send_notification_mail(
self.app_id, self.user_id, self.email, push_actions, reason
)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 5bb17d12..eaaa7afc 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -15,8 +15,6 @@
# limitations under the License.
import logging
-import six
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -28,9 +26,6 @@ from synapse.push import PusherConfigException
from . import push_rule_evaluator, push_tools
-if six.PY3:
- long = int
-
logger = logging.getLogger(__name__)
http_push_processed_counter = Counter(
@@ -318,7 +313,7 @@ class HttpPusher(object):
{
"app_id": self.app_id,
"pushkey": self.pushkey,
- "pushkey_ts": long(self.pushkey_ts / 1000),
+ "pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url,
}
],
@@ -347,7 +342,7 @@ class HttpPusher(object):
{
"app_id": self.app_id,
"pushkey": self.pushkey,
- "pushkey_ts": long(self.pushkey_ts / 1000),
+ "pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url,
"tweaks": tweaks,
}
@@ -409,7 +404,7 @@ class HttpPusher(object):
{
"app_id": self.app_id,
"pushkey": self.pushkey,
- "pushkey_ts": long(self.pushkey_ts / 1000),
+ "pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url,
}
],
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ab33abbe..d57a66a6 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -26,8 +26,6 @@ from six.moves import urllib
import bleach
import jinja2
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
@@ -127,8 +125,7 @@ class Mailer(object):
logger.info("Created Mailer for app_name %s" % app_name)
- @defer.inlineCallbacks
- def send_password_reset_mail(self, email_address, token, client_secret, sid):
+ async def send_password_reset_mail(self, email_address, token, client_secret, sid):
"""Send an email with a password reset link to a user
Args:
@@ -149,14 +146,13 @@ class Mailer(object):
template_vars = {"link": link}
- yield self.send_email(
+ await self.send_email(
email_address,
"[%s] Password Reset" % self.hs.config.server_name,
template_vars,
)
- @defer.inlineCallbacks
- def send_registration_mail(self, email_address, token, client_secret, sid):
+ async def send_registration_mail(self, email_address, token, client_secret, sid):
"""Send an email with a registration confirmation link to a user
Args:
@@ -177,14 +173,13 @@ class Mailer(object):
template_vars = {"link": link}
- yield self.send_email(
+ await self.send_email(
email_address,
"[%s] Register your Email Address" % self.hs.config.server_name,
template_vars,
)
- @defer.inlineCallbacks
- def send_add_threepid_mail(self, email_address, token, client_secret, sid):
+ async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
"""Send an email with a validation link to a user for adding a 3pid to their account
Args:
@@ -206,20 +201,19 @@ class Mailer(object):
template_vars = {"link": link}
- yield self.send_email(
+ await self.send_email(
email_address,
"[%s] Validate Your Email" % self.hs.config.server_name,
template_vars,
)
- @defer.inlineCallbacks
- def send_notification_mail(
+ async def send_notification_mail(
self, app_id, user_id, email_address, push_actions, reason
):
"""Send email regarding a user's room notifications"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
- notif_events = yield self.store.get_events(
+ notif_events = await self.store.get_events(
[pa["event_id"] for pa in push_actions]
)
@@ -232,7 +226,7 @@ class Mailer(object):
state_by_room = {}
try:
- user_display_name = yield self.store.get_profile_displayname(
+ user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@@ -240,14 +234,13 @@ class Mailer(object):
except StoreError:
user_display_name = user_id
- @defer.inlineCallbacks
- def _fetch_room_state(room_id):
- room_state = yield self.store.get_current_state_ids(room_id)
+ async def _fetch_room_state(room_id):
+ room_state = await self.store.get_current_state_ids(room_id)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
# notifs are much less realtime than sync so we can afford to wait a bit.
- yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
+ await concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
@@ -255,19 +248,19 @@ class Mailer(object):
rooms = []
for r in rooms_in_order:
- roomvars = yield self.get_room_vars(
+ roomvars = await self.get_room_vars(
r, user_id, notifs_by_room[r], notif_events, state_by_room[r]
)
rooms.append(roomvars)
- reason["room_name"] = yield calculate_room_name(
+ reason["room_name"] = await calculate_room_name(
self.store,
state_by_room[reason["room_id"]],
user_id,
fallback_to_members=True,
)
- summary_text = yield self.make_summary_text(
+ summary_text = await self.make_summary_text(
notifs_by_room, state_by_room, notif_events, user_id, reason
)
@@ -282,12 +275,11 @@ class Mailer(object):
"reason": reason,
}
- yield self.send_email(
+ await self.send_email(
email_address, "[%s] %s" % (self.app_name, summary_text), template_vars
)
- @defer.inlineCallbacks
- def send_email(self, email_address, subject, template_vars):
+ async def send_email(self, email_address, subject, template_vars):
"""Send an email with the given information and template text"""
try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@@ -317,7 +309,7 @@ class Mailer(object):
logger.info("Sending email to %s" % email_address)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
raw_from,
@@ -332,13 +324,14 @@ class Mailer(object):
)
)
- @defer.inlineCallbacks
- def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
+ async def get_room_vars(
+ self, room_id, user_id, notifs, notif_events, room_state_ids
+ ):
my_member_event_id = room_state_ids[("m.room.member", user_id)]
- my_member_event = yield self.store.get_event(my_member_event_id)
+ my_member_event = await self.store.get_event(my_member_event_id)
is_invite = my_member_event.content["membership"] == "invite"
- room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
+ room_name = await calculate_room_name(self.store, room_state_ids, user_id)
room_vars = {
"title": room_name,
@@ -350,7 +343,7 @@ class Mailer(object):
if not is_invite:
for n in notifs:
- notifvars = yield self.get_notif_vars(
+ notifvars = await self.get_notif_vars(
n, user_id, notif_events[n["event_id"]], room_state_ids
)
@@ -377,9 +370,8 @@ class Mailer(object):
return room_vars
- @defer.inlineCallbacks
- def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
- results = yield self.store.get_events_around(
+ async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
+ results = await self.store.get_events_around(
notif["room_id"],
notif["event_id"],
before_limit=CONTEXT_BEFORE,
@@ -392,25 +384,24 @@ class Mailer(object):
"messages": [],
}
- the_events = yield filter_events_for_client(
+ the_events = await filter_events_for_client(
self.storage, user_id, results["events_before"]
)
the_events.append(notif_event)
for event in the_events:
- messagevars = yield self.get_message_vars(notif, event, room_state_ids)
+ messagevars = await self.get_message_vars(notif, event, room_state_ids)
if messagevars is not None:
ret["messages"].append(messagevars)
return ret
- @defer.inlineCallbacks
- def get_message_vars(self, notif, event, room_state_ids):
+ async def get_message_vars(self, notif, event, room_state_ids):
if event.type != EventTypes.Message:
return
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
- sender_state_event = yield self.store.get_event(sender_state_event_id)
+ sender_state_event = await self.store.get_event(sender_state_event_id)
sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = sender_state_event.content.get("avatar_url")
@@ -460,8 +451,7 @@ class Mailer(object):
return messagevars
- @defer.inlineCallbacks
- def make_summary_text(
+ async def make_summary_text(
self, notifs_by_room, room_state_ids, notif_events, user_id, reason
):
if len(notifs_by_room) == 1:
@@ -471,17 +461,17 @@ class Mailer(object):
# If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room"
- room_name = yield calculate_room_name(
+ room_name = await calculate_room_name(
self.store, room_state_ids[room_id], user_id, fallback_to_members=False
)
my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)]
- my_member_event = yield self.store.get_event(my_member_event_id)
+ my_member_event = await self.store.get_event(my_member_event_id)
if my_member_event.content["membership"] == "invite":
inviter_member_event_id = room_state_ids[room_id][
("m.room.member", my_member_event.sender)
]
- inviter_member_event = yield self.store.get_event(
+ inviter_member_event = await self.store.get_event(
inviter_member_event_id
)
inviter_name = name_from_member_event(inviter_member_event)
@@ -506,7 +496,7 @@ class Mailer(object):
state_event_id = room_state_ids[room_id][
("m.room.member", event.sender)
]
- state_event = yield self.store.get_event(state_event_id)
+ state_event = await self.store.get_event(state_event_id)
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
@@ -535,7 +525,7 @@ class Mailer(object):
}
)
- member_events = yield self.store.get_events(
+ member_events = await self.store.get_events(
[
room_state_ids[room_id][("m.room.member", s)]
for s in sender_ids
@@ -567,7 +557,7 @@ class Mailer(object):
}
)
- member_events = yield self.store.get_events(
+ member_events = await self.store.get_events(
[room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4cd702b5..11032491 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -22,7 +22,7 @@ from six import string_types
from synapse.events import EventBase
from synapse.types import UserID
-from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
+from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object):
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
+regex_cache = LruCache(50000)
register_cache("cache", "regex_push_cache", regex_cache)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 39c99a28..8b4312e5 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -92,6 +92,7 @@ CONDITIONAL_REQUIREMENTS = {
'eliot<1.8.0;python_version<"3.5.3"',
],
"saml2": ["pysaml2>=4.5.0"],
+ "oidc": ["authlib>=0.14.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0", "parameterized"],
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 4613b253..19b69e0e 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -19,6 +19,7 @@ from synapse.replication.http import (
federation,
login,
membership,
+ presence,
register,
send_event,
streams,
@@ -34,9 +35,13 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
- membership.register_servlets(hs, self)
federation.register_servlets(hs, self)
- login.register_servlets(hs, self)
- register.register_servlets(hs, self)
- devices.register_servlets(hs, self)
- streams.register_servlets(hs, self)
+ presence.register_servlets(hs, self)
+ membership.register_servlets(hs, self)
+
+ # The following can't currently be instantiated on workers.
+ if hs.config.worker.worker_app is None:
+ login.register_servlets(hs, self)
+ register.register_servlets(hs, self)
+ devices.register_servlets(hs, self)
+ streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index f88c80ae..793cef6c 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -141,17 +141,29 @@ class ReplicationEndpoint(object):
Returns a callable that accepts the same parameters as `_serialize_payload`.
"""
clock = hs.get_clock()
- host = hs.config.worker_replication_host
- port = hs.config.worker_replication_http_port
-
client = hs.get_simple_http_client()
+ local_instance_name = hs.get_instance_name()
+
+ master_host = hs.config.worker_replication_host
+ master_port = hs.config.worker_replication_http_port
+
+ instance_map = hs.config.worker.instance_map
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(instance_name="master", **kwargs):
- # Currently we only support sending requests to master process.
- if instance_name != "master":
- raise Exception("Unknown instance")
+ if instance_name == local_instance_name:
+ raise Exception("Trying to send HTTP request to self")
+ if instance_name == "master":
+ host = master_host
+ port = master_port
+ elif instance_name in instance_map:
+ host = instance_map[instance_name].host
+ port = instance_map[instance_name].port
+ else:
+ raise Exception(
+ "Instance %r not in 'instance_map' config" % (instance_name,)
+ )
data = yield cls._serialize_payload(**kwargs)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 7e23b565..c287c4e2 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""Handles events newly received from federation, including persisting and
- notifying.
+ notifying. Returns the maximum stream ID of the persisted events.
The API looks like:
@@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"context": { .. serialized event context .. },
}],
"backfilled": false
+ }
+
+ 200 OK
+
+ {
+ "max_stream_id": 32443,
+ }
"""
NAME = "fed_send_events"
@@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
- await self.federation_handler.persist_events_and_notify(
+ max_stream_id = await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
)
- return 200, {}
+ return 200, {"max_stream_id": max_stream_id}
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 3577611f..a7174c4a 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,12 +14,16 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -76,11 +80,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
logger.info("remote_join: %s into room: %s", user_id, room_id)
- await self.federation_handler.do_invite_join(
+ event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user_id, event_content
)
- return 200, {}
+ return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -106,6 +110,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ self.member_handler = hs.get_room_member_handler()
@staticmethod
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
@@ -136,10 +141,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
- event = await self.federation_handler.do_remotely_reject_invite(
+ event, stream_id = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id, event_content,
)
- ret = event.get_pdu_json()
+ event_id = event.event_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@@ -149,10 +154,42 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
#
logger.warning("Failed to reject invite: %s", e)
- await self.store.locally_reject_invite(user_id, room_id)
- ret = {}
+ stream_id = await self.member_handler.locally_reject_invite(
+ user_id, room_id
+ )
+ event_id = None
+
+ return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
+class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
+ """Rejects the invite for the user and room locally.
+
+ Request format:
+
+ POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
+
+ {}
+ """
+
+ NAME = "locally_reject_invite"
+ PATH_ARGS = ("room_id", "user_id")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.member_handler = hs.get_room_member_handler()
+
+ @staticmethod
+ def _serialize_payload(room_id, user_id):
+ return {}
+
+ async def _handle_request(self, request, room_id, user_id):
+ logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
+
+ stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
- return 200, ret
+ return 200, {"stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
@@ -208,3 +245,4 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
+ ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
new file mode 100644
index 00000000..ea1b3333
--- /dev/null
+++ b/synapse/replication/http/presence.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.types import UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
+ """We've seen the user do something that indicates they're interacting
+ with the app.
+
+ The POST looks like:
+
+ POST /_synapse/replication/bump_presence_active_time/<user_id>
+
+ 200 OK
+
+ {}
+ """
+
+ NAME = "bump_presence_active_time"
+ PATH_ARGS = ("user_id",)
+ METHOD = "POST"
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self._presence_handler = hs.get_presence_handler()
+
+ @staticmethod
+ def _serialize_payload(user_id):
+ return {}
+
+ async def _handle_request(self, request, user_id):
+ await self._presence_handler.bump_presence_active_time(
+ UserID.from_string(user_id)
+ )
+
+ return (
+ 200,
+ {},
+ )
+
+
+class ReplicationPresenceSetState(ReplicationEndpoint):
+ """Set the presence state for a user.
+
+ The POST looks like:
+
+ POST /_synapse/replication/presence_set_state/<user_id>
+
+ {
+ "state": { ... },
+ "ignore_status_msg": false,
+ }
+
+ 200 OK
+
+ {}
+ """
+
+ NAME = "presence_set_state"
+ PATH_ARGS = ("user_id",)
+ METHOD = "POST"
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self._presence_handler = hs.get_presence_handler()
+
+ @staticmethod
+ def _serialize_payload(user_id, state, ignore_status_msg=False):
+ return {
+ "state": state,
+ "ignore_status_msg": ignore_status_msg,
+ }
+
+ async def _handle_request(self, request, user_id):
+ content = parse_json_object_from_request(request)
+
+ await self._presence_handler.set_state(
+ UserID.from_string(user_id), content["state"], content["ignore_status_msg"]
+ )
+
+ return (
+ 200,
+ {},
+ )
+
+
+def register_servlets(hs, http_server):
+ ReplicationBumpPresenceActiveTime(hs).register(http_server)
+ ReplicationPresenceSetState(hs).register(http_server)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index b74b088f..c981723c 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
- await self.event_creation_handler.persist_and_notify_client_event(
+ stream_id = await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- return 200, {}
+ return 200, {"stream_id": stream_id}
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index 0459f582..bde97eef 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
super().__init__(hs)
self._instance_name = hs.get_instance_name()
-
- # We pull the streams from the replication steamer (if we try and make
- # them ourselves we end up in an import loop).
- self.streams = hs.get_replication_streamer().get_streams()
+ self.streams = hs.get_replication_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token):
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 5d7c8871..f9e2533e 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -16,65 +16,28 @@
import logging
from typing import Optional
-import six
-
-from synapse.storage.data_stores.main.cache import (
- CURRENT_STATE_CACHE_NAME,
- CacheInvalidationWorkerStore,
-)
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
-
-from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
logger = logging.getLogger(__name__)
-def __func__(inp):
- if six.PY3:
- return inp
- else:
- return inp.__func__
-
-
class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen = SlavedIdTracker(
- db_conn, "cache_invalidation_stream", "stream_id"
- ) # type: Optional[SlavedIdTracker]
+ self._cache_id_gen = MultiWriterIdGenerator(
+ db_conn,
+ database,
+ instance_name=hs.get_instance_name(),
+ table="cache_invalidation_stream_by_instance",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="cache_invalidation_stream_seq",
+ ) # type: Optional[MultiWriterIdGenerator]
else:
self._cache_id_gen = None
self.hs = hs
-
- def get_cache_stream_token(self):
- if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
- else:
- return 0
-
- def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "caches":
- if self._cache_id_gen:
- self._cache_id_gen.advance(token)
- for row in rows:
- if row.cache_func == CURRENT_STATE_CACHE_NAME:
- if row.keys is None:
- raise Exception(
- "Can't send an 'invalidate all' for current state cache"
- )
-
- room_id = row.keys[0]
- members_changed = set(row.keys[1:])
- self._invalidate_state_caches(room_id, members_changed)
- else:
- self._attempt_to_invalidate_cache(row.cache_func, row.keys)
-
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
- txn.call_after(cache_func.invalidate, keys)
- txn.call_after(self._send_invalidation_poke, cache_func, keys)
-
- def _send_invalidation_poke(self, cache_func, keys):
- self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 65e54b1c..9db6c62b 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -24,7 +24,13 @@ from synapse.storage.database import Database
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
- db_conn, "account_data_max_stream_id", "stream_id"
+ db_conn,
+ "account_data",
+ "stream_id",
+ extra_tables=[
+ ("room_account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ],
)
super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
@@ -32,7 +38,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)
for row in rows:
@@ -51,6 +57,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedAccountDataStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index fbf996e3..1a38f53d 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -15,7 +15,6 @@
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.storage.database import Database
-from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
from ._base import BaseSlavedStore
@@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+ name="client_ip_last_seen", keylen=4, max_entries=50000
)
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index c923751e..6e7fd259 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)
for row in rows:
@@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token
)
- return super(SlavedDeviceInboxStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 58fb0eaa..9d806734 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
self._invalidate_caches_for_devices(token, rows)
@@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedDeviceStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(self, token, rows):
for row in rows:
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 15011259..1a1a50a2 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,11 +15,6 @@
# limitations under the License.
import logging
-from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import (
- EventsStreamCurrentStateRow,
- EventsStreamEventRow,
-)
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
from synapse.storage.data_stores.main.event_push_actions import (
EventPushActionsWorkerStore,
@@ -35,7 +30,6 @@ from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__)
@@ -62,11 +56,6 @@ class SlavedEventStore(
BaseSlavedStore,
):
def __init__(self, database: Database, db_conn, hs):
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
-
super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
@@ -92,83 +81,3 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
-
- def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "events":
- self._stream_id_gen.advance(token)
- for row in rows:
- self._process_event_stream_row(token, row)
- elif stream_name == "backfill":
- self._backfill_id_gen.advance(-token)
- for row in rows:
- self.invalidate_caches_for_event(
- -token,
- row.event_id,
- row.room_id,
- row.type,
- row.state_key,
- row.redacts,
- row.relates_to,
- backfilled=True,
- )
- return super(SlavedEventStore, self).process_replication_rows(
- stream_name, token, rows
- )
-
- def _process_event_stream_row(self, token, row):
- data = row.data
-
- if row.type == EventsStreamEventRow.TypeId:
- self.invalidate_caches_for_event(
- token,
- data.event_id,
- data.room_id,
- data.type,
- data.state_key,
- data.redacts,
- data.relates_to,
- backfilled=False,
- )
- elif row.type == EventsStreamCurrentStateRow.TypeId:
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
-
- if data.type == EventTypes.Member:
- self.get_rooms_for_user_with_stream_ordering.invalidate(
- (data.state_key,)
- )
- else:
- raise Exception("Unknown events stream row type %s" % (row.type,))
-
- def invalidate_caches_for_event(
- self,
- stream_ordering,
- event_id,
- room_id,
- etype,
- state_key,
- redacts,
- relates_to,
- backfilled,
- ):
- self._invalidate_get_event_cache(event_id)
-
- self.get_latest_event_ids_in_room.invalidate((room_id,))
-
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
-
- if not backfilled:
- self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
-
- if redacts:
- self._invalidate_get_event_cache(redacts)
-
- if etype == EventTypes.Member:
- self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
- self.get_invited_rooms_for_local_user.invalidate((state_key,))
-
- if relates_to:
- self.get_relations_for_event.invalidate_many((relates_to,))
- self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
- self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 01bcf0e8..1851e7d5 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedGroupServerStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index fae31250..4e012484 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -18,7 +18,7 @@ from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import BaseSlavedStore, __func__
+from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -27,26 +27,24 @@ class SlavedPresenceStore(BaseSlavedStore):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
- self._presence_on_startup = self._get_active_presence(db_conn)
+ self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
- _get_active_presence = __func__(DataStore._get_active_presence)
- take_presence_startup_info = __func__(DataStore.take_presence_startup_info)
+ _get_active_presence = DataStore._get_active_presence
+ take_presence_startup_info = DataStore.take_presence_startup_info
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
- return super(SlavedPresenceStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6138796d..6adb1946 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,19 +15,11 @@
# limitations under the License.
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
-from synapse.storage.database import Database
-from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
- self._push_rules_stream_id_gen = SlavedIdTracker(
- db_conn, "push_rules_stream", "stream_id"
- )
- super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
-
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
@@ -37,13 +29,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedPushRuleStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 67be3379..cb78b49a 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)
- return super(SlavedPusherStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 993432ed..be716cc5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "receipts":
self._receipts_id_gen.advance(token)
for row in rows:
@@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
- return super(SlavedReceiptsStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 10dda870..8873bf37 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
- return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3bbf3c35..df29732f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,14 +14,23 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
-
+import heapq
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Dict, List, Tuple
+from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory
-from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams.events import (
+ EventsStream,
+ EventsStreamEventRow,
+ EventsStreamRow,
+)
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,6 +39,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# How long we allow callers to wait for replication updates before timing out.
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+
+
class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
@@ -83,8 +96,20 @@ class ReplicationDataHandler:
to handle updates in additional ways.
"""
- def __init__(self, store: BaseSlavedStore):
- self.store = store
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.pusher_pool = hs.get_pusherpool()
+ self.notifier = hs.get_notifier()
+ self._reactor = hs.get_reactor()
+ self._clock = hs.get_clock()
+ self._streams = hs.get_replication_streams()
+ self._instance_name = hs.get_instance_name()
+
+ # Map from stream to list of deferreds waiting for the stream to
+ # arrive at a particular position. The lists are sorted by stream position.
+ self._streams_to_waiters = (
+ {}
+ ) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -100,10 +125,100 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
- self.store.process_replication_rows(stream_name, token, rows)
-
- async def on_position(self, stream_name: str, token: int):
- self.store.process_replication_rows(stream_name, token, [])
+ self.store.process_replication_rows(stream_name, instance_name, token, rows)
+
+ if stream_name == EventsStream.NAME:
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ if row.type != EventsStreamEventRow.TypeId:
+ continue
+ assert isinstance(row, EventsStreamRow)
+
+ event = await self.store.get_event(
+ row.data.event_id, allow_rejected=True
+ )
+ if event.rejected_reason:
+ continue
+
+ extra_users = () # type: Tuple[str, ...]
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+ await self.pusher_pool.on_new_notifications(token, token)
+
+ # Notify any waiting deferreds. The list is ordered by position so we
+ # just iterate through the list until we reach a position that is
+ # greater than the received row position.
+ waiting_list = self._streams_to_waiters.get(stream_name, [])
+
+ # Index of first item with a position after the current token, i.e we
+ # have called all deferreds before this index. If not overwritten by
+ # loop below means either a) no items in list so no-op or b) all items
+ # in list were called and so the list should be cleared. Setting it to
+ # `len(list)` works for both cases.
+ index_of_first_deferred_not_called = len(waiting_list)
+
+ for idx, (position, deferred) in enumerate(waiting_list):
+ if position <= token:
+ try:
+ with PreserveLoggingContext():
+ deferred.callback(None)
+ except Exception:
+ # The deferred has been cancelled or timed out.
+ pass
+ else:
+ # The list is sorted by position so we don't need to continue
+ # checking any further entries in the list.
+ index_of_first_deferred_not_called = idx
+ break
+
+ # Drop all entries in the waiting list that were called in the above
+ # loop. (This maintains the order so no need to resort)
+ waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
+
+ async def wait_for_stream_position(
+ self, instance_name: str, stream_name: str, position: int
+ ):
+ """Wait until this instance has received updates up to and including
+ the given stream position.
+ """
+
+ if instance_name == self._instance_name:
+ # We don't get told about updates written by this process, and
+ # anyway in that case we don't need to wait.
+ return
+
+ current_position = self._streams[stream_name].current_token(self._instance_name)
+ if position <= current_position:
+ # We're already past the position
+ return
+
+ # Create a new deferred that times out after N seconds, as we don't want
+ # to wedge here forever.
+ deferred = Deferred()
+ deferred = timeout_deferred(
+ deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
+ )
+
+ waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
+
+ # We insert into the list using heapq as it is more efficient than
+ # pushing then resorting each time.
+ heapq.heappush(waiting_list, (position, deferred))
+
+ # We measure here to get in flight counts and average waiting time.
+ with Measure(self._clock, "repl.wait_for_stream_position"):
+ logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
+ await make_deferred_yieldable(deferred)
+ logger.info(
+ "Finished waiting for repl stream %r to reach %s", stream_name, position
+ )
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f58e384d..c04f6228 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
-class InvalidateCacheCommand(Command):
- """Sent by the client to invalidate an upstream cache.
-
- THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
- NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
- Mainly used to invalidate destination retry timing caches.
-
- Format::
-
- INVALIDATE_CACHE <cache_func> <keys_json>
-
- Where <keys_json> is a json list.
- """
-
- NAME = "INVALIDATE_CACHE"
-
- def __init__(self, cache_func, keys):
- self.cache_func = cache_func
- self.keys = keys
-
- @classmethod
- def from_line(cls, line):
- cache_func, keys_json = line.split(" ", 1)
-
- return cls(cache_func, json.loads(keys_json))
-
- def to_line(self):
- return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -439,7 +408,6 @@ _COMMANDS = (
UserSyncCommand,
FederationAckCommand,
RemovePusherCommand,
- InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
- InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 4328b38e..cbcf46f3 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
# limitations under the License.
import logging
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Tuple,
- TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from prometheus_client import Counter
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
- InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -48,7 +36,14 @@ from synapse.replication.tcp.commands import (
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.replication.tcp.streams import (
+ STREAMS_MAP,
+ BackfillStream,
+ CachesStream,
+ EventsStream,
+ FederationStream,
+ Stream,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -85,6 +80,34 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
+ # List of streams that this instance is the source of
+ self._streams_to_replicate = [] # type: List[Stream]
+
+ for stream in self._streams.values():
+ if stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream.
+ self._streams_to_replicate.append(stream)
+ continue
+
+ if isinstance(stream, (EventsStream, BackfillStream)):
+ # Only add EventStream and BackfillStream as a source on the
+ # instance in charge of event persistence.
+ if hs.config.worker.writers.events == hs.get_instance_name():
+ self._streams_to_replicate.append(stream)
+
+ continue
+
+ # Only add any other streams if we're on master.
+ if hs.config.worker_app is not None:
+ continue
+
+ if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+ # We only support federation stream if federation sending
+ # has been disabled on the master.
+ continue
+
+ self._streams_to_replicate.append(stream)
+
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
@@ -136,6 +159,9 @@ class ReplicationCommandHandler:
hs.config.redis_port,
)
+ # First let's ensure that we have a ReplicationStreamer started.
+ hs.get_replication_streamer()
+
# We need two connections to redis, one for the subscription stream and
# one to send commands to (as you can't send further redis commands to a
# connection after SUBSCRIBE is called).
@@ -162,16 +188,33 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a map from stream name to all streams.
+ """
+ return self._streams
+
+ def get_streams_to_replicate(self) -> List[Stream]:
+ """Get a list of streams that this instances replicates.
+ """
+ return self._streams_to_replicate
+
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
- # We only want to announce positions by the writer of the streams.
- # Currently this is just the master process.
- if not self._is_master:
- return
+ self.send_positions_to_connection(conn)
+
+ def send_positions_to_connection(self, conn: AbstractConnection):
+ """Send current position of all streams this process is source of to
+ the connection.
+ """
- for stream_name, stream in self._streams.items():
- current_token = stream.current_token()
+ # We respond with current position of all streams this instance
+ # replicates.
+ for stream in self.get_streams_to_replicate():
self.send_command(
- PositionCommand(stream_name, self._instance_name, current_token)
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ stream.current_token(self._instance_name),
+ )
)
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
@@ -208,18 +251,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
- async def on_INVALIDATE_CACHE(
- self, conn: AbstractConnection, cmd: InvalidateCacheCommand
- ):
- invalidate_cache_counter.inc()
-
- if self._is_master:
- # We invalidate the cache locally, but then also stream that to other
- # workers.
- await self._store.invalidate_cache_and_stream(
- cmd.cache_func, tuple(cmd.keys)
- )
-
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
@@ -293,7 +324,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
- logger.debug("Received rdata %s -> %s", stream_name, token)
+ logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
@@ -324,7 +355,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
- current_token = stream.current_token()
+ current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@@ -361,7 +392,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(stream_name, cmd.token)
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@@ -489,12 +522,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
- def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
- """Poke the master to invalidate a cache.
- """
- cmd = InvalidateCacheCommand(cache_func.__name__, keys)
- self.send_command(cmd)
-
def send_user_ip(
self,
user_id: str,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 55bfa71d..e776b631 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -70,7 +70,6 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.info("Connected to redis")
super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe)
- self.handler.new_connection(self)
async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
@@ -81,9 +80,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
+ self.handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
+ # We send out our positions when there is a new connection in case the
+ # other side missed updates. We do this for Redis connections as the
+ # otherside won't know we've connected and so won't issue a REPLICATE.
+ self.handler.send_positions_to_connection(self)
+
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589..41569305 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,6 @@
import logging
import random
-from typing import Dict, List
from prometheus_client import Counter
@@ -25,7 +24,6 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -71,26 +69,11 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level
- # Work out list of streams that this instance is the source of.
- self.streams = [] # type: List[Stream]
- if hs.config.worker_app is None:
- for stream in STREAMS_MAP.values():
- if stream == FederationStream and hs.config.send_federation:
- # We only support federation stream if federation sending
- # hase been disabled on the master.
- continue
-
- self.streams.append(stream(hs))
-
- self.streams_by_name = {stream.NAME: stream for stream in self.streams}
-
- # Only bother registering the notifier callback if we have streams to
- # publish.
- if self.streams:
- self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -98,10 +81,8 @@ class ReplicationStreamer(object):
self.command_handler = hs.get_tcp_replication()
- def get_streams(self) -> Dict[str, Stream]:
- """Get a mapp from stream name to stream instance.
- """
- return self.streams_by_name
+ # Set of streams to replicate.
+ self.streams = self.command_handler.get_streams_to_replicate()
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
@@ -145,7 +126,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
- if stream.last_token == stream.current_token():
+ if stream.last_token == stream.current_token(
+ self._instance_name
+ ):
continue
if self._replication_torture_level:
@@ -157,7 +140,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
- stream.current_token(),
+ stream.current_token(self._instance_name),
)
try:
updates, current_token, limited = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c36..4acefc8a 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@@ -95,19 +108,25 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
- current_token_function: Callable[[], Token],
+ current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
- current_token_function and update_function are callbacks which should be
- implemented by subclasses.
+ `current_token_function` and `update_function` are callbacks which
+ should be implemented by subclasses.
- current_token_function is called to get the current token of the underlying
- stream.
+ `current_token_function` takes an instance name, which is a writer to
+ the stream, and returns the position in the stream of the writer (as
+ viewed from the current process). On the writer process this is where
+ the writer has successfully written up to, whereas on other processes
+ this is the position which we have received updates up to over
+ replication. (Note that most streams have a single writer and so their
+ implementations ignore the instance name passed in).
- update_function is called to get updates for this stream between a pair of
- stream tokens. See the UpdateFunction type definition for more info.
+ `update_function` is called to get updates for this stream between a
+ pair of stream tokens. See the `UpdateFunction` type definition for more
+ info.
Args:
local_instance_name: The instance name of the current process
@@ -119,13 +138,13 @@ class Stream(object):
self.update_function = update_function
# The token from which we last asked for updates
- self.last_token = self.current_token()
+ self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
- self.last_token = self.current_token()
+ self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
@@ -137,7 +156,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates
to fetch.
"""
- current_token = self.current_token()
+ current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@@ -169,6 +188,16 @@ class Stream(object):
return updates, upto_token, limited
+def current_token_without_instance(
+ current_token: Callable[[], int]
+) -> Callable[[str], int]:
+ """Takes a current token callback function for a single writer stream
+ that doesn't take an instance name parameter and wraps it in a function that
+ does accept an instance name parameter but ignores it.
+ """
+ return lambda instance_name: current_token()
+
+
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
@@ -234,7 +263,7 @@ class BackfillStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_current_backfill_token,
+ current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@@ -270,7 +299,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(), store.get_current_presence_token, update_function
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_current_presence_token),
+ update_function,
)
@@ -295,7 +326,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(), typing_handler.get_current_token, update_function
+ hs.get_instance_name(),
+ current_token_without_instance(typing_handler.get_current_token),
+ update_function,
)
@@ -318,7 +351,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_max_receipt_stream_id,
+ current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts),
)
@@ -338,7 +371,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function
)
- def _current_token(self) -> int:
+ def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@@ -372,7 +405,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
- store.get_pushers_stream_token,
+ current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@@ -401,13 +434,27 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
- store = hs.get_datastore()
+ self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token,
- db_query_to_update_function(store.get_all_updated_caches),
+ self.store.get_cache_stream_token,
+ self._update_function,
)
+ async def _update_function(
+ self, instance_name: str, from_token: int, upto_token: int, limit: int
+ ):
+ rows = await self.store.get_all_updated_caches(
+ instance_name, from_token, upto_token, limit
+ )
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
class PublicRoomsStream(Stream):
"""The public rooms list changed
@@ -430,7 +477,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_current_public_room_stream_id,
+ current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
)
@@ -451,7 +498,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_device_stream_token,
+ current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@@ -469,7 +516,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_to_device_stream_token,
+ current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
)
@@ -489,7 +536,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_max_account_data_stream_id,
+ current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
)
@@ -499,32 +546,69 @@ class AccountDataStream(Stream):
"""
AccountDataStreamRow = namedtuple(
- "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+ "AccountDataStream",
+ ("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self.store.get_max_account_data_stream_id,
- db_query_to_update_function(self._update_function),
+ current_token_without_instance(self.store.get_max_account_data_stream_id),
+ self._update_function,
)
- async def _update_function(self, from_token, to_token, limit):
- global_results, room_results = await self.store.get_all_updated_account_data(
- from_token, from_token, to_token, limit
+ async def _update_function(
+ self, instance_name: str, from_token: int, to_token: int, limit: int
+ ) -> StreamUpdateResult:
+ limited = False
+ global_results = await self.store.get_updated_global_account_data(
+ from_token, to_token, limit
)
- results = list(room_results)
- results.extend(
- (stream_id, user_id, None, account_data_type)
+ # if the global results hit the limit, we'll need to limit the room results to
+ # the same stream token.
+ if len(global_results) >= limit:
+ to_token = global_results[-1][0]
+ limited = True
+
+ room_results = await self.store.get_updated_room_account_data(
+ from_token, to_token, limit
+ )
+
+ # likewise, if the room results hit the limit, limit the global results to
+ # the same stream token.
+ if len(room_results) >= limit:
+ to_token = room_results[-1][0]
+ limited = True
+
+ # convert the global results to the right format, and limit them to the to_token
+ # at the same time
+ global_rows = (
+ (stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
+ if stream_id <= to_token
+ )
+
+ # we know that the room_results are already limited to `to_token` so no need
+ # for a check on `stream_id` here.
+ room_rows = (
+ (stream_id, (user_id, room_id, account_data_type))
+ for stream_id, user_id, room_id, account_data_type in room_results
)
- return results
+ # We need to return a sorted list, so merge them together.
+ #
+ # Note: We order only by the stream ID to work around a bug where the
+ # same stream ID could appear in both `global_rows` and `room_rows`,
+ # leading to a comparison between the data tuples. The comparison could
+ # fail due to attempting to compare the `room_id` which results in a
+ # `TypeError` from comparing a `str` vs `None`.
+ updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+ return updates, to_token, limited
class GroupServerStream(Stream):
@@ -540,7 +624,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_group_stream_token,
+ current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
)
@@ -558,7 +642,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_device_stream_token,
+ current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 890e75d8..f3703903 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
import attr
-from ._base import Stream, StreamUpdateResult, Token
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self._store.get_current_events_token,
+ current_token_without_instance(self._store.get_current_events_token),
self._update_function,
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e3..9bcd13b0 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,11 @@
# limitations under the License.
from collections import namedtuple
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import (
+ Stream,
+ current_token_without_instance,
+ make_http_update_function,
+)
class FederationStream(Stream):
@@ -35,21 +39,35 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
- # Not all synapse instances will have a federation sender instance,
- # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
- # so we stub the stream out when that is the case.
- if hs.config.worker_app is None or hs.should_send_federation():
+ if hs.config.worker_app is None:
+ # master process: get updates from the FederationRemoteSendQueue.
+ # (if the master is configured to send federation itself, federation_sender
+ # will be a real FederationSender, which has stubs for current_token and
+ # get_replication_rows.)
federation_sender = hs.get_federation_sender()
- current_token = federation_sender.get_current_token
- update_function = db_query_to_update_function(
- federation_sender.get_replication_rows
+ current_token = current_token_without_instance(
+ federation_sender.get_current_token
)
+ update_function = federation_sender.get_replication_rows
+
+ elif hs.should_send_federation():
+ # federation sender: Query master process
+ update_function = make_http_update_function(hs, self.NAME)
+ current_token = self._stub_current_token
+
else:
- current_token = lambda: 0
+ # other worker: stub out the update function (we're not interested in
+ # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
+ current_token = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
+ def _stub_current_token(instance_name: str) -> int:
+ # dummy current-token method for use on workers
+ return 0
+
+ @staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index f0d7c66e..6b94d8c3 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -30,7 +30,7 @@
<tr>
<td colspan="2">
<div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div>
- <div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div>
+ <div class="noticetext">To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):</div>
<div class="noticetext"><a href="{{ url }}">{{ url }}</a></div>
</td>
</tr>
diff --git a/synapse/res/templates/notice_expiry.txt b/synapse/res/templates/notice_expiry.txt
index 41f1c427..4ec27e88 100644
--- a/synapse/res/templates/notice_expiry.txt
+++ b/synapse/res/templates/notice_expiry.txt
@@ -2,6 +2,6 @@ Hi {{ display_name }},
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
-To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab):
+To extend the validity of your account, please click on the link below (or copy and paste it to a new browser tab):
{{ url }}
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
new file mode 100644
index 00000000..43a21138
--- /dev/null
+++ b/synapse/res/templates/sso_error.html
@@ -0,0 +1,18 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <title>SSO error</title>
+</head>
+<body>
+ <p>Oops! Something went wrong during authentication.</p>
+ <p>
+ Try logging in again from your Matrix client and if the problem persists
+ please contact the server's administrator.
+ </p>
+ <p>Error: <code>{{ error }}</code></p>
+ {% if error_description %}
+ <pre><code>{{ error_description }}</code></pre>
+ {% endif %}
+</body>
+</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index ed70d448..9eda592d 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -26,12 +26,18 @@ from synapse.rest.admin._base import (
assert_requester_is_admin,
historical_admin_path_patterns,
)
+from synapse.rest.admin.devices import (
+ DeleteDevicesRestServlet,
+ DeviceRestServlet,
+ DevicesRestServlet,
+)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
JoinRoomAliasServlet,
ListRoomRestServlet,
+ RoomRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -193,6 +199,7 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
+ RoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
@@ -200,6 +207,9 @@ def register_servlets(hs, http_server):
UserAdminServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
+ DeviceRestServlet(hs).register(http_server)
+ DevicesRestServlet(hs).register(http_server)
+ DeleteDevicesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index a96f75ce..d82eaf5e 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -15,7 +15,11 @@
import re
+import twisted.web.server
+
+import synapse.api.auth
from synapse.api.errors import AuthError
+from synapse.types import UserID
def historical_admin_path_patterns(path_regex):
@@ -55,41 +59,32 @@ def admin_patterns(path_regex: str):
return patterns
-async def assert_requester_is_admin(auth, request):
+async def assert_requester_is_admin(
+ auth: synapse.api.auth.Auth, request: twisted.web.server.Request
+) -> None:
"""Verify that the requester is an admin user
- WARNING: MAKE SURE YOU YIELD ON THE RESULT!
-
Args:
- auth (synapse.api.auth.Auth):
- request (twisted.web.server.Request): incoming request
-
- Returns:
- Deferred
+ auth: api.auth.Auth singleton
+ request: incoming request
Raises:
- AuthError if the requester is not an admin
+ AuthError if the requester is not a server admin
"""
requester = await auth.get_user_by_req(request)
await assert_user_is_admin(auth, requester.user)
-async def assert_user_is_admin(auth, user_id):
+async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user
- WARNING: MAKE SURE YOU YIELD ON THE RESULT!
-
Args:
- auth (synapse.api.auth.Auth):
- user_id (UserID):
-
- Returns:
- Deferred
+ auth: api.auth.Auth singleton
+ user_id: user to check
Raises:
- AuthError if the user is not an admin
+ AuthError if the user is not a server admin
"""
-
is_admin = await auth.is_server_admin(user_id)
if not is_admin:
raise AuthError(403, "You are not a server admin")
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
new file mode 100644
index 00000000..8d326773
--- /dev/null
+++ b/synapse/rest/admin/devices.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from synapse.api.errors import NotFoundError, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.rest.admin._base import assert_requester_is_admin
+from synapse.types import UserID
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceRestServlet(RestServlet):
+ """
+ Get, update or delete the given user's device
+ """
+
+ PATTERNS = (
+ re.compile(
+ "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$"
+ ),
+ )
+
+ def __init__(self, hs):
+ super(DeviceRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, user_id, device_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ device = await self.device_handler.get_device(
+ target_user.to_string(), device_id
+ )
+ return 200, device
+
+ async def on_DELETE(self, request, user_id, device_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ await self.device_handler.delete_device(target_user.to_string(), device_id)
+ return 200, {}
+
+ async def on_PUT(self, request, user_id, device_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ await self.device_handler.update_device(
+ target_user.to_string(), device_id, body
+ )
+ return 200, {}
+
+
+class DevicesRestServlet(RestServlet):
+ """
+ Retrieve the given user's devices
+ """
+
+ PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),)
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ devices = await self.device_handler.get_devices_by_user(target_user.to_string())
+ return 200, {"devices": devices}
+
+
+class DeleteDevicesRestServlet(RestServlet):
+ """
+ API for bulk deletion of devices. Accepts a JSON object with a devices
+ key which lists the device_ids to delete.
+ """
+
+ PATTERNS = (
+ re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"),
+ )
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
+
+ async def on_POST(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ body = parse_json_object_from_request(request, allow_empty_body=False)
+ assert_params_in_dict(body, ["devices"])
+
+ await self.device_handler.delete_devices(
+ target_user.to_string(), body["devices"]
+ )
+ return 200, {}
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index d1bdb641..8173baef 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
)
from synapse.rest.admin._base import (
admin_patterns,
+ assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
)
@@ -58,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet):
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
+ self._replication = hs.get_replication_data_handler()
async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
@@ -72,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
- info = await self._room_creation_handler.create_room(
+ info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
@@ -93,6 +95,15 @@ class ShutdownRoomRestServlet(RestServlet):
# desirable in case the first attempt at blocking the room failed below.
await self.store.block_room(room_id, requester_user_id)
+ # We now wait for the create room to come back in via replication so
+ # that we can assume that all the joins/invites have propogated before
+ # we try and auto join below.
+ #
+ # TODO: Currently the events stream is written to from master
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", stream_id
+ )
+
users = await self.state.get_current_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []
@@ -104,7 +115,7 @@ class ShutdownRoomRestServlet(RestServlet):
try:
target_requester = create_requester(user_id)
- await self.room_member_handler.update_membership(
+ _, stream_id = await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=room_id,
@@ -114,6 +125,11 @@ class ShutdownRoomRestServlet(RestServlet):
require_consent=False,
)
+ # Wait for leave to come in over replication before trying to forget.
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", stream_id
+ )
+
await self.room_member_handler.forget(target_requester.user, room_id)
await self.room_member_handler.update_membership(
@@ -169,7 +185,7 @@ class ListRoomRestServlet(RestServlet):
in a dictionary containing room information. Supports pagination.
"""
- PATTERNS = admin_patterns("/rooms")
+ PATTERNS = admin_patterns("/rooms$")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -253,6 +269,29 @@ class ListRoomRestServlet(RestServlet):
return 200, response
+class RoomRestServlet(RestServlet):
+ """Get room details.
+
+ TODO: Add on_POST to allow room creation without joining the room
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, room_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ ret = await self.store.get_room_with_stats(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ return 200, ret
+
+
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 326682fb..fefc8f71 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -142,6 +142,7 @@ class UserRestServletV2(RestServlet):
self.set_password_handler = hs.get_set_password_handler()
self.deactivate_account_handler = hs.get_deactivate_account_handler()
self.registration_handler = hs.get_registration_handler()
+ self.pusher_pool = hs.get_pusherpool()
async def on_GET(self, request, user_id):
await assert_requester_is_admin(self.auth, request)
@@ -222,8 +223,14 @@ class UserRestServletV2(RestServlet):
else:
new_password = body["password"]
logout_devices = True
+
+ new_password_hash = await self.auth_handler.hash(new_password)
+
await self.set_password_handler.set_password(
- target_user.to_string(), new_password, logout_devices, requester
+ target_user.to_string(),
+ new_password_hash,
+ logout_devices,
+ requester,
)
if "deactivated" in body:
@@ -263,6 +270,7 @@ class UserRestServletV2(RestServlet):
admin=bool(admin),
default_display_name=displayname,
user_type=user_type,
+ by_admin=True,
)
if "threepids" in body:
@@ -275,6 +283,21 @@ class UserRestServletV2(RestServlet):
await self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], current_time
)
+ if (
+ self.hs.config.email_enable_notifs
+ and self.hs.config.email_notif_for_new_users
+ ):
+ await self.pusher_pool.add_pusher(
+ user_id=user_id,
+ access_token=None,
+ kind="email",
+ app_id="m.email",
+ app_display_name="Email Notifications",
+ device_display_name=threepid["address"],
+ pushkey=threepid["address"],
+ lang=None, # We don't know a user's language here
+ data={},
+ )
if "avatar_url" in body and type(body["avatar_url"]) == str:
await self.profile_handler.set_avatar_url(
@@ -410,6 +433,7 @@ class UserRegisterServlet(RestServlet):
password_hash=password_hash,
admin=bool(admin),
user_type=user_type,
+ by_admin=True,
)
result = await register._create_registration_details(user_id, body)
@@ -523,6 +547,7 @@ class ResetPasswordRestServlet(RestServlet):
self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
async def on_POST(self, request, target_user_id):
@@ -539,8 +564,10 @@ class ResetPasswordRestServlet(RestServlet):
new_password = params["new_password"]
logout_devices = params.get("logout_devices", True)
+ new_password_hash = await self.auth_handler.hash(new_password)
+
await self._set_password_handler.set_password(
- target_user_id, new_password, logout_devices, requester
+ target_user_id, new_password_hash, logout_devices, requester
)
return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 4de2f97d..dceb2792 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -83,32 +83,42 @@ class LoginRestServlet(RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
+ self.oidc_enabled = hs.config.oidc_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
- self._clock = hs.get_clock()
self._well_known_builder = WellKnownBuilder(hs)
- self._address_ratelimiter = Ratelimiter()
- self._account_ratelimiter = Ratelimiter()
- self._failed_attempts_ratelimiter = Ratelimiter()
+ self._address_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_address.per_second,
+ burst_count=self.hs.config.rc_login_address.burst_count,
+ )
+ self._account_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_account.per_second,
+ burst_count=self.hs.config.rc_login_account.burst_count,
+ )
+ self._failed_attempts_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
def on_GET(self, request):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- if self.cas_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
+ if self.cas_enabled:
# we advertise CAS for backwards compat, though MSC1721 renamed it
# to SSO.
flows.append({"type": LoginRestServlet.CAS_TYPE})
+ if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
# While its valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
- # CAS login flow.
+ # SSO login flow.
# Generally we don't want to advertise login flows that clients
# don't know how to implement, since they (currently) will always
# fall back to the fallback API if they don't understand one of the
@@ -125,13 +135,7 @@ class LoginRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
- self._address_ratelimiter.ratelimit(
- request.getClientIP(),
- time_now_s=self.hs.clock.time(),
- rate_hz=self.hs.config.rc_login_address.per_second,
- burst_count=self.hs.config.rc_login_address.burst_count,
- update=True,
- )
+ self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
try:
@@ -199,13 +203,7 @@ class LoginRestServlet(RestServlet):
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
- self._failed_attempts_ratelimiter.ratelimit(
- (medium, address),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=False,
- )
+ self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
# Check for login providers that support 3pid login types
(
@@ -239,13 +237,7 @@ class LoginRestServlet(RestServlet):
# If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
- self._failed_attempts_ratelimiter.can_do_action(
- (medium, address),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=True,
- )
+ self._failed_attempts_ratelimiter.can_do_action((medium, address))
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
@@ -264,11 +256,7 @@ class LoginRestServlet(RestServlet):
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=False,
+ qualified_user_id.lower(), update=False
)
try:
@@ -280,13 +268,7 @@ class LoginRestServlet(RestServlet):
# limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
- self._failed_attempts_ratelimiter.can_do_action(
- qualified_user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=True,
- )
+ self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
raise
result = await self._complete_login(
@@ -295,7 +277,7 @@ class LoginRestServlet(RestServlet):
return result
async def _complete_login(
- self, user_id, login_submission, callback=None, create_non_existant_users=False
+ self, user_id, login_submission, callback=None, create_non_existent_users=False
):
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -308,7 +290,7 @@ class LoginRestServlet(RestServlet):
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
- create_non_existant_users (bool): Whether to create the user if
+ create_non_existent_users (bool): Whether to create the user if
they don't exist. Defaults to False.
Returns:
@@ -319,20 +301,15 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
- self._account_ratelimiter.ratelimit(
- user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_account.per_second,
- burst_count=self.hs.config.rc_login_account.burst_count,
- update=True,
- )
+ self._account_ratelimiter.ratelimit(user_id.lower())
- if create_non_existant_users:
- user_id = await self.auth_handler.check_user_exists(user_id)
- if not user_id:
- user_id = await self.registration_handler.register_user(
+ if create_non_existent_users:
+ canonical_uid = await self.auth_handler.check_user_exists(user_id)
+ if not canonical_uid:
+ canonical_uid = await self.registration_handler.register_user(
localpart=UserID.from_string(user_id).localpart
)
+ user_id = canonical_uid
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
@@ -387,7 +364,7 @@ class LoginRestServlet(RestServlet):
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
- user_id, login_submission, create_non_existant_users=True
+ user_id, login_submission, create_non_existent_users=True
)
return result
@@ -397,19 +374,22 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
- def on_GET(self, request: SynapseRequest):
+ async def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
- sso_url = self.get_sso_url(client_redirect_url)
+ sso_url = await self.get_sso_url(request, client_redirect_url)
request.redirect(sso_url)
finish_request(request)
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
+ request: The client request to redirect.
client_redirect_url: the URL that we should redirect the
client to when everything is done
@@ -424,7 +404,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._cas_handler = hs.get_cas_handler()
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._cas_handler.get_redirect_url(
{"redirectUrl": client_redirect_url}
).encode("ascii")
@@ -461,10 +443,28 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
+class OIDCRedirectServlet(BaseSSORedirectServlet):
+ """Implementation for /login/sso/redirect for the OIDC login flow."""
+
+ PATTERNS = client_patterns("/login/sso/redirect", v1=True)
+
+ def __init__(self, hs):
+ self._oidc_handler = hs.get_oidc_handler()
+
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url
+ )
+
+
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:
@@ -472,3 +472,5 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
+ elif hs.config.oidc_enabled:
+ OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1cf3caf8..b0c30b65 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -34,10 +34,10 @@ class LogoutRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None:
- # the acccess token wasn't associated with a device.
+ # The access token wasn't associated with a device.
# Just delete the access token
access_token = self.auth.get_access_token_from_request(request)
await self._auth_handler.delete_access_token(access_token)
@@ -62,7 +62,7 @@ class LogoutAllRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
# first delete all of the user's devices
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 6b5830cc..105e0cf4 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request)
- info = await self._room_creation_handler.create_room(
+ info, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = await self.room_member_handler.update_membership(
+ event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
+ event_id = event.event_id
ret = {} # type: dict
- if event:
- set_tag("event_id", event.event_id)
- ret = {"event_id": event.event_id}
+ if event_id:
+ set_tag("event_id", event_id)
+ ret = {"event_id": event_id}
return 200, ret
@@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 1bd02347..d4f721b6 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -220,12 +220,27 @@ class PasswordRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
+ self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
async def on_POST(self, request):
body = parse_json_object_from_request(request)
+ # we do basic sanity checks here because the auth layer will store these
+ # in sessions. Pull out the new password provided to us.
+ if "new_password" in body:
+ new_password = body.pop("new_password")
+ if not isinstance(new_password, str) or len(new_password) > 512:
+ raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(new_password)
+
+ # If the password is valid, hash it and store it back on the body.
+ # This ensures that only the hashed password is handled everywhere.
+ if "new_password_hash" in body:
+ raise SynapseError(400, "Unexpected property: new_password_hash")
+ body["new_password_hash"] = await self.auth_handler.hash(new_password)
+
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -276,12 +291,12 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- assert_params_in_dict(params, ["new_password"])
- new_password = params["new_password"]
+ assert_params_in_dict(params, ["new_password_hash"])
+ new_password_hash = params["new_password_hash"]
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
- user_id, new_password, logout_devices, requester
+ user_id, new_password_hash, logout_devices, requester
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 24dd3d3e..75590eba 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
# SSO configuration.
- self._saml_enabled = hs.config.saml2_enabled
- if self._saml_enabled:
- self._saml_handler = hs.get_saml_handler()
self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled:
self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
+ self._saml_enabled = hs.config.saml2_enabled
+ if self._saml_enabled:
+ self._saml_handler = hs.get_saml_handler()
+ self._oidc_enabled = hs.config.oidc_enabled
+ if self._oidc_enabled:
+ self._oidc_handler = hs.get_oidc_handler()
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
@@ -172,11 +177,20 @@ class AuthRestServlet(RestServlet):
)
elif self._saml_enabled:
- client_redirect_url = ""
+ # Some SAML identity providers (e.g. Google) require a
+ # RelayState parameter on requests. It is not necessary here, so
+ # pass in a dummy redirect URL (which will never get used).
+ client_redirect_url = b"unused"
sso_redirect_url = self._saml_handler.handle_redirect_request(
client_redirect_url, session
)
+ elif self._oidc_enabled:
+ client_redirect_url = b""
+ sso_redirect_url = await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url, session
+ )
+
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 8f41a3ed..24bb0908 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -42,7 +42,7 @@ class KeyUploadServlet(RestServlet):
"device_id": "<device_id>",
"valid_until_ts": <millisecond_timestamp>,
"algorithms": [
- "m.olm.curve25519-aes-sha256",
+ "m.olm.curve25519-aes-sha2",
]
"keys": {
"<algorithm>:<device_id>": "<key_base64>",
@@ -124,7 +124,7 @@ class KeyQueryServlet(RestServlet):
"device_id": "<device_id>", // Duplicated to be signed
"valid_until_ts": <millisecond_timestamp>,
"algorithms": [ // List of supported algorithms
- "m.olm.curve25519-aes-sha256",
+ "m.olm.curve25519-aes-sha2",
],
"keys": { // Must include a ed25519 signing key
"<algorithm>:<key_id>": "<key_base64>",
@@ -285,8 +285,8 @@ class SignaturesUploadServlet(RestServlet):
"user_id": "<user_id>",
"device_id": "<device_id>",
"algorithms": [
- "m.olm.curve25519-aes-sha256",
- "m.megolm.v1.aes-sha"
+ "m.olm.curve25519-aes-sha2",
+ "m.megolm.v1.aes-sha2"
],
"keys": {
"<algorithm>:<device_id>": "<key_base64>",
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c26927f2..b9ffe86b 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -26,7 +26,6 @@ import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
- LimitExceededError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
@@ -396,20 +395,7 @@ class RegisterRestServlet(RestServlet):
client_addr = request.getClientIP()
- time_now = self.clock.time()
-
- allowed, time_allowed = self.ratelimiter.can_do_action(
- client_addr,
- time_now_s=time_now,
- rate_hz=self.hs.config.rc_registration.per_second,
- burst_count=self.hs.config.rc_registration.burst_count,
- update=False,
- )
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now))
- )
+ self.ratelimiter.ratelimit(client_addr, update=False)
kind = b"user"
if b"kind" in request.args:
@@ -431,8 +417,8 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(password)
- # If the password is valid, hash it and store it back on the request.
- # This ensures the hashed password is handled everywhere.
+ # If the password is valid, hash it and store it back on the body.
+ # This ensures that only the hashed password is handled everywhere.
if "password_hash" in body:
raise SynapseError(400, "Unexpected property: password_hash")
body["password_hash"] = await self.auth_handler.hash(password)
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 63f07b63..89002ffb 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d90a6a89..0d668df0 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -49,24 +49,10 @@ class VersionsRestServlet(RestServlet):
"r0.3.0",
"r0.4.0",
"r0.5.0",
+ "r0.6.0",
],
# as per MSC1497:
"unstable_features": {
- # as per MSC2190, as amended by MSC2264
- # to be removed in r0.6.0
- "m.id_access_token": True,
- # Advertise to clients that they need not include an `id_server`
- # parameter during registration or password reset, as Synapse now decides
- # itself which identity server to use (or none at all).
- #
- # This is also used by a client when they wish to bind a 3PID to their
- # account, but not bind it to an identity server, the endpoint for which
- # also requires `id_server`. If the homeserver is handling 3PID
- # verification itself, there is no need to ask the user for `id_server` to
- # be supplied.
- "m.require_identity_server": False,
- # as per MSC2290
- "m.separate_add_and_bind": True,
# Implements support for label-based filtering as described in
# MSC2326.
"org.matrix.label_based_filtering": True,
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 503f2bed..36897772 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,7 +17,6 @@
import logging
import os
-from six import PY3
from six.moves import urllib
from twisted.internet import defer
@@ -324,23 +323,15 @@ def get_filename_from_headers(headers):
upload_name_utf8 = upload_name_utf8[7:]
# We have a filename*= section. This MUST be ASCII, and any UTF-8
# bytes are %-quoted.
- if PY3:
- try:
- # Once it is decoded, we can then unquote the %-encoded
- # parts strictly into a unicode string.
- upload_name = urllib.parse.unquote(
- upload_name_utf8.decode("ascii"), errors="strict"
- )
- except UnicodeDecodeError:
- # Incorrect UTF-8.
- pass
- else:
- # On Python 2, we first unquote the %-encoded parts and then
- # decode it strictly using UTF-8.
- try:
- upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
- except UnicodeDecodeError:
- pass
+ try:
+ # Once it is decoded, we can then unquote the %-encoded
+ # parts strictly into a unicode string.
+ upload_name = urllib.parse.unquote(
+ upload_name_utf8.decode("ascii"), errors="strict"
+ )
+ except UnicodeDecodeError:
+ # Incorrect UTF-8.
+ pass
# If there isn't check for an ascii name.
if not upload_name:
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py
new file mode 100644
index 00000000..d958dd65
--- /dev/null
+++ b/synapse/rest/oidc/__init__.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.web.resource import Resource
+
+from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCResource(Resource):
+ def __init__(self, hs):
+ Resource.__init__(self)
+ self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
new file mode 100644
index 00000000..c03194f0
--- /dev/null
+++ b/synapse/rest/oidc/callback_resource.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.http.server import DirectServeResource, wrap_html_request_handler
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCCallbackResource(DirectServeResource):
+ isLeaf = 1
+
+ def __init__(self, hs):
+ super().__init__()
+ self._oidc_handler = hs.get_oidc_handler()
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index a545c13d..75e58043 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -13,12 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.python import failure
-from synapse.http.server import (
- DirectServeResource,
- finish_request,
- wrap_html_request_handler,
-)
+from synapse.api.errors import SynapseError
+from synapse.http.server import DirectServeResource, return_html_error
class SAML2ResponseResource(DirectServeResource):
@@ -28,20 +26,22 @@ class SAML2ResponseResource(DirectServeResource):
def __init__(self, hs):
super().__init__()
- self._error_html_content = hs.config.saml2_error_html_content
self._saml_handler = hs.get_saml_handler()
+ self._error_html_template = hs.config.saml2.saml2_error_html_template
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
- request.setResponseCode(400)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(self._error_html_content),))
- request.write(self._error_html_content.encode("utf8"))
- finish_request(request)
+ f = failure.Failure(
+ SynapseError(400, "Unexpected GET request on /saml2/authn_response")
+ )
+ return_html_error(f, request, self._error_html_template)
- @wrap_html_request_handler
async def _async_render_POST(self, request):
- return await self._saml_handler.handle_saml_response(request)
+ try:
+ await self._saml_handler.handle_saml_response(request)
+ except Exception:
+ f = failure.Failure()
+ return_html_error(f, request, self._error_html_template)
diff --git a/synapse/server.py b/synapse/server.py
index bf97a16c..fe94836a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
+from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
@@ -204,11 +205,13 @@ class HomeServer(object):
"account_validity_handler",
"cas_handler",
"saml_handler",
+ "oidc_handler",
"event_client_serializer",
"password_policy_handler",
"storage",
"replication_streamer",
"replication_data_handler",
+ "replication_streams",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -239,9 +242,12 @@ class HomeServer(object):
self.clock = Clock(reactor)
self.distributor = Distributor()
- self.ratelimiter = Ratelimiter()
- self.admin_redaction_ratelimiter = Ratelimiter()
- self.registration_ratelimiter = Ratelimiter()
+
+ self.registration_ratelimiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=config.rc_registration.per_second,
+ burst_count=config.rc_registration.burst_count,
+ )
self.datastores = None
@@ -311,15 +317,9 @@ class HomeServer(object):
def get_distributor(self):
return self.distributor
- def get_ratelimiter(self):
- return self.ratelimiter
-
- def get_registration_ratelimiter(self):
+ def get_registration_ratelimiter(self) -> Ratelimiter:
return self.registration_ratelimiter
- def get_admin_redaction_ratelimiter(self):
- return self.admin_redaction_ratelimiter
-
def build_federation_client(self):
return FederationClient(self)
@@ -562,6 +562,11 @@ class HomeServer(object):
return SamlHandler(self)
+ def build_oidc_handler(self):
+ from synapse.handlers.oidc_handler import OidcHandler
+
+ return OidcHandler(self)
+
def build_event_client_serializer(self):
return EventClientSerializer(self)
@@ -575,7 +580,10 @@ class HomeServer(object):
return ReplicationStreamer(self)
def build_replication_data_handler(self):
- return ReplicationDataHandler(self.get_datastore())
+ return ReplicationDataHandler(self)
+
+ def build_replication_streams(self):
+ return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 18043a25..fe8024d2 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,3 +1,5 @@
+from typing import Dict
+
import twisted.internet
import synapse.api.auth
@@ -13,11 +15,13 @@ import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.handlers.message
import synapse.handlers.presence
+import synapse.handlers.register
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
import synapse.notifier
+import synapse.push.pusherpool
import synapse.replication.tcp.client
import synapse.replication.tcp.handler
import synapse.rest.media.v1.media_repository
@@ -26,6 +30,7 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
+from synapse.replication.tcp.streams import Stream
class HomeServer(object):
@property
@@ -128,3 +133,11 @@ class HomeServer(object):
pass
def get_storage(self) -> synapse.storage.Storage:
pass
+ def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler:
+ pass
+ def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator:
+ pass
+ def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool:
+ pass
+ def get_replication_streams(self) -> Dict[str, Stream]:
+ pass
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index d9716635..73f2cedb 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -48,6 +48,12 @@ class ResourceLimitsServerNotices(object):
self._notifier = hs.get_notifier()
+ self._enabled = (
+ hs.config.limit_usage_by_mau
+ and self._server_notices_manager.is_enabled()
+ and not hs.config.hs_disabled
+ )
+
async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, this will be true in
two cases.
@@ -61,14 +67,7 @@ class ResourceLimitsServerNotices(object):
Returns:
Deferred
"""
- if self._config.hs_disabled is True:
- return
-
- if self._config.limit_usage_by_mau is False:
- return
-
- if not self._server_notices_manager.is_enabled():
- # Don't try and send server notices unless they've been enabled
+ if not self._enabled:
return
timestamp = await self._store.user_last_seen_monthly_active(user_id)
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 999c621b..bf2454c0 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -83,10 +83,10 @@ class ServerNoticesManager(object):
if state_key is not None:
event_dict["state_key"] = state_key
- res = await self._event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
- return res
+ return event
@cached()
async def get_or_create_notice_room_for_user(self, user_id):
@@ -143,7 +143,7 @@ class ServerNoticesManager(object):
}
requester = create_requester(self.server_notices_mxid)
- info = await self._room_creation_handler.create_room(
+ info, _ = await self._room_creation_handler.create_room(
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4afefc6b..2fa529fc 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -35,7 +35,6 @@ from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
@@ -53,7 +52,6 @@ state_groups_histogram = Histogram(
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache")
EVICTION_TIMEOUT_SECONDS = 60 * 60
@@ -447,7 +445,7 @@ class StateResolutionHandler(object):
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
- max_len=SIZE_OF_CACHE,
+ max_len=100000,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index 611df36d..ba8048b2 100644
--- a/synapse/static/client/login/js/login.js
+++ b/synapse/static/client/login/js/login.js
@@ -4,30 +4,41 @@ window.matrixLogin = {
serverAcceptsSso: false,
};
-var title_pre_auth = "Log in with one of the following methods";
-var title_post_auth = "Logging in...";
-
-var submitPassword = function(user, pwd) {
- console.log("Logging in with password...");
- set_title(title_post_auth);
- var data = {
- type: "m.login.password",
- user: user,
- password: pwd,
- };
- $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
- matrixLogin.onLogin(response);
- }).fail(errorFunc);
-};
+// Titles get updated through the process to give users feedback.
+var TITLE_PRE_AUTH = "Log in with one of the following methods";
+var TITLE_POST_AUTH = "Logging in...";
+
+// The cookie used to store the original query parameters when using SSO.
+var COOKIE_KEY = "synapse_login_fallback_qs";
+
+/*
+ * Submit a login request.
+ *
+ * type: The login type as a string (e.g. "m.login.foo").
+ * data: An object of data specific to the login type.
+ * extra: (Optional) An object to search for extra information to send with the
+ * login request, e.g. device_id.
+ * callback: (Optional) Function to call on successful login.
+ */
+var submitLogin = function(type, data, extra, callback) {
+ console.log("Logging in with " + type);
+ set_title(TITLE_POST_AUTH);
+
+ // Add the login type.
+ data.type = type;
+
+ // Add the device information, if it was provided.
+ if (extra.device_id) {
+ data.device_id = extra.device_id;
+ }
+ if (extra.initial_device_display_name) {
+ data.initial_device_display_name = extra.initial_device_display_name;
+ }
-var submitToken = function(loginToken) {
- console.log("Logging in with login token...");
- set_title(title_post_auth);
- var data = {
- type: "m.login.token",
- token: loginToken
- };
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
+ if (callback) {
+ callback();
+ }
matrixLogin.onLogin(response);
}).fail(errorFunc);
};
@@ -50,12 +61,19 @@ var setFeedbackString = function(text) {
};
var show_login = function(inhibit_redirect) {
+ // Set the redirect to come back to this page, a login token will get added
+ // and handled after the redirect.
var this_page = window.location.origin + window.location.pathname;
$("#sso_redirect_url").val(this_page);
- // If inhibit_redirect is false, and SSO is the only supported login method, we can
- // redirect straight to the SSO page
+ // If inhibit_redirect is false, and SSO is the only supported login method,
+ // we can redirect straight to the SSO page.
if (matrixLogin.serverAcceptsSso) {
+ // Before submitting SSO, set the current query parameters into a cookie
+ // for retrieval later.
+ var qs = parseQsFromUrl();
+ setCookie(COOKIE_KEY, JSON.stringify(qs));
+
if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) {
$("#sso_form").submit();
return;
@@ -73,7 +91,7 @@ var show_login = function(inhibit_redirect) {
$("#no_login_types").show();
}
- set_title(title_pre_auth);
+ set_title(TITLE_PRE_AUTH);
$("#loading").hide();
};
@@ -123,7 +141,10 @@ matrixLogin.password_login = function() {
setFeedbackString("");
show_spinner();
- submitPassword(user, pwd);
+ submitLogin(
+ "m.login.password",
+ {user: user, password: pwd},
+ parseQsFromUrl());
};
matrixLogin.onLogin = function(response) {
@@ -131,7 +152,16 @@ matrixLogin.onLogin = function(response) {
console.warn("onLogin - This function should be replaced to proceed.");
};
-var parseQsFromUrl = function(query) {
+/*
+ * Process the query parameters from the current URL into an object.
+ */
+var parseQsFromUrl = function() {
+ var pos = window.location.href.indexOf("?");
+ if (pos == -1) {
+ return {};
+ }
+ var query = window.location.href.substr(pos + 1);
+
var result = {};
query.split("&").forEach(function(part) {
var item = part.split("=");
@@ -141,25 +171,80 @@ var parseQsFromUrl = function(query) {
if (val) {
val = decodeURIComponent(val);
}
- result[key] = val
+ result[key] = val;
});
return result;
};
+/*
+ * Process the cookies and return an object.
+ */
+var parseCookies = function() {
+ var allCookies = document.cookie;
+ var result = {};
+ allCookies.split(";").forEach(function(part) {
+ var item = part.split("=");
+ // Cookies might have arbitrary whitespace between them.
+ var key = item[0].trim();
+ // You can end up with a broken cookie that doesn't have an equals sign
+ // in it. Set to an empty value.
+ var val = (item[1] || "").trim();
+ // Values might be URI encoded.
+ if (val) {
+ val = decodeURIComponent(val);
+ }
+ result[key] = val;
+ });
+ return result;
+};
+
+/*
+ * Set a cookie that is valid for 1 hour.
+ */
+var setCookie = function(key, value) {
+ // The maximum age is set in seconds.
+ var maxAge = 60 * 60;
+ // Set the cookie, this defaults to the current domain and path.
+ document.cookie = key + "=" + encodeURIComponent(value) + ";max-age=" + maxAge + ";sameSite=lax";
+};
+
+/*
+ * Removes a cookie by key.
+ */
+var deleteCookie = function(key) {
+ // Delete a cookie by setting the expiration to 0. (Note that the value
+ // doesn't matter.)
+ document.cookie = key + "=deleted;expires=0";
+};
+
+/*
+ * Submits the login token if one is found in the query parameters. Returns a
+ * boolean of whether the login token was found or not.
+ */
var try_token = function() {
- var pos = window.location.href.indexOf("?");
- if (pos == -1) {
- return false;
- }
- var qs = parseQsFromUrl(window.location.href.substr(pos+1));
+ // Check if the login token is in the query parameters.
+ var qs = parseQsFromUrl();
var loginToken = qs.loginToken;
-
if (!loginToken) {
return false;
}
- submitToken(loginToken);
+ // Retrieve the original query parameters (from before the SSO redirect).
+ // They are stored as JSON in a cookie.
+ var cookies = parseCookies();
+ var original_query_params = JSON.parse(cookies[COOKIE_KEY] || "{}")
+
+ // If the login is successful, delete the cookie.
+ var callback = function() {
+ deleteCookie(COOKIE_KEY);
+ }
+
+ submitLogin(
+ "m.login.token",
+ {token: loginToken},
+ original_query_params,
+ callback);
return true;
};
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 13de5f1f..bfce541c 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,9 +19,6 @@ import random
from abc import ABCMeta
from typing import Any, Optional
-from six import PY2
-from six.moves import builtins
-
from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401
@@ -47,6 +44,9 @@ class SQLBaseStore(metaclass=ABCMeta):
self.db = database
self.rand = random.SystemRandom()
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ pass
+
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -100,11 +100,6 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
- # psycopg2 on Python 2 returns buffer objects, which we need to cast to
- # bytes to decode
- if PY2 and isinstance(db_content, builtins.buffer):
- db_content = bytes(db_content)
-
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index e1d03429..599ee470 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -15,6 +15,7 @@
import logging
+from synapse.storage.data_stores.main.events import PersistEventsStore
from synapse.storage.data_stores.state import StateGroupDataStore
from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine
@@ -39,6 +40,7 @@ class DataStores(object):
self.databases = []
self.main = None
self.state = None
+ self.persist_events = None
for database_config in hs.config.database.databases:
db_name = database_config.name
@@ -64,6 +66,13 @@ class DataStores(object):
self.main = main_store_class(database, db_conn, hs)
+ # If we're on a process that can persist events also
+ # instantiate a `PersistEventsStore`
+ if hs.config.worker.writers.events == hs.get_instance_name():
+ self.persist_events = PersistEventsStore(
+ hs, database, self.main
+ )
+
if "state" in database_config.data_stores:
logger.info("Starting 'state' data store")
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index ceba1088..4b4763c7 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -24,15 +24,16 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
- ChainedIdGenerator,
IdGenerator,
+ MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .cache import CacheInvalidationStore
+from .cache import CacheInvalidationWorkerStore
+from .censor_events import CensorEventsStore
from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
@@ -41,16 +42,17 @@ from .e2e_room_keys import EndToEndRoomKeyStore
from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
-from .events import EventsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
from .media_repository import MediaRepositoryStore
+from .metrics import ServerMetricsStore
from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
+from .purge_events import PurgeEventsStore
from .push_rule import PushRuleStore
from .pusher import PusherStore
from .receipts import ReceiptsStore
@@ -87,7 +89,7 @@ class DataStore(
StateStore,
SignatureStore,
ApplicationServiceStore,
- EventsStore,
+ PurgeEventsStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
@@ -112,27 +114,16 @@ class DataStore(
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
- CacheInvalidationStore,
+ CensorEventsStore,
UIAuthStore,
+ CacheInvalidationWorkerStore,
+ ServerMetricsStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
- self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- extra_tables=[("local_invites", "stream_id")],
- )
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
- )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
@@ -159,9 +150,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- )
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
@@ -170,8 +158,14 @@ class DataStore(
)
if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen = StreamIdGenerator(
- db_conn, "cache_invalidation_stream", "stream_id"
+ self._cache_id_gen = MultiWriterIdGenerator(
+ db_conn,
+ database,
+ instance_name="master",
+ table="cache_invalidation_stream_by_instance",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="cache_invalidation_stream_seq",
)
else:
self._cache_id_gen = None
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b3..b58f04d0 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
import abc
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
- def get_all_updated_account_data(
- self, last_global_id, last_room_id, current_id, limit
- ):
- """Get all the client account_data that has changed on the server
+ async def get_updated_global_account_data(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple[int, str, str]]:
+ """Get the global account_data that has changed, for the account_data stream
+
Args:
- last_global_id(int): The position to fetch from for top level data
- last_room_id(int): The position to fetch from for per room data
- current_id(int): The position to fetch up to.
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
Returns:
- A deferred pair of lists of tuples of stream_id int, user_id string,
- room_id string, and type string.
+ A list of tuples of stream_id int, user_id string,
+ and type string.
"""
- if last_room_id == current_id and last_global_id == current_id:
- return defer.succeed(([], []))
+ if last_id == current_id:
+ return []
- def get_updated_account_data_txn(txn):
+ def get_updated_global_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_global_id, current_id, limit))
- global_results = txn.fetchall()
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return await self.db.runInteraction(
+ "get_updated_global_account_data", get_updated_global_account_data_txn
+ )
+
+ async def get_updated_room_account_data(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple[int, str, str, str]]:
+ """Get the global account_data that has changed, for the account_data stream
+
+ Args:
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
+ Returns:
+ A list of tuples of stream_id int, user_id string,
+ room_id string and type string.
+ """
+ if last_id == current_id:
+ return []
+ def get_updated_room_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_room_id, current_id, limit))
- room_results = txn.fetchall()
- return global_results, room_results
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
- return self.db.runInteraction(
- "get_all_updated_account_data_txn", get_updated_account_data_txn
+ return await self.db.runInteraction(
+ "get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
@@ -273,7 +297,13 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore):
def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
- db_conn, "account_data_max_stream_id", "stream_id"
+ db_conn,
+ "account_data_max_stream_id",
+ "stream_id",
+ extra_tables=[
+ ("room_account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ],
)
super(AccountDataStore, self).__init__(database, db_conn, hs)
@@ -363,6 +393,10 @@ class AccountDataStore(AccountDataWorkerStore):
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
+ #
+ # Note: This is only here for backwards compat to allow admins to
+ # roll back to a previous Synapse version. Next time we update the
+ # database version we can remove this table.
yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
@@ -381,6 +415,10 @@ class AccountDataStore(AccountDataWorkerStore):
next_id(int): The the revision to advance to.
"""
+ # Note: This is only here for backwards compat to allow admins to
+ # roll back to a previous Synapse version. Next time we update the
+ # database version we can remove this table.
+
def _update(txn):
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index efbc06c7..7a1fe8cd 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -30,12 +30,12 @@ logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache):
- # We precompie a regex constructed from all the regexes that the AS's
+ # We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
regex.pattern
for service in services_cache
- for regex in service.get_exlusive_user_regexes()
+ for regex in service.get_exclusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 4dc5da3f..eac5a4e5 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -18,9 +18,13 @@ import itertools
import logging
from typing import Any, Iterable, Optional, Tuple
-from twisted.internet import defer
-
+from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+ EventsStreamCurrentStateRow,
+ EventsStreamEventRow,
+)
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -33,28 +37,132 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
- def get_all_updated_caches(self, last_id, current_id, limit):
+ def __init__(self, database: Database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._instance_name = hs.get_instance_name()
+
+ async def get_all_updated_caches(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ):
+ """Fetches cache invalidation rows between the two given IDs written
+ by the given instance. Returns at most `limit` rows.
+ """
+
if last_id == current_id:
- return defer.succeed([])
+ return []
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
+ sql = """
+ SELECT stream_id, cache_func, keys, invalidation_ts
+ FROM cache_invalidation_stream_by_instance
+ WHERE stream_id > ? AND instance_name = ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall()
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == "events":
+ for row in rows:
+ self._process_event_stream_row(token, row)
+ elif stream_name == "backfill":
+ for row in rows:
+ self._invalidate_caches_for_event(
+ -token,
+ row.event_id,
+ row.room_id,
+ row.type,
+ row.state_key,
+ row.redacts,
+ row.relates_to,
+ backfilled=True,
+ )
+ elif stream_name == "caches":
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(instance_name, token)
+
+ for row in rows:
+ if row.cache_func == CURRENT_STATE_CACHE_NAME:
+ if row.keys is None:
+ raise Exception(
+ "Can't send an 'invalidate all' for current state cache"
+ )
+
+ room_id = row.keys[0]
+ members_changed = set(row.keys[1:])
+ self._invalidate_state_caches(room_id, members_changed)
+ else:
+ self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def _process_event_stream_row(self, token, row):
+ data = row.data
+
+ if row.type == EventsStreamEventRow.TypeId:
+ self._invalidate_caches_for_event(
+ token,
+ data.event_id,
+ data.room_id,
+ data.type,
+ data.state_key,
+ data.redacts,
+ data.relates_to,
+ backfilled=False,
+ )
+ elif row.type == EventsStreamCurrentStateRow.TypeId:
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ row.data.room_id, token
+ )
+
+ if data.type == EventTypes.Member:
+ self.get_rooms_for_user_with_stream_ordering.invalidate(
+ (data.state_key,)
+ )
+ else:
+ raise Exception("Unknown events stream row type %s" % (row.type,))
+
+ def _invalidate_caches_for_event(
+ self,
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ state_key,
+ redacts,
+ relates_to,
+ backfilled,
+ ):
+ self._invalidate_get_event_cache(event_id)
+
+ self.get_latest_event_ids_in_room.invalidate((room_id,))
+
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+
+ if not backfilled:
+ self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
+
+ if redacts:
+ self._invalidate_get_event_cache(redacts)
+
+ if etype == EventTypes.Member:
+ self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+ self.get_invited_rooms_for_local_user.invalidate((state_key,))
+
+ if relates_to:
+ self.get_relations_for_event.invalidate_many((relates_to,))
+ self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+ self.get_applicable_edit.invalidate((relates_to,))
-class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -68,7 +176,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
return
cache_func.invalidate(keys)
- await self.runInteraction(
+ await self.db.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
@@ -147,10 +255,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
- ctx = self._cache_id_gen.get_next()
- stream_id = ctx.__enter__()
- txn.call_on_exception(ctx.__exit__, None, None, None)
- txn.call_after(ctx.__exit__, None, None, None)
+ stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None:
@@ -158,17 +263,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
self.db.simple_insert_txn(
txn,
- table="cache_invalidation_stream",
+ table="cache_invalidation_stream_by_instance",
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
},
)
- def get_cache_stream_token(self):
+ def get_cache_stream_token(self, instance_name):
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
+ return self._cache_id_gen.get_current_token(instance_name)
else:
return 0
diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py
new file mode 100644
index 00000000..2d482617
--- /dev/null
+++ b/synapse/storage/data_stores/main/censor_events.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.internet import defer
+
+from synapse.events.utils import prune_event_dict
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.data_stores.main.events import encode_json
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ def _censor_redactions():
+ return run_as_background_process(
+ "_censor_redactions", self._censor_redactions
+ )
+
+ if self.hs.config.redaction_retention_period is not None:
+ hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+
+ async def _censor_redactions(self):
+ """Censors all redactions older than the configured period that haven't
+ been censored yet.
+
+ By censor we mean update the event_json table with the redacted event.
+ """
+
+ if self.hs.config.redaction_retention_period is None:
+ return
+
+ if not (
+ await self.db.updates.has_completed_background_update(
+ "redactions_have_censored_ts_idx"
+ )
+ ):
+ # We don't want to run this until the appropriate index has been
+ # created.
+ return
+
+ before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+ # We fetch all redactions that:
+ # 1. point to an event we have,
+ # 2. has a received_ts from before the cut off, and
+ # 3. we haven't yet censored.
+ #
+ # This is limited to 100 events to ensure that we don't try and do too
+ # much at once. We'll get called again so this should eventually catch
+ # up.
+ sql = """
+ SELECT redactions.event_id, redacts FROM redactions
+ LEFT JOIN events AS original_event ON (
+ redacts = original_event.event_id
+ )
+ WHERE NOT have_censored
+ AND redactions.received_ts <= ?
+ ORDER BY redactions.received_ts ASC
+ LIMIT ?
+ """
+
+ rows = await self.db.execute(
+ "_censor_redactions_fetch", None, sql, before_ts, 100
+ )
+
+ updates = []
+
+ for redaction_id, event_id in rows:
+ redaction_event = await self.get_event(redaction_id, allow_none=True)
+ original_event = await self.get_event(
+ event_id, allow_rejected=True, allow_none=True
+ )
+
+ # The SQL above ensures that we have both the redaction and
+ # original event, so if the `get_event` calls return None it
+ # means that the redaction wasn't allowed. Either way we know that
+ # the result won't change so we mark the fact that we've checked.
+ if (
+ redaction_event
+ and original_event
+ and original_event.internal_metadata.is_redacted()
+ ):
+ # Redaction was allowed
+ pruned_json = encode_json(
+ prune_event_dict(
+ original_event.room_version, original_event.get_dict()
+ )
+ )
+ else:
+ # Redaction wasn't allowed
+ pruned_json = None
+
+ updates.append((redaction_id, event_id, pruned_json))
+
+ def _update_censor_txn(txn):
+ for redaction_id, event_id, pruned_json in updates:
+ if pruned_json:
+ self._censor_event_txn(txn, event_id, pruned_json)
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="redactions",
+ keyvalues={"event_id": redaction_id},
+ updatevalues={"have_censored": True},
+ )
+
+ await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+ def _censor_event_txn(self, txn, event_id, pruned_json):
+ """Censor an event by replacing its JSON in the event_json table with the
+ provided pruned JSON.
+
+ Args:
+ txn (LoggingTransaction): The database transaction.
+ event_id (str): The ID of the event to censor.
+ pruned_json (str): The pruned JSON
+ """
+ self.db.simple_update_one_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": pruned_json},
+ )
+
+ @defer.inlineCallbacks
+ def expire_event(self, event_id):
+ """Retrieve and expire an event that has expired, and delete its associated
+ expiry timestamp. If the event can't be retrieved, delete its associated
+ timestamp so we don't try to expire it again in the future.
+
+ Args:
+ event_id (str): The ID of the event to delete.
+ """
+ # Try to retrieve the event's content from the database or the event cache.
+ event = yield self.get_event(event_id)
+
+ def delete_expired_event_txn(txn):
+ # Delete the expiry timestamp associated with this event from the database.
+ self._delete_event_expiry_txn(txn, event_id)
+
+ if not event:
+ # If we can't find the event, log a warning and delete the expiry date
+ # from the database so that we don't try to expire it again in the
+ # future.
+ logger.warning(
+ "Can't expire event %s because we don't have it.", event_id
+ )
+ return
+
+ # Prune the event's dict then convert it to JSON.
+ pruned_json = encode_json(
+ prune_event_dict(event.room_version, event.get_dict())
+ )
+
+ # Update the event_json table to replace the event's JSON with the pruned
+ # JSON.
+ self._censor_event_txn(txn, event.event_id, pruned_json)
+
+ # We need to invalidate the event cache entry for this event because we
+ # changed its content in the database. We can't call
+ # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+ # right type.
+ txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ # Send that invalidation to replication so that other workers also invalidate
+ # the event cache.
+ self._send_invalidation_to_replication(
+ txn, "_get_event_cache", (event.event_id,)
+ )
+
+ yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+ def _delete_event_expiry_txn(self, txn, event_id):
+ """Delete the expiry timestamp associated with an event ID without deleting the
+ actual event.
+
+ Args:
+ txn (LoggingTransaction): The transaction to use to perform the deletion.
+ event_id (str): The event ID to delete the associated expiry timestamp of.
+ """
+ return self.db.simple_delete_txn(
+ txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+ )
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 92bc0691..71f8d43a 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -22,7 +22,6 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database, make_tuple_comparison_clause
-from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__)
@@ -361,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+ name="client_ip_last_seen", keylen=4, max_entries=50000
)
super(ClientIpStore, self).__init__(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 03f5141e..fb9f798e 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import List, Optional, Set, Tuple
from six import iteritems
@@ -342,32 +342,23 @@ class DeviceWorkerStore(SQLBaseStore):
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
- # poked users. We do the join to see which users need to be inserted and
- # which updated.
+ # poked users.
sql = """
- SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ SELECT user_id, coalesce(max(o.stream_id), 0)
FROM device_lists_outbound_pokes as o
- LEFT JOIN device_lists_outbound_last_success as s
- USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
- sql = """
- UPDATE device_lists_outbound_last_success
- SET stream_id = ?
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
- sql = """
- INSERT INTO device_lists_outbound_last_success
- (destination, user_id, stream_id) VALUES (?, ?, ?)
- """
- txn.executemany(
- sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+ self.db.simple_upsert_many_txn(
+ txn=txn,
+ table="device_lists_outbound_last_success",
+ key_names=("destination", "user_id"),
+ key_values=((destination, user_id) for user_id, _ in rows),
+ value_names=("stream_id",),
+ value_values=((stream_id,) for _, stream_id in rows),
)
# Delete all sent outbound pokes
@@ -654,21 +645,31 @@ class DeviceWorkerStore(SQLBaseStore):
return results
@defer.inlineCallbacks
- def get_user_ids_requiring_device_list_resync(self, user_ids: Collection[str]):
+ def get_user_ids_requiring_device_list_resync(
+ self, user_ids: Optional[Collection[str]] = None,
+ ) -> Set[str]:
"""Given a list of remote users return the list of users that we
- should resync the device lists for.
+ should resync the device lists for. If None is given instead of a list,
+ return every user that we should resync the device lists for.
Returns:
- Deferred[Set[str]]
+ The IDs of users whose device lists need resync.
"""
-
- rows = yield self.db.simple_select_many_batch(
- table="device_lists_remote_resync",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync",
- )
+ if user_ids:
+ rows = yield self.db.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ )
+ else:
+ rows = yield self.db.simple_select_list(
+ table="device_lists_remote_resync",
+ keyvalues=None,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ )
return {row["user_id"] for row in rows}
@@ -684,6 +685,25 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
+ def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+ """Mark that we no longer track device lists for remote user.
+ """
+
+ def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+ self.db.simple_delete_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
+ )
+
+ return self.db.runInteraction(
+ "mark_remote_user_device_list_as_unsubscribed",
+ _mark_remote_user_device_list_as_unsubscribed_txn,
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
@@ -725,6 +745,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
)
+ # a pair of background updates that were added during the 1.14 release cycle,
+ # but replaced with 58/06dlols_unique_idx.py
+ self.db.updates.register_noop_background_update(
+ "device_lists_outbound_last_success_unique_idx",
+ )
+ self.db.updates.register_noop_background_update(
+ "drop_device_lists_outbound_last_success_non_unique_idx",
+ )
+
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
@@ -935,17 +964,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device",
)
- @defer.inlineCallbacks
- def mark_remote_user_device_list_as_unsubscribed(self, user_id):
- """Mark that we no longer track device lists for remote user.
- """
- yield self.db.simple_delete(
- table="device_lists_remote_extremeties",
- keyvalues={"user_id": user_id},
- desc="mark_remote_user_device_list_as_unsubscribed",
- )
- self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
-
def update_remote_device_list_cache_entry(
self, user_id, device_id, content, stream_id
):
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index bcf746b7..20698bfd 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -25,7 +25,9 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
- def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
- """Returns a user's cross-signing key.
-
- Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being requested: either 'master'
- for a master key, 'self_signing' for a self-signing key, or
- 'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
- the key will be included in the result
-
- Returns:
- dict of the key data or None if not found
- """
- sql = (
- "SELECT keydata "
- " FROM e2e_cross_signing_keys "
- " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
- )
- txn.execute(sql, (user_id, key_type))
- row = txn.fetchone()
- if not row:
- return None
- key = json.loads(row[0])
-
- device_id = None
- for k in key["keys"].values():
- device_id = k
-
- if from_user_id is not None:
- sql = (
- "SELECT key_id, signature "
- " FROM e2e_cross_signing_signatures "
- " WHERE user_id = ? "
- " AND target_user_id = ? "
- " AND target_device_id = ? "
- )
- txn.execute(sql, (from_user_id, user_id, device_id))
- row = txn.fetchone()
- if row:
- key.setdefault("signatures", {}).setdefault(from_user_id, {})[
- row[0]
- ] = row[1]
-
- return key
-
+ @defer.inlineCallbacks
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
dict of the key data or None if not found
"""
- return self.db.runInteraction(
- "get_e2e_cross_signing_key",
- self._get_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- from_user_id,
- )
+ res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+ user_keys = res.get(user_id)
+ if not user_keys:
+ return None
+ return user_keys.get(key_type)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""
result = {}
- batch_size = 100
- chunks = [
- user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
- ]
- for user_chunk in chunks:
- sql = """
+ for user_chunk in batch_iter(user_ids, 100):
+ clause, params = make_in_list_sql_clause(
+ txn.database_engine, "k.user_id", user_chunk
+ )
+ sql = (
+ """
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
- WHERE k.user_id IN (%s)
- """ % (
- ",".join("?" for u in user_chunk),
+ WHERE
+ """
+ + clause
)
- query_params = []
- query_params.extend(user_chunk)
- txn.execute(sql, query_params)
+ txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn)
for row in rows:
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
device_id = k
devices[(user_id, device_id)] = key_type
- device_list = list(devices)
-
- # split into batches
- batch_size = 100
- chunks = [
- device_list[i : i + batch_size]
- for i in range(0, len(device_list), batch_size)
- ]
- for user_chunk in chunks:
+ for batch in batch_iter(devices.keys(), size=100):
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
AND (%s)
""" % (
" OR ".join(
- "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ "(target_user_id = ? AND target_device_id = ?)" for _ in batch
)
)
query_params = [from_user_id]
- for item in devices:
+ for item in batch:
# item is a (user_id, device_id) tuple
query_params.extend(item)
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index b99439cc..24ce8c43 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -640,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore):
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
- def _update_min_depth_for_room_txn(self, txn, room_id, depth):
- min_depth = self._get_min_depth_interaction(txn, room_id)
-
- if min_depth is not None and depth >= min_depth:
- return
-
- self.db.simple_upsert_txn(
- txn,
- table="room_depth",
- keyvalues={"room_id": room_id},
- values={"min_depth": depth},
- )
-
- def _handle_mult_prev_events(self, txn, events):
- """
- For the given event, update the event edges table and forward and
- backward extremities tables.
- """
- self.db.simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": ev.event_id,
- "prev_event_id": e_id,
- "room_id": ev.room_id,
- "is_state": False,
- }
- for ev in events
- for e_id in ev.prev_event_ids()
- ],
- )
-
- self._update_backward_extremeties(txn, events)
-
- def _update_backward_extremeties(self, txn, events):
- """Updates the event_backward_extremities tables based on the new/updated
- events being persisted.
-
- This is called for new events *and* for events that were outliers, but
- are now being persisted as non-outliers.
-
- Forward extremities are handled when we first start persisting the events.
- """
- events_by_room = {}
- for ev in events:
- events_by_room.setdefault(ev.room_id, []).append(ev)
-
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- " AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
- " )"
- )
-
- txn.executemany(
- query,
- [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
- for ev in events
- for e_id in ev.prev_event_ids()
- if not ev.internal_metadata.is_outlier()
- ],
- )
-
- query = (
- "DELETE FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- )
- txn.executemany(
- query,
- [
- (ev.event_id, ev.room_id)
- for ev in events
- if not ev.internal_metadata.is_outlier()
- ],
- )
-
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 8eed5909..0321274d 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -652,69 +652,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._start_rotate_notifs, 30 * 60 * 1000
)
- def _set_push_actions_for_event_and_users_txn(
- self, txn, events_and_contexts, all_events_and_contexts
- ):
- """Handles moving push actions from staging table to main
- event_push_actions table for all events in `events_and_contexts`.
-
- Also ensures that all events in `all_events_and_contexts` are removed
- from the push action staging area.
-
- Args:
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
- all_events_and_contexts (list[(EventBase, EventContext)]): all
- events that we were going to persist. This includes events
- we've already persisted, etc, that wouldn't appear in
- events_and_context.
- """
-
- sql = """
- INSERT INTO event_push_actions (
- room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight
- )
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
- FROM event_push_actions_staging
- WHERE event_id = ?
- """
-
- if events_and_contexts:
- txn.executemany(
- sql,
- (
- (
- event.room_id,
- event.internal_metadata.stream_ordering,
- event.depth,
- event.event_id,
- )
- for event, _ in events_and_contexts
- ),
- )
-
- for event, _ in events_and_contexts:
- user_ids = self.db.simple_select_onecol_txn(
- txn,
- table="event_push_actions_staging",
- keyvalues={"event_id": event.event_id},
- retcol="user_id",
- )
-
- for uid in user_ids:
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (event.room_id, uid),
- )
-
- # Now we delete the staging area for *all* events that were being
- # persisted.
- txn.executemany(
- "DELETE FROM event_push_actions_staging WHERE event_id = ?",
- ((event.event_id,) for event, _ in all_events_and_contexts),
- )
-
@defer.inlineCallbacks
def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False
@@ -763,17 +700,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
return result[0] or 0
- def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
- # Sad that we have to blow away the cache for the whole room here
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id,),
- )
- txn.execute(
- "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
- (room_id, event_id),
- )
-
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index e71c2354..a6572571 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,39 +17,44 @@
import itertools
import logging
-from collections import Counter as c_counter, OrderedDict, namedtuple
+from collections import OrderedDict, namedtuple
from functools import wraps
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
-from six import iteritems, text_type
+from six import integer_types, iteritems, text_type
from six.moves import range
+import attr
from canonicaljson import json
from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+)
from synapse.api.room_versions import RoomVersions
+from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
-from synapse.events.utils import prune_event_dict
from synapse.logging.utils import log_function
-from synapse.metrics import BucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.data_stores.main.event_federation import EventFederationStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.data_stores.main.search import SearchEntry
from synapse.storage.database import Database, LoggingTransaction
-from synapse.storage.persist_events import DeltaState
-from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.storage.data_stores.main import DataStore
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
@@ -94,58 +99,49 @@ def _retry_on_integrity_error(func):
return f
-# inherits from EventFederationStore so that we can call _update_backward_extremities
-# and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(
- StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
-):
- def __init__(self, database: Database, db_conn, hs):
- super(EventsStore, self).__init__(database, db_conn, hs)
+@attr.s(slots=True)
+class DeltaState:
+ """Deltas to use to update the `current_state_events` table.
- # Collect metrics on the number of forward extremities that exist.
- # Counter of number of extremities to count
- self._current_forward_extremities_amount = c_counter()
+ Attributes:
+ to_delete: List of type/state_keys to delete from current state
+ to_insert: Map of state to upsert into current state
+ no_longer_in_room: The server is not longer in the room, so the room
+ should e.g. be removed from `current_state_events` table.
+ """
- BucketCollector(
- "synapse_forward_extremities",
- lambda: self._current_forward_extremities_amount,
- buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
- )
+ to_delete = attr.ib(type=List[Tuple[str, str]])
+ to_insert = attr.ib(type=StateMap[str])
+ no_longer_in_room = attr.ib(type=bool, default=False)
- # Read the extrems every 60 minutes
- def read_forward_extremities():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "read_forward_extremities", self._read_forward_extremities
- )
- hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+class PersistEventsStore:
+ """Contains all the functions for writing events to the database.
- def _censor_redactions():
- return run_as_background_process(
- "_censor_redactions", self._censor_redactions
- )
+ Should only be instantiated on one process (when using a worker mode setup).
+
+ Note: This is not part of the `DataStore` mixin.
+ """
- if self.hs.config.redaction_retention_period is not None:
- hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+ def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"):
+ self.hs = hs
+ self.db = db
+ self.store = main_data_store
+ self.database_engine = db.engine
+ self._clock = hs.get_clock()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
- @defer.inlineCallbacks
- def _read_forward_extremities(self):
- def fetch(txn):
- txn.execute(
- """
- select count(*) c from event_forward_extremities
- group by room_id
- """
- )
- return txn.fetchall()
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator
+ self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator
- res = yield self.db.runInteraction("read_forward_extremities", fetch)
- self._current_forward_extremities_amount = c_counter([x[0] for x in res])
+ # This should only exist on instances that are configured to write
+ assert (
+ hs.config.worker.writers.events == hs.get_instance_name()
+ ), "Can only instantiate EventsStore on master"
@_retry_on_integrity_error
@defer.inlineCallbacks
@@ -237,10 +233,10 @@ class EventsStore(
event_counter.labels(event.type, origin_type, origin_entity).inc()
for room_id, new_state in iteritems(current_state_for_room):
- self.get_current_state_ids.prefill((room_id,), new_state)
+ self.store.get_current_state_ids.prefill((room_id,), new_state)
for room_id, latest_event_ids in iteritems(new_forward_extremeties):
- self.get_latest_event_ids_in_room.prefill(
+ self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
@@ -586,7 +582,7 @@ class EventsStore(
)
txn.call_after(
- self._curr_state_delta_stream_cache.entity_has_changed,
+ self.store._curr_state_delta_stream_cache.entity_has_changed,
room_id,
stream_id,
)
@@ -606,10 +602,13 @@ class EventsStore(
for member in members_changed:
txn.call_after(
- self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
+ self.store.get_rooms_for_user_with_stream_ordering.invalidate,
+ (member,),
)
- self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+ self.store._invalidate_state_caches_and_stream(
+ txn, room_id, members_changed
+ )
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
"""Update the room version in the database based off current state
@@ -647,7 +646,9 @@ class EventsStore(
self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
- txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+ txn.call_after(
+ self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
+ )
self.db.simple_insert_many_txn(
txn,
@@ -713,10 +714,10 @@ class EventsStore(
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
- txn.call_after(self._invalidate_get_event_cache, event.event_id)
+ txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
if not backfilled:
txn.call_after(
- self._events_stream_cache.entity_has_changed,
+ self.store._events_stream_cache.entity_has_changed,
event.room_id,
event.internal_metadata.stream_ordering,
)
@@ -1088,13 +1089,15 @@ class EventsStore(
def prefill():
for cache_entry in to_prefill:
- self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
+ self.store._get_event_cache.prefill(
+ (cache_entry[0].event_id,), cache_entry
+ )
txn.call_after(prefill)
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
- txn.call_after(self._invalidate_get_event_cache, event.redacts)
+ txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
self.db.simple_insert_txn(
txn,
@@ -1106,783 +1109,512 @@ class EventsStore(
},
)
- async def _censor_redactions(self):
- """Censors all redactions older than the configured period that haven't
- been censored yet.
+ def insert_labels_for_event_txn(
+ self, txn, event_id, labels, room_id, topological_ordering
+ ):
+ """Store the mapping between an event's ID and its labels, with one row per
+ (event_id, label) tuple.
- By censor we mean update the event_json table with the redacted event.
+ Args:
+ txn (LoggingTransaction): The transaction to execute.
+ event_id (str): The event's ID.
+ labels (list[str]): A list of text labels.
+ room_id (str): The ID of the room the event was sent to.
+ topological_ordering (int): The position of the event in the room's topology.
"""
+ return self.db.simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": room_id,
+ "topological_ordering": topological_ordering,
+ }
+ for label in labels
+ ],
+ )
- if self.hs.config.redaction_retention_period is None:
- return
-
- if not (
- await self.db.updates.has_completed_background_update(
- "redactions_have_censored_ts_idx"
- )
- ):
- # We don't want to run this until the appropriate index has been
- # created.
- return
-
- before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+ def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+ """Save the expiry timestamp associated with a given event ID.
- # We fetch all redactions that:
- # 1. point to an event we have,
- # 2. has a received_ts from before the cut off, and
- # 3. we haven't yet censored.
- #
- # This is limited to 100 events to ensure that we don't try and do too
- # much at once. We'll get called again so this should eventually catch
- # up.
- sql = """
- SELECT redactions.event_id, redacts FROM redactions
- LEFT JOIN events AS original_event ON (
- redacts = original_event.event_id
- )
- WHERE NOT have_censored
- AND redactions.received_ts <= ?
- ORDER BY redactions.received_ts ASC
- LIMIT ?
+ Args:
+ txn (LoggingTransaction): The database transaction to use.
+ event_id (str): The event ID the expiry timestamp is associated with.
+ expiry_ts (int): The timestamp at which to expire (delete) the event.
"""
-
- rows = await self.db.execute(
- "_censor_redactions_fetch", None, sql, before_ts, 100
+ return self.db.simple_insert_txn(
+ txn=txn,
+ table="event_expiry",
+ values={"event_id": event_id, "expiry_ts": expiry_ts},
)
- updates = []
+ def _store_event_reference_hashes_txn(self, txn, events):
+ """Store a hash for a PDU
+ Args:
+ txn (cursor):
+ events (list): list of Events.
+ """
- for redaction_id, event_id in rows:
- redaction_event = await self.get_event(redaction_id, allow_none=True)
- original_event = await self.get_event(
- event_id, allow_rejected=True, allow_none=True
+ vals = []
+ for event in events:
+ ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+ vals.append(
+ {
+ "event_id": event.event_id,
+ "algorithm": ref_alg,
+ "hash": memoryview(ref_hash_bytes),
+ }
)
- # The SQL above ensures that we have both the redaction and
- # original event, so if the `get_event` calls return None it
- # means that the redaction wasn't allowed. Either way we know that
- # the result won't change so we mark the fact that we've checked.
- if (
- redaction_event
- and original_event
- and original_event.internal_metadata.is_redacted()
- ):
- # Redaction was allowed
- pruned_json = encode_json(
- prune_event_dict(
- original_event.room_version, original_event.get_dict()
- )
- )
- else:
- # Redaction wasn't allowed
- pruned_json = None
-
- updates.append((redaction_id, event_id, pruned_json))
-
- def _update_censor_txn(txn):
- for redaction_id, event_id, pruned_json in updates:
- if pruned_json:
- self._censor_event_txn(txn, event_id, pruned_json)
-
- self.db.simple_update_one_txn(
- txn,
- table="redactions",
- keyvalues={"event_id": redaction_id},
- updatevalues={"have_censored": True},
- )
-
- await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+ self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
- def _censor_event_txn(self, txn, event_id, pruned_json):
- """Censor an event by replacing its JSON in the event_json table with the
- provided pruned JSON.
-
- Args:
- txn (LoggingTransaction): The database transaction.
- event_id (str): The ID of the event to censor.
- pruned_json (str): The pruned JSON
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
"""
- self.db.simple_update_one_txn(
+ self.db.simple_insert_many_txn(
txn,
- table="event_json",
- keyvalues={"event_id": event_id},
- updatevalues={"json": pruned_json},
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ],
)
- @defer.inlineCallbacks
- def count_daily_messages(self):
- """
- Returns an estimate of the number of messages sent in the last day.
-
- If it has been significantly less or more than one day since the last
- call to this function, it will return None.
- """
-
- def _count_messages(txn):
- sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
- WHERE type = 'm.room.message'
- AND stream_ordering > ?
- """
- txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
- return count
-
- ret = yield self.db.runInteraction("count_messages", _count_messages)
- return ret
-
- @defer.inlineCallbacks
- def count_daily_sent_messages(self):
- def _count_messages(txn):
- # This is good enough as if you have silly characters in your own
- # hostname then thats your own fault.
- like_clause = "%:" + self.hs.hostname
-
- sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
- WHERE type = 'm.room.message'
- AND sender LIKE ?
- AND stream_ordering > ?
- """
-
- txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- (count,) = txn.fetchone()
- return count
-
- ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
- return ret
-
- @defer.inlineCallbacks
- def count_daily_active_rooms(self):
- def _count(txn):
- sql = """
- SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
- WHERE type = 'm.room.message'
- AND stream_ordering > ?
- """
- txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
- return count
-
- ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
- return ret
-
- @cached(num_args=5, max_entries=10)
- def get_all_new_events(
- self,
- last_backfill_id,
- last_forward_id,
- current_backfill_id,
- current_forward_id,
- limit,
- ):
- """Get all the new events that have arrived at the server either as
- new events or as backfilled events"""
- have_backfill_events = last_backfill_id != current_backfill_id
- have_forward_events = last_forward_id != current_forward_id
-
- if not have_backfill_events and not have_forward_events:
- return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
- def get_all_new_events_txn(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
+ for event in events:
+ txn.call_after(
+ self.store._membership_stream_cache.entity_has_changed,
+ event.state_key,
+ event.internal_metadata.stream_ordering,
)
- if have_forward_events:
- txn.execute(sql, (last_forward_id, current_forward_id, limit))
- new_forward_events = txn.fetchall()
-
- if len(new_forward_events) == limit:
- upper_bound = new_forward_events[-1][0]
- else:
- upper_bound = current_forward_id
-
- sql = (
- "SELECT event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_forward_id, upper_bound))
- forward_ex_outliers = txn.fetchall()
- else:
- new_forward_events = []
- forward_ex_outliers = []
-
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
+ txn.call_after(
+ self.store.get_invited_rooms_for_local_user.invalidate,
+ (event.state_key,),
)
- if have_backfill_events:
- txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
- new_backfill_events = txn.fetchall()
- if len(new_backfill_events) == limit:
- upper_bound = new_backfill_events[-1][0]
- else:
- upper_bound = current_backfill_id
-
- sql = (
- "SELECT -event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_backfill_id, -upper_bound))
- backward_ex_outliers = txn.fetchall()
- else:
- new_backfill_events = []
- backward_ex_outliers = []
-
- return AllNewEventsResult(
- new_forward_events,
- new_backfill_events,
- forward_ex_outliers,
- backward_ex_outliers,
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened. If the event is an
+ # outlier it is only current if its an "out of band membership",
+ # like a remote invite or a rejection of a remote invite.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_out_of_band_membership()
)
+ is_mine = self.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self.db.simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ },
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
- return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
-
- def purge_history(self, room_id, token, delete_local_events):
- """Deletes room history before a certain point
-
- Args:
- room_id (str):
+ txn.execute(
+ sql,
+ (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ),
+ )
- token (str): A topological token to delete events before
+ # We also update the `local_current_membership` table with
+ # latest invite info. This will usually get updated by the
+ # `current_state_events` handling, unless its an outlier.
+ if event.internal_metadata.is_outlier():
+ # This should only happen for out of band memberships, so
+ # we add a paranoia check.
+ assert event.internal_metadata.is_out_of_band_membership()
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={
+ "room_id": event.room_id,
+ "user_id": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
- delete_local_events (bool):
- if True, we will delete local events as well as remote ones
- (instead of just marking them as outliers and deleting their
- state groups).
+ def _handle_event_relations(self, txn, event):
+ """Handles inserting relation data during peristence of events
- Returns:
- Deferred[set[int]]: The set of state groups that are referenced by
- deleted events.
+ Args:
+ txn
+ event (EventBase)
"""
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ # No relations
+ return
- return self.db.runInteraction(
- "purge_history",
- self._purge_history_txn,
- room_id,
- token,
- delete_local_events,
- )
+ rel_type = relation.get("rel_type")
+ if rel_type not in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
+ # Unknown relation type
+ return
- def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
- token = RoomStreamToken.parse(token_str)
-
- # Tables that should be pruned:
- # event_auth
- # event_backward_extremities
- # event_edges
- # event_forward_extremities
- # event_json
- # event_push_actions
- # event_reference_hashes
- # event_search
- # event_to_state_groups
- # events
- # rejections
- # room_depth
- # state_groups
- # state_groups_state
-
- # we will build a temporary table listing the events so that we don't
- # have to keep shovelling the list back and forth across the
- # connection. Annoyingly the python sqlite driver commits the
- # transaction on CREATE, so let's do this first.
- #
- # furthermore, we might already have the table from a previous (failed)
- # purge attempt, so let's drop the table first.
+ parent_id = relation.get("event_id")
+ if not parent_id:
+ # Invalid relation
+ return
- txn.execute("DROP TABLE IF EXISTS events_to_purge")
+ aggregation_key = relation.get("key")
- txn.execute(
- "CREATE TEMPORARY TABLE events_to_purge ("
- " event_id TEXT NOT NULL,"
- " should_delete BOOLEAN NOT NULL"
- ")"
+ self.db.simple_insert_txn(
+ txn,
+ table="event_relations",
+ values={
+ "event_id": event.event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ "aggregation_key": aggregation_key,
+ },
)
- # First ensure that we're not about to delete all the forward extremeties
- txn.execute(
- "SELECT e.event_id, e.depth FROM events as e "
- "INNER JOIN event_forward_extremities as f "
- "ON e.event_id = f.event_id "
- "AND e.room_id = f.room_id "
- "WHERE f.room_id = ?",
- (room_id,),
+ txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
+ txn.call_after(
+ self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
)
- rows = txn.fetchall()
- max_depth = max(row[1] for row in rows)
-
- if max_depth < token.topological:
- # We need to ensure we don't delete all the events from the database
- # otherwise we wouldn't be able to send any events (due to not
- # having any backwards extremeties)
- raise SynapseError(
- 400, "topological_ordering is greater than forward extremeties"
- )
-
- logger.info("[purge] looking for events to delete")
-
- should_delete_expr = "state_key IS NULL"
- should_delete_params = ()
- if not delete_local_events:
- should_delete_expr += " AND event_id NOT LIKE ?"
-
- # We include the parameter twice since we use the expression twice
- should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
- should_delete_params += (room_id, token.topological)
+ if rel_type == RelationTypes.REPLACE:
+ txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
- # Note that we insert events that are outliers and aren't going to be
- # deleted, as nothing will happen to them.
- txn.execute(
- "INSERT INTO events_to_purge"
- " SELECT event_id, %s"
- " FROM events AS e LEFT JOIN state_events USING (event_id)"
- " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
- % (should_delete_expr, should_delete_expr),
- should_delete_params,
- )
+ def _handle_redaction(self, txn, redacted_event_id):
+ """Handles receiving a redaction and checking whether we need to remove
+ any redacted relations from the database.
- # We create the indices *after* insertion as that's a lot faster.
+ Args:
+ txn
+ redacted_event_id (str): The event that was redacted.
+ """
- # create an index on should_delete because later we'll be looking for
- # the should_delete / shouldn't_delete subsets
- txn.execute(
- "CREATE INDEX events_to_purge_should_delete"
- " ON events_to_purge(should_delete)"
+ self.db.simple_delete_txn(
+ txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
- # We do joins against events_to_purge for e.g. calculating state
- # groups to purge, etc., so lets make an index.
- txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
-
- txn.execute("SELECT event_id, should_delete FROM events_to_purge")
- event_rows = txn.fetchall()
- logger.info(
- "[purge] found %i events before cutoff, of which %i can be deleted",
- len(event_rows),
- sum(1 for e in event_rows if e[1]),
- )
+ def _store_room_topic_txn(self, txn, event):
+ if hasattr(event, "content") and "topic" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.topic", event.content["topic"]
+ )
- logger.info("[purge] Finding new backward extremities")
+ def _store_room_name_txn(self, txn, event):
+ if hasattr(event, "content") and "name" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.name", event.content["name"]
+ )
- # We calculate the new entries for the backward extremeties by finding
- # events to be purged that are pointed to by events we're not going to
- # purge.
- txn.execute(
- "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
- " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
- " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
- " WHERE ep2.event_id IS NULL"
- )
- new_backwards_extrems = txn.fetchall()
+ def _store_room_message_txn(self, txn, event):
+ if hasattr(event, "content") and "body" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.body", event.content["body"]
+ )
- logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+ def _store_retention_policy_for_room_txn(self, txn, event):
+ if hasattr(event, "content") and (
+ "min_lifetime" in event.content or "max_lifetime" in event.content
+ ):
+ if (
+ "min_lifetime" in event.content
+ and not isinstance(event.content.get("min_lifetime"), integer_types)
+ ) or (
+ "max_lifetime" in event.content
+ and not isinstance(event.content.get("max_lifetime"), integer_types)
+ ):
+ # Ignore the event if one of the value isn't an integer.
+ return
- txn.execute(
- "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
- )
+ self.db.simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "min_lifetime": event.content.get("min_lifetime"),
+ "max_lifetime": event.content.get("max_lifetime"),
+ },
+ )
- # Update backward extremeties
- txn.executemany(
- "INSERT INTO event_backward_extremities (room_id, event_id)"
- " VALUES (?, ?)",
- [(room_id, event_id) for event_id, in new_backwards_extrems],
- )
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_retention_policy_for_room, (event.room_id,)
+ )
- logger.info("[purge] finding state groups referenced by deleted events")
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
- # Get all state groups that are referenced by events that are to be
- # deleted.
- txn.execute(
- """
- SELECT DISTINCT state_group FROM events_to_purge
- INNER JOIN event_to_state_groups USING (event_id)
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
"""
+ self.store.store_search_entries_txn(
+ txn,
+ (
+ SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),
+ ),
)
- referenced_state_groups = {sg for sg, in txn}
- logger.info(
- "[purge] found %i referenced state groups", len(referenced_state_groups)
- )
+ def _set_push_actions_for_event_and_users_txn(
+ self, txn, events_and_contexts, all_events_and_contexts
+ ):
+ """Handles moving push actions from staging table to main
+ event_push_actions table for all events in `events_and_contexts`.
- logger.info("[purge] removing events from event_to_state_groups")
- txn.execute(
- "DELETE FROM event_to_state_groups "
- "WHERE event_id IN (SELECT event_id from events_to_purge)"
- )
- for event_id, _ in event_rows:
- txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+ Also ensures that all events in `all_events_and_contexts` are removed
+ from the push action staging area.
- # Delete all remote non-state events
- for table in (
- "events",
- "event_json",
- "event_auth",
- "event_edges",
- "event_forward_extremities",
- "event_reference_hashes",
- "event_search",
- "rejections",
- ):
- logger.info("[purge] removing events from %s", table)
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ all_events_and_contexts (list[(EventBase, EventContext)]): all
+ events that we were going to persist. This includes events
+ we've already persisted, etc, that wouldn't appear in
+ events_and_context.
+ """
- txn.execute(
- "DELETE FROM %s WHERE event_id IN ("
- " SELECT event_id FROM events_to_purge WHERE should_delete"
- ")" % (table,)
+ sql = """
+ INSERT INTO event_push_actions (
+ room_id, event_id, user_id, actions, stream_ordering,
+ topological_ordering, notif, highlight
)
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ FROM event_push_actions_staging
+ WHERE event_id = ?
+ """
- # event_push_actions lacks an index on event_id, and has one on
- # (room_id, event_id) instead.
- for table in ("event_push_actions",):
- logger.info("[purge] removing events from %s", table)
+ if events_and_contexts:
+ txn.executemany(
+ sql,
+ (
+ (
+ event.room_id,
+ event.internal_metadata.stream_ordering,
+ event.depth,
+ event.event_id,
+ )
+ for event, _ in events_and_contexts
+ ),
+ )
- txn.execute(
- "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
- " SELECT event_id FROM events_to_purge WHERE should_delete"
- ")" % (table,),
- (room_id,),
+ for event, _ in events_and_contexts:
+ user_ids = self.db.simple_select_onecol_txn(
+ txn,
+ table="event_push_actions_staging",
+ keyvalues={"event_id": event.event_id},
+ retcol="user_id",
)
- # Mark all state and own events as outliers
- logger.info("[purge] marking remaining events as outliers")
- txn.execute(
- "UPDATE events SET outlier = ?"
- " WHERE event_id IN ("
- " SELECT event_id FROM events_to_purge "
- " WHERE NOT should_delete"
- ")",
- (True,),
+ for uid in user_ids:
+ txn.call_after(
+ self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (event.room_id, uid),
+ )
+
+ # Now we delete the staging area for *all* events that were being
+ # persisted.
+ txn.executemany(
+ "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+ ((event.event_id,) for event, _ in all_events_and_contexts),
)
- # synapse tries to take out an exclusive lock on room_depth whenever it
- # persists events (because upsert), and once we run this update, we
- # will block that for the rest of our transaction.
- #
- # So, let's stick it at the end so that we don't block event
- # persistence.
- #
- # We do this by calculating the minimum depth of the backwards
- # extremities. However, the events in event_backward_extremities
- # are ones we don't have yet so we need to look at the events that
- # point to it via event_edges table.
- txn.execute(
- """
- SELECT COALESCE(MIN(depth), 0)
- FROM event_backward_extremities AS eb
- INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
- INNER JOIN events AS e ON e.event_id = eg.event_id
- WHERE eb.room_id = ?
- """,
+ def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+ # Sad that we have to blow away the cache for the whole room here
+ txn.call_after(
+ self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id,),
)
- (min_depth,) = txn.fetchone()
-
- logger.info("[purge] updating room_depth to %d", min_depth)
-
txn.execute(
- "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
- (min_depth, room_id),
+ "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
+ (room_id, event_id),
)
- # finally, drop the temp table. this will commit the txn in sqlite,
- # so make sure to keep this actually last.
- txn.execute("DROP TABLE events_to_purge")
-
- logger.info("[purge] done")
-
- return referenced_state_groups
-
- def purge_room(self, room_id):
- """Deletes all record of a room
+ def _store_rejections_txn(self, txn, event_id, reason):
+ self.db.simple_insert_txn(
+ txn,
+ table="rejections",
+ values={
+ "event_id": event_id,
+ "reason": reason,
+ "last_check": self._clock.time_msec(),
+ },
+ )
- Args:
- room_id (str)
+ def _store_event_state_mappings_txn(
+ self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+ ):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
- Returns:
- Deferred[List[int]]: The list of state groups to delete.
- """
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.state_group_before_event
+ continue
- return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+ state_groups[event.event_id] = context.state_group
- def _purge_room_txn(self, txn, room_id):
- # First we fetch all the state groups that should be deleted, before
- # we delete that information.
- txn.execute(
- """
- SELECT DISTINCT state_group FROM events
- INNER JOIN event_to_state_groups USING(event_id)
- WHERE events.room_id = ?
- """,
- (room_id,),
+ self.db.simple_insert_many_txn(
+ txn,
+ table="event_to_state_groups",
+ values=[
+ {"state_group": state_group_id, "event_id": event_id}
+ for event_id, state_group_id in iteritems(state_groups)
+ ],
)
- state_groups = [row[0] for row in txn]
-
- # Now we delete tables which lack an index on room_id but have one on event_id
- for table in (
- "event_auth",
- "event_edges",
- "event_push_actions_staging",
- "event_reference_hashes",
- "event_relations",
- "event_to_state_groups",
- "redactions",
- "rejections",
- "state_events",
- ):
- logger.info("[purge] removing %s from %s", room_id, table)
-
- txn.execute(
- """
- DELETE FROM %s WHERE event_id IN (
- SELECT event_id FROM events WHERE room_id=?
- )
- """
- % (table,),
- (room_id,),
+ for event_id, state_group_id in iteritems(state_groups):
+ txn.call_after(
+ self.store._get_state_group_for_event.prefill,
+ (event_id,),
+ state_group_id,
)
- # and finally, the tables with an index on room_id (or no useful index)
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- # no useful index, but let's clear them anyway
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "local_invites",
- "room_account_data",
- "room_tags",
- "local_current_membership",
- ):
- logger.info("[purge] removing %s from %s", room_id, table)
- txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
-
- # Other tables we do NOT need to clear out:
- #
- # - blocked_rooms
- # This is important, to make sure that we don't accidentally rejoin a blocked
- # room after it was purged
- #
- # - user_directory
- # This has a room_id column, but it is unused
- #
-
- # Other tables that we might want to consider clearing out include:
- #
- # - event_reports
- # Given that these are intended for abuse management my initial
- # inclination is to leave them in place.
- #
- # - current_state_delta_stream
- # - ex_outlier_stream
- # - room_tags_revisions
- # The problem with these is that they are largeish and there is no room_id
- # index on them. In any case we should be clearing out 'stream' tables
- # periodically anyway (#5888)
-
- # TODO: we could probably usefully do a bunch of cache invalidation here
-
- logger.info("[purge] done")
-
- return state_groups
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self.store._get_min_depth_interaction(txn, room_id)
- async def is_event_after(self, event_id1, event_id2):
- """Returns True if event_id1 is after event_id2 in the stream
- """
- to_1, so_1 = await self._get_event_ordering(event_id1)
- to_2, so_2 = await self._get_event_ordering(event_id2)
- return (to_1, so_1) > (to_2, so_2)
+ if min_depth is not None and depth >= min_depth:
+ return
- @cachedInlineCallbacks(max_entries=5000)
- def _get_event_ordering(self, event_id):
- res = yield self.db.simple_select_one(
- table="events",
- retcols=["topological_ordering", "stream_ordering"],
- keyvalues={"event_id": event_id},
- allow_none=True,
+ self.db.simple_upsert_txn(
+ txn,
+ table="room_depth",
+ keyvalues={"room_id": room_id},
+ values={"min_depth": depth},
)
- if not res:
- raise SynapseError(404, "Could not find event %s" % (event_id,))
-
- return (int(res["topological_ordering"]), int(res["stream_ordering"]))
-
- def insert_labels_for_event_txn(
- self, txn, event_id, labels, room_id, topological_ordering
- ):
- """Store the mapping between an event's ID and its labels, with one row per
- (event_id, label) tuple.
-
- Args:
- txn (LoggingTransaction): The transaction to execute.
- event_id (str): The event's ID.
- labels (list[str]): A list of text labels.
- room_id (str): The ID of the room the event was sent to.
- topological_ordering (int): The position of the event in the room's topology.
+ def _handle_mult_prev_events(self, txn, events):
"""
- return self.db.simple_insert_many_txn(
- txn=txn,
- table="event_labels",
+ For the given event, update the event edges table and forward and
+ backward extremities tables.
+ """
+ self.db.simple_insert_many_txn(
+ txn,
+ table="event_edges",
values=[
{
- "event_id": event_id,
- "label": label,
- "room_id": room_id,
- "topological_ordering": topological_ordering,
+ "event_id": ev.event_id,
+ "prev_event_id": e_id,
+ "room_id": ev.room_id,
+ "is_state": False,
}
- for label in labels
+ for ev in events
+ for e_id in ev.prev_event_ids()
],
)
- def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
- """Save the expiry timestamp associated with a given event ID.
+ self._update_backward_extremeties(txn, events)
- Args:
- txn (LoggingTransaction): The database transaction to use.
- event_id (str): The event ID the expiry timestamp is associated with.
- expiry_ts (int): The timestamp at which to expire (delete) the event.
- """
- return self.db.simple_insert_txn(
- txn=txn,
- table="event_expiry",
- values={"event_id": event_id, "expiry_ts": expiry_ts},
- )
+ def _update_backward_extremeties(self, txn, events):
+ """Updates the event_backward_extremities tables based on the new/updated
+ events being persisted.
- @defer.inlineCallbacks
- def expire_event(self, event_id):
- """Retrieve and expire an event that has expired, and delete its associated
- expiry timestamp. If the event can't be retrieved, delete its associated
- timestamp so we don't try to expire it again in the future.
+ This is called for new events *and* for events that were outliers, but
+ are now being persisted as non-outliers.
- Args:
- event_id (str): The ID of the event to delete.
+ Forward extremities are handled when we first start persisting the events.
"""
- # Try to retrieve the event's content from the database or the event cache.
- event = yield self.get_event(event_id)
-
- def delete_expired_event_txn(txn):
- # Delete the expiry timestamp associated with this event from the database.
- self._delete_event_expiry_txn(txn, event_id)
-
- if not event:
- # If we can't find the event, log a warning and delete the expiry date
- # from the database so that we don't try to expire it again in the
- # future.
- logger.warning(
- "Can't expire event %s because we don't have it.", event_id
- )
- return
-
- # Prune the event's dict then convert it to JSON.
- pruned_json = encode_json(
- prune_event_dict(event.room_version, event.get_dict())
- )
-
- # Update the event_json table to replace the event's JSON with the pruned
- # JSON.
- self._censor_event_txn(txn, event.event_id, pruned_json)
-
- # We need to invalidate the event cache entry for this event because we
- # changed its content in the database. We can't call
- # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
- # right type.
- txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
- # Send that invalidation to replication so that other workers also invalidate
- # the event cache.
- self._send_invalidation_to_replication(
- txn, "_get_event_cache", (event.event_id,)
- )
-
- yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
-
- def _delete_event_expiry_txn(self, txn, event_id):
- """Delete the expiry timestamp associated with an event ID without deleting the
- actual event.
+ events_by_room = {}
+ for ev in events:
+ events_by_room.setdefault(ev.room_id, []).append(ev)
+
+ query = (
+ "INSERT INTO event_backward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ " )"
+ " AND NOT EXISTS ("
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+ " AND outlier = ?"
+ " )"
+ )
- Args:
- txn (LoggingTransaction): The transaction to use to perform the deletion.
- event_id (str): The event ID to delete the associated expiry timestamp of.
- """
- return self.db.simple_delete_txn(
- txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+ txn.executemany(
+ query,
+ [
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ for ev in events
+ for e_id in ev.prev_event_ids()
+ if not ev.internal_metadata.is_outlier()
+ ],
)
- def get_next_event_to_expire(self):
- """Retrieve the entry with the lowest expiry timestamp in the event_expiry
- table, or None if there's no more event to expire.
+ query = (
+ "DELETE FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ )
+ txn.executemany(
+ query,
+ [
+ (ev.event_id, ev.room_id)
+ for ev in events
+ if not ev.internal_metadata.is_outlier()
+ ],
+ )
- Returns: Deferred[Optional[Tuple[str, int]]]
- A tuple containing the event ID as its first element and an expiry timestamp
- as its second one, if there's at least one row in the event_expiry table.
- None otherwise.
+ async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+ """Mark the invite has having been rejected even though we failed to
+ create a leave event for it.
"""
- def get_next_event_to_expire_txn(txn):
- txn.execute(
- """
- SELECT event_id, expiry_ts FROM event_expiry
- ORDER BY expiry_ts ASC LIMIT 1
- """
- )
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
- return txn.fetchone()
+ def f(txn, stream_ordering):
+ txn.execute(sql, (stream_ordering, True, room_id, user_id))
- return self.db.runInteraction(
- desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
- )
+ # We also clear this entry from `local_current_membership`.
+ # Ideally we'd point to a leave event, but we don't have one, so
+ # nevermind.
+ self.db.simple_delete_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ )
+ with self._stream_id_gen.get_next() as stream_ordering:
+ await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-AllNewEventsResult = namedtuple(
- "AllNewEventsResult",
- [
- "new_forward_events",
- "new_backfill_events",
- "forward_ex_outliers",
- "backward_ex_outliers",
- ],
-)
+ return stream_ordering
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 73df6b33..213d6910 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -27,7 +27,7 @@ from constantly import NamedConstant, Names
from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@@ -37,10 +37,12 @@ from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -74,14 +76,50 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+ if hs.config.worker.writers.events == hs.get_instance_name():
+ # We are the process in charge of generating stream ids for events,
+ # so instantiate ID generators based on the database
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ extra_tables=[("local_invites", "stream_id")],
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ else:
+ # Another process is in charge of persisting events and generating
+ # stream IDs: rely on the replication streams to let us know which
+ # IDs we can process.
+ self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
+
self._get_event_cache = Cache(
- "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+ "*getEvent*",
+ keylen=3,
+ max_entries=hs.config.caches.event_cache_size,
+ apply_cache_factor_from_config=False,
)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == "events":
+ self._stream_id_gen.advance(token)
+ elif stream_name == "backfill":
+ self._backfill_id_gen.advance(-token)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
def get_received_ts(self, event_id):
"""Get received_ts (when it was persisted) for the event.
@@ -1154,4 +1192,152 @@ class EventsWorkerStore(SQLBaseStore):
rows = await self.db.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
+
return rows, to_token, True
+
+ @cached(num_args=5, max_entries=10)
+ def get_all_new_events(
+ self,
+ last_backfill_id,
+ last_forward_id,
+ current_backfill_id,
+ current_forward_id,
+ limit,
+ ):
+ """Get all the new events that have arrived at the server either as
+ new events or as backfilled events"""
+ have_backfill_events = last_backfill_id != current_backfill_id
+ have_forward_events = last_forward_id != current_forward_id
+
+ if not have_backfill_events and not have_forward_events:
+ return defer.succeed(AllNewEventsResult([], [], [], [], []))
+
+ def get_all_new_events_txn(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ if have_forward_events:
+ txn.execute(sql, (last_forward_id, current_forward_id, limit))
+ new_forward_events = txn.fetchall()
+
+ if len(new_forward_events) == limit:
+ upper_bound = new_forward_events[-1][0]
+ else:
+ upper_bound = current_forward_id
+
+ sql = (
+ "SELECT event_stream_ordering, event_id, state_group"
+ " FROM ex_outlier_stream"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (last_forward_id, upper_bound))
+ forward_ex_outliers = txn.fetchall()
+ else:
+ new_forward_events = []
+ forward_ex_outliers = []
+
+ sql = (
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+ if have_backfill_events:
+ txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+ new_backfill_events = txn.fetchall()
+
+ if len(new_backfill_events) == limit:
+ upper_bound = new_backfill_events[-1][0]
+ else:
+ upper_bound = current_backfill_id
+
+ sql = (
+ "SELECT -event_stream_ordering, event_id, state_group"
+ " FROM ex_outlier_stream"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (-last_backfill_id, -upper_bound))
+ backward_ex_outliers = txn.fetchall()
+ else:
+ new_backfill_events = []
+ backward_ex_outliers = []
+
+ return AllNewEventsResult(
+ new_forward_events,
+ new_backfill_events,
+ forward_ex_outliers,
+ backward_ex_outliers,
+ )
+
+ return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
+
+ async def is_event_after(self, event_id1, event_id2):
+ """Returns True if event_id1 is after event_id2 in the stream
+ """
+ to_1, so_1 = await self.get_event_ordering(event_id1)
+ to_2, so_2 = await self.get_event_ordering(event_id2)
+ return (to_1, so_1) > (to_2, so_2)
+
+ @cachedInlineCallbacks(max_entries=5000)
+ def get_event_ordering(self, event_id):
+ res = yield self.db.simple_select_one(
+ table="events",
+ retcols=["topological_ordering", "stream_ordering"],
+ keyvalues={"event_id": event_id},
+ allow_none=True,
+ )
+
+ if not res:
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+ return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+
+ def get_next_event_to_expire(self):
+ """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+ table, or None if there's no more event to expire.
+
+ Returns: Deferred[Optional[Tuple[str, int]]]
+ A tuple containing the event ID as its first element and an expiry timestamp
+ as its second one, if there's at least one row in the event_expiry table.
+ None otherwise.
+ """
+
+ def get_next_event_to_expire_txn(txn):
+ txn.execute(
+ """
+ SELECT event_id, expiry_ts FROM event_expiry
+ ORDER BY expiry_ts ASC LIMIT 1
+ """
+ )
+
+ return txn.fetchone()
+
+ return self.db.runInteraction(
+ desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+ )
+
+
+AllNewEventsResult = namedtuple(
+ "AllNewEventsResult",
+ [
+ "new_forward_events",
+ "new_backfill_events",
+ "forward_ex_outliers",
+ "backward_ex_outliers",
+ ],
+)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 0963e6c2..fb1361f1 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_invited_users_in_group",
)
- def get_rooms_in_group(self, group_id, include_private=False):
+ def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+ """Retrieve the rooms that belong to a given group. Does not return rooms that
+ lack members.
+
+ Args:
+ group_id: The ID of the group to query for rooms
+ include_private: Whether to return private rooms in results
+
+ Returns:
+ Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
+ form of:
+
+ {
+ "room_id": "!a_room_id:example.com", # The ID of the room
+ "is_public": False # Whether this is a public room or not
+ }
+ """
# TODO: Pagination
- keyvalues = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
+ def _get_rooms_in_group_txn(txn):
+ sql = """
+ SELECT room_id, is_public FROM group_rooms
+ WHERE group_id = ?
+ AND room_id IN (
+ SELECT group_rooms.room_id FROM group_rooms
+ LEFT JOIN room_stats_current ON
+ group_rooms.room_id = room_stats_current.room_id
+ AND joined_members > 0
+ AND local_users_in_room > 0
+ LEFT JOIN rooms ON
+ group_rooms.room_id = rooms.room_id
+ AND (room_version <> '') = ?
+ )
+ """
+ args = [group_id, False]
- return self.db.simple_select_list(
- table="group_rooms",
- keyvalues=keyvalues,
- retcols=("room_id", "is_public"),
- desc="get_rooms_in_group",
- )
+ if not include_private:
+ sql += " AND is_public = ?"
+ args += [True]
+
+ txn.execute(sql, args)
+
+ return [
+ {"room_id": room_id, "is_public": is_public}
+ for room_id, is_public in txn
+ ]
- def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+ return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+
+ def get_rooms_for_summary_by_category(
+ self, group_id: str, include_private: bool = False,
+ ):
"""Get the rooms and categories that should be included in a summary request
- Returns ([rooms], [categories])
+ Args:
+ group_id: The ID of the group to query the summary for
+ include_private: Whether to return private rooms in results
+
+ Returns:
+ Deferred[Tuple[List, Dict]]: A tuple containing:
+
+ * A list of dictionaries with the keys:
+ * "room_id": str, the room ID
+ * "is_public": bool, whether the room is public
+ * "category_id": str|None, the category ID if set, else None
+ * "order": int, the sort order of rooms
+
+ * A dictionary with the key:
+ * category_id (str): a dictionary with the keys:
+ * "is_public": bool, whether the category is public
+ * "profile": str, the category profile
+ * "order": int, the sort order of rooms in this category
"""
def _get_rooms_for_summary_txn(txn):
@@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
SELECT room_id, is_public, category_id, room_order
FROM group_summary_rooms
WHERE group_id = ?
+ AND room_id IN (
+ SELECT group_rooms.room_id FROM group_rooms
+ LEFT JOIN room_stats_current ON
+ group_rooms.room_id = room_stats_current.room_id
+ AND joined_members > 0
+ AND local_users_in_room > 0
+ LEFT JOIN rooms ON
+ group_rooms.room_id = rooms.room_id
+ AND (room_version <> '') = ?
+ )
"""
if not include_private:
sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
+ txn.execute(sql, (group_id, False, True))
else:
- txn.execute(sql, (group_id,))
+ txn.execute(sql, (group_id, False))
rooms = [
{
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index ba89c68c..4e1642a2 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -17,8 +17,6 @@
import itertools
import logging
-import six
-
from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
@@ -28,12 +26,8 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
+
+db_binary_type = memoryview
class KeyStore(SQLBaseStore):
diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py
new file mode 100644
index 00000000..dad5bbc6
--- /dev/null
+++ b/synapse/storage/data_stores/main/metrics.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from collections import Counter
+
+from twisted.internet import defer
+
+from synapse.metrics import BucketCollector
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.event_push_actions import (
+ EventPushActionsWorkerStore,
+)
+from synapse.storage.database import Database
+
+
+class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
+ """Functions to pull various metrics from the DB, for e.g. phone home
+ stats and prometheus metrics.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ # Collect metrics on the number of forward extremities that exist.
+ # Counter of number of extremities to count
+ self._current_forward_extremities_amount = (
+ Counter()
+ ) # type: typing.Counter[int]
+
+ BucketCollector(
+ "synapse_forward_extremities",
+ lambda: self._current_forward_extremities_amount,
+ buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
+ )
+
+ # Read the extrems every 60 minutes
+ def read_forward_extremities():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "read_forward_extremities", self._read_forward_extremities
+ )
+
+ hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+
+ async def _read_forward_extremities(self):
+ def fetch(txn):
+ txn.execute(
+ """
+ select count(*) c from event_forward_extremities
+ group by room_id
+ """
+ )
+ return txn.fetchall()
+
+ res = await self.db.runInteraction("read_forward_extremities", fetch)
+ self._current_forward_extremities_amount = Counter([x[0] for x in res])
+
+ @defer.inlineCallbacks
+ def count_daily_messages(self):
+ """
+ Returns an estimate of the number of messages sent in the last day.
+
+ If it has been significantly less or more than one day since the last
+ call to this function, it will return None.
+ """
+
+ def _count_messages(txn):
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ ret = yield self.db.runInteraction("count_messages", _count_messages)
+ return ret
+
+ @defer.inlineCallbacks
+ def count_daily_sent_messages(self):
+ def _count_messages(txn):
+ # This is good enough as if you have silly characters in your own
+ # hostname then thats your own fault.
+ like_clause = "%:" + self.hs.hostname
+
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND sender LIKE ?
+ AND stream_ordering > ?
+ """
+
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+ (count,) = txn.fetchone()
+ return count
+
+ ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
+ return ret
+
+ @defer.inlineCallbacks
+ def count_daily_active_rooms(self):
+ def _count(txn):
+ sql = """
+ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
+ return ret
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 925bc569..e459cf49 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -77,20 +78,19 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return self.db.runInteraction("count_users_by_service", _count_users_by_service)
- @defer.inlineCallbacks
- def get_registered_reserved_users(self):
- """Of the reserved threepids defined in config, which are associated
- with registered users?
+ async def get_registered_reserved_users(self) -> List[str]:
+ """Of the reserved threepids defined in config, retrieve those that are associated
+ with registered users
Returns:
- Defered[list]: Real reserved users
+ User IDs of actual users that are reserved
"""
users = []
for tp in self.hs.config.mau_limits_reserved_threepids[
: self.hs.config.max_mau_value
]:
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
)
if user_id:
@@ -122,6 +122,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+ self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+ self._mau_stats_only = hs.config.mau_stats_only
+ self._max_mau_value = hs.config.max_mau_value
+
# Do not add more reserved users than the total allowable number
# cur = LoggingTransaction(
self.db.new_transaction(
@@ -130,7 +134,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
[],
[],
self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
+ hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
@@ -142,6 +146,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
threepids (list[dict]): List of threepid dicts to reserve
"""
+ # XXX what is this function trying to achieve? It upserts into
+ # monthly_active_users for each *registered* reserved mau user, but why?
+ #
+ # - shouldn't there already be an entry for each reserved user (at least
+ # if they have been active recently)?
+ #
+ # - if it's important that the timestamp is kept up to date, why do we only
+ # run this at startup?
+
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@@ -158,13 +171,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
else:
logger.warning("mau limit reserved threepid %s not found in db" % tp)
- @defer.inlineCallbacks
- def reap_monthly_active_users(self):
+ async def reap_monthly_active_users(self):
"""Cleans out monthly active user table to ensure that no stale
entries exist.
-
- Returns:
- Deferred[]
"""
def _reap_users(txn, reserved_users):
@@ -174,76 +183,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- query_args = [thirty_days_ago]
- base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
-
- # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
- # when len(reserved_users) == 0. Works fine on sqlite.
- if len(reserved_users) > 0:
- # questionmarks is a hack to overcome sqlite not supporting
- # tuples in 'WHERE IN %s'
- question_marks = ",".join("?" * len(reserved_users))
-
- query_args.extend(reserved_users)
- sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
- else:
- sql = base_sql
- txn.execute(sql, query_args)
+ in_clause, in_clause_args = make_in_list_sql_clause(
+ self.database_engine, "user_id", reserved_users
+ )
+
+ txn.execute(
+ "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s"
+ % (in_clause,),
+ [thirty_days_ago] + in_clause_args,
+ )
- max_mau_value = self.hs.config.max_mau_value
- if self.hs.config.limit_usage_by_mau:
+ if self._limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
# Note it is not possible to write this query using OFFSET due to
# incompatibilities in how sqlite and postgres support the feature.
- # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
- # While Postgres does not require 'LIMIT', but also does not support
+ # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present,
+ # while Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
- if len(reserved_users) == 0:
- sql = """
- DELETE FROM monthly_active_users
- WHERE user_id NOT IN (
- SELECT user_id FROM monthly_active_users
- ORDER BY timestamp DESC
- LIMIT ?
- )
- """
- txn.execute(sql, (max_mau_value,))
- # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
- # when len(reserved_users) == 0. Works fine on sqlite.
- else:
- # Must be >= 0 for postgres
- num_of_non_reserved_users_to_remove = max(
- max_mau_value - len(reserved_users), 0
- )
- # It is important to filter reserved users twice to guard
- # against the case where the reserved user is present in the
- # SELECT, meaning that a legitmate mau is deleted.
- sql = """
- DELETE FROM monthly_active_users
- WHERE user_id NOT IN (
- SELECT user_id FROM monthly_active_users
- WHERE user_id NOT IN ({})
- ORDER BY timestamp DESC
- LIMIT ?
- )
- AND user_id NOT IN ({})
- """.format(
- question_marks, question_marks
+ # Limit must be >= 0 for postgres
+ num_of_non_reserved_users_to_remove = max(
+ self._max_mau_value - len(reserved_users), 0
+ )
+
+ # It is important to filter reserved users twice to guard
+ # against the case where the reserved user is present in the
+ # SELECT, meaning that a legitimate mau is deleted.
+ sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ WHERE NOT %s
+ ORDER BY timestamp DESC
+ LIMIT ?
)
-
- query_args = [
- *reserved_users,
- num_of_non_reserved_users_to_remove,
- *reserved_users,
- ]
-
- txn.execute(sql, query_args)
-
- # It seems poor to invalidate the whole cache, Postgres supports
+ AND NOT %s
+ """ % (
+ in_clause,
+ in_clause,
+ )
+
+ query_args = (
+ in_clause_args
+ + [num_of_non_reserved_users_to_remove]
+ + in_clause_args
+ )
+ txn.execute(sql, query_args)
+
+ # It seems poor to invalidate the whole cache. Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
# I would need to SELECT and the DELETE which without locking
@@ -255,8 +245,8 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
- reserved_users = yield self.get_registered_reserved_users()
- yield self.db.runInteraction(
+ reserved_users = await self.get_registered_reserved_users()
+ await self.db.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
@@ -267,6 +257,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
Args:
user_id (str): user to add/update
+
+ Returns:
+ Deferred
"""
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
@@ -335,7 +328,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
Args:
user_id(str): the user_id to query
"""
- if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only:
+ if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = yield self.is_guest(user_id)
if is_guest:
@@ -356,11 +349,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# In the case where mau_stats_only is True and limit_usage_by_mau is
# False, there is no point in checking get_monthly_active_count - it
# adds no value and will break the logic if max_mau_value is exceeded.
- if not self.hs.config.limit_usage_by_mau:
+ if not self._limit_usage_by_mau:
yield self.upsert_monthly_active_user(user_id)
else:
count = yield self.get_monthly_active_count()
- if count < self.hs.config.max_mau_value:
+ if count < self._max_mau_value:
yield self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
yield self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index 2b52cf9c..bfc9369f 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
- values={
+ updatevalues={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
new file mode 100644
index 00000000..a93e1ef1
--- /dev/null
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -0,0 +1,399 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Tuple
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.types import RoomStreamToken
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
+ def purge_history(self, room_id, token, delete_local_events):
+ """Deletes room history before a certain point
+
+ Args:
+ room_id (str):
+
+ token (str): A topological token to delete events before
+
+ delete_local_events (bool):
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
+
+ Returns:
+ Deferred[set[int]]: The set of state groups that are referenced by
+ deleted events.
+ """
+
+ return self.db.runInteraction(
+ "purge_history",
+ self._purge_history_txn,
+ room_id,
+ token,
+ delete_local_events,
+ )
+
+ def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
+ token = RoomStreamToken.parse(token_str)
+
+ # Tables that should be pruned:
+ # event_auth
+ # event_backward_extremities
+ # event_edges
+ # event_forward_extremities
+ # event_json
+ # event_push_actions
+ # event_reference_hashes
+ # event_search
+ # event_to_state_groups
+ # events
+ # rejections
+ # room_depth
+ # state_groups
+ # state_groups_state
+
+ # we will build a temporary table listing the events so that we don't
+ # have to keep shovelling the list back and forth across the
+ # connection. Annoyingly the python sqlite driver commits the
+ # transaction on CREATE, so let's do this first.
+ #
+ # furthermore, we might already have the table from a previous (failed)
+ # purge attempt, so let's drop the table first.
+
+ txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+ txn.execute(
+ "CREATE TEMPORARY TABLE events_to_purge ("
+ " event_id TEXT NOT NULL,"
+ " should_delete BOOLEAN NOT NULL"
+ ")"
+ )
+
+ # First ensure that we're not about to delete all the forward extremeties
+ txn.execute(
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "AND e.room_id = f.room_id "
+ "WHERE f.room_id = ?",
+ (room_id,),
+ )
+ rows = txn.fetchall()
+ max_depth = max(row[1] for row in rows)
+
+ if max_depth < token.topological:
+ # We need to ensure we don't delete all the events from the database
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremeties)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremeties"
+ )
+
+ logger.info("[purge] looking for events to delete")
+
+ should_delete_expr = "state_key IS NULL"
+ should_delete_params = () # type: Tuple[Any, ...]
+ if not delete_local_events:
+ should_delete_expr += " AND event_id NOT LIKE ?"
+
+ # We include the parameter twice since we use the expression twice
+ should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
+
+ should_delete_params += (room_id, token.topological)
+
+ # Note that we insert events that are outliers and aren't going to be
+ # deleted, as nothing will happen to them.
+ txn.execute(
+ "INSERT INTO events_to_purge"
+ " SELECT event_id, %s"
+ " FROM events AS e LEFT JOIN state_events USING (event_id)"
+ " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
+ % (should_delete_expr, should_delete_expr),
+ should_delete_params,
+ )
+
+ # We create the indices *after* insertion as that's a lot faster.
+
+ # create an index on should_delete because later we'll be looking for
+ # the should_delete / shouldn't_delete subsets
+ txn.execute(
+ "CREATE INDEX events_to_purge_should_delete"
+ " ON events_to_purge(should_delete)"
+ )
+
+ # We do joins against events_to_purge for e.g. calculating state
+ # groups to purge, etc., so lets make an index.
+ txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
+
+ txn.execute("SELECT event_id, should_delete FROM events_to_purge")
+ event_rows = txn.fetchall()
+ logger.info(
+ "[purge] found %i events before cutoff, of which %i can be deleted",
+ len(event_rows),
+ sum(1 for e in event_rows if e[1]),
+ )
+
+ logger.info("[purge] Finding new backward extremities")
+
+ # We calculate the new entries for the backward extremeties by finding
+ # events to be purged that are pointed to by events we're not going to
+ # purge.
+ txn.execute(
+ "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+ " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+ " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+ " WHERE ep2.event_id IS NULL"
+ )
+ new_backwards_extrems = txn.fetchall()
+
+ logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
+ txn.execute(
+ "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
+ )
+
+ # Update backward extremeties
+ txn.executemany(
+ "INSERT INTO event_backward_extremities (room_id, event_id)"
+ " VALUES (?, ?)",
+ [(room_id, event_id) for event_id, in new_backwards_extrems],
+ )
+
+ logger.info("[purge] finding state groups referenced by deleted events")
+
+ # Get all state groups that are referenced by events that are to be
+ # deleted.
+ txn.execute(
+ """
+ SELECT DISTINCT state_group FROM events_to_purge
+ INNER JOIN event_to_state_groups USING (event_id)
+ """
+ )
+
+ referenced_state_groups = {sg for sg, in txn}
+ logger.info(
+ "[purge] found %i referenced state groups", len(referenced_state_groups)
+ )
+
+ logger.info("[purge] removing events from event_to_state_groups")
+ txn.execute(
+ "DELETE FROM event_to_state_groups "
+ "WHERE event_id IN (SELECT event_id from events_to_purge)"
+ )
+ for event_id, _ in event_rows:
+ txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+
+ # Delete all remote non-state events
+ for table in (
+ "events",
+ "event_json",
+ "event_auth",
+ "event_edges",
+ "event_forward_extremities",
+ "event_reference_hashes",
+ "event_search",
+ "rejections",
+ ):
+ logger.info("[purge] removing events from %s", table)
+
+ txn.execute(
+ "DELETE FROM %s WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,)
+ )
+
+ # event_push_actions lacks an index on event_id, and has one on
+ # (room_id, event_id) instead.
+ for table in ("event_push_actions",):
+ logger.info("[purge] removing events from %s", table)
+
+ txn.execute(
+ "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,),
+ (room_id,),
+ )
+
+ # Mark all state and own events as outliers
+ logger.info("[purge] marking remaining events as outliers")
+ txn.execute(
+ "UPDATE events SET outlier = ?"
+ " WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge "
+ " WHERE NOT should_delete"
+ ")",
+ (True,),
+ )
+
+ # synapse tries to take out an exclusive lock on room_depth whenever it
+ # persists events (because upsert), and once we run this update, we
+ # will block that for the rest of our transaction.
+ #
+ # So, let's stick it at the end so that we don't block event
+ # persistence.
+ #
+ # We do this by calculating the minimum depth of the backwards
+ # extremities. However, the events in event_backward_extremities
+ # are ones we don't have yet so we need to look at the events that
+ # point to it via event_edges table.
+ txn.execute(
+ """
+ SELECT COALESCE(MIN(depth), 0)
+ FROM event_backward_extremities AS eb
+ INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+ INNER JOIN events AS e ON e.event_id = eg.event_id
+ WHERE eb.room_id = ?
+ """,
+ (room_id,),
+ )
+ (min_depth,) = txn.fetchone()
+
+ logger.info("[purge] updating room_depth to %d", min_depth)
+
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (min_depth, room_id),
+ )
+
+ # finally, drop the temp table. this will commit the txn in sqlite,
+ # so make sure to keep this actually last.
+ txn.execute("DROP TABLE events_to_purge")
+
+ logger.info("[purge] done")
+
+ return referenced_state_groups
+
+ def purge_room(self, room_id):
+ """Deletes all record of a room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[List[int]]: The list of state groups to delete.
+ """
+
+ return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+
+ def _purge_room_txn(self, txn, room_id):
+ # First we fetch all the state groups that should be deleted, before
+ # we delete that information.
+ txn.execute(
+ """
+ SELECT DISTINCT state_group FROM events
+ INNER JOIN event_to_state_groups USING(event_id)
+ WHERE events.room_id = ?
+ """,
+ (room_id,),
+ )
+
+ state_groups = [row[0] for row in txn]
+
+ # Now we delete tables which lack an index on room_id but have one on event_id
+ for table in (
+ "event_auth",
+ "event_edges",
+ "event_push_actions_staging",
+ "event_reference_hashes",
+ "event_relations",
+ "event_to_state_groups",
+ "redactions",
+ "rejections",
+ "state_events",
+ ):
+ logger.info("[purge] removing %s from %s", room_id, table)
+
+ txn.execute(
+ """
+ DELETE FROM %s WHERE event_id IN (
+ SELECT event_id FROM events WHERE room_id=?
+ )
+ """
+ % (table,),
+ (room_id,),
+ )
+
+ # and finally, the tables with an index on room_id (or no useful index)
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ # no useful index, but let's clear them anyway
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ "local_current_membership",
+ ):
+ logger.info("[purge] removing %s from %s", room_id, table)
+ txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
+
+ # Other tables we do NOT need to clear out:
+ #
+ # - blocked_rooms
+ # This is important, to make sure that we don't accidentally rejoin a blocked
+ # room after it was purged
+ #
+ # - user_directory
+ # This has a room_id column, but it is unused
+ #
+
+ # Other tables that we might want to consider clearing out include:
+ #
+ # - event_reports
+ # Given that these are intended for abuse management my initial
+ # inclination is to leave them in place.
+ #
+ # - current_state_delta_stream
+ # - ex_outlier_stream
+ # - room_tags_revisions
+ # The problem with these is that they are largeish and there is no room_id
+ # index on them. In any case we should be clearing out 'stream' tables
+ # periodically anyway (#5888)
+
+ # TODO: we could probably usefully do a bunch of cache invalidation here
+
+ logger.info("[purge] done")
+
+ return state_groups
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index b3faafa0..ef8f4095 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -16,19 +16,23 @@
import abc
import logging
+from typing import Union
from canonicaljson import json
from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
from synapse.storage.database import Database
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -64,6 +68,7 @@ class PushRulesWorkerStore(
ReceiptsWorkerStore,
PusherWorkerStore,
RoomMemberWorkerStore,
+ EventsWorkerStore,
SQLBaseStore,
):
"""This is an abstract base class where subclasses must implement
@@ -77,6 +82,15 @@ class PushRulesWorkerStore(
def __init__(self, database: Database, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+ if hs.config.worker.worker_app is None:
+ self._push_rules_stream_id_gen = ChainedIdGenerator(
+ self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ else:
+ self._push_rules_stream_id_gen = SlavedIdTracker(
+ db_conn, "push_rules_stream", "stream_id"
+ )
+
push_rules_prefill, push_rules_id = self.db.get_cache_dict(
db_conn,
"push_rules_stream",
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 0d932a06..cebdcd40 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -391,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
- self.db.simple_delete_txn(
+ self.db.simple_upsert_txn(
txn,
table="receipts_linearized",
keyvalues={
@@ -399,19 +399,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
"receipt_type": receipt_type,
"user_id": user_id,
},
- )
-
- self.db.simple_insert_txn(
- txn,
- table="receipts_linearized",
values={
"stream_id": stream_id,
- "room_id": room_id,
- "receipt_type": receipt_type,
- "user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
},
+ # receipts_linearized has a unique constraint on
+ # (user_id, room_id, receipt_type), so no need to lock
+ lock=False,
)
if receipt_type == "m.read" and stream_ordering is not None:
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index efcdd210..97689818 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -17,6 +17,7 @@
import logging
import re
+from typing import Optional
from six import iterkeys
@@ -342,7 +343,7 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return res
- @cachedInlineCallbacks()
+ @cached()
def is_support_user(self, user_id):
"""Determines if the user is of type UserTypes.SUPPORT
@@ -352,10 +353,9 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
"""
- res = yield self.db.runInteraction(
+ return self.db.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
- return res
def is_real_user_txn(self, txn, user_id):
res = self.db.simple_select_one_onecol_txn(
@@ -516,18 +516,17 @@ class RegistrationWorkerStore(SQLBaseStore):
)
)
- @defer.inlineCallbacks
- def get_user_id_by_threepid(self, medium, address):
+ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid
Args:
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
+ medium: threepid medium e.g. email
+ address: threepid address e.g. me@example.com
Returns:
- Deferred[str|None]: user id or None if no user id/threepid mapping exists
+ The user ID or None if no user id/threepid mapping exists
"""
- user_id = yield self.db.runInteraction(
+ user_id = await self.db.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
return user_id
@@ -993,7 +992,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Args:
user_id (str): The desired user ID to register.
- password_hash (str): Optional. The password hash for this user.
+ password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
@@ -1007,6 +1006,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Raises:
StoreError if the user_id could not be registered.
+
+ Returns:
+ Deferred
"""
return self.db.runInteraction(
"register_user",
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py
index 1c07c7a4..27e5a208 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -21,17 +21,6 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
- def _store_rejections_txn(self, txn, event_id, reason):
- self.db.simple_insert_txn(
- txn,
- table="rejections",
- values={
- "event_id": event_id,
- "reason": reason,
- "last_check": self._clock.time_msec(),
- },
- )
-
def get_rejection_reason(self, event_id):
return self.db.simple_select_one_onecol(
table="rejections",
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
index 046c2b48..7d477f8d 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/data_stores/main/relations.py
@@ -324,62 +324,4 @@ class RelationsWorkerStore(SQLBaseStore):
class RelationsStore(RelationsWorkerStore):
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
-
- Args:
- txn
- event (EventBase)
- """
- relation = event.content.get("m.relates_to")
- if not relation:
- # No relations
- return
-
- rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- ):
- # Unknown relation type
- return
-
- parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
- return
-
- aggregation_key = relation.get("key")
-
- self.db.simple_insert_txn(
- txn,
- table="event_relations",
- values={
- "event_id": event.event_id,
- "relates_to_id": parent_id,
- "relation_type": rel_type,
- "aggregation_key": aggregation_key,
- },
- )
-
- txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
- txn.call_after(
- self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
- )
-
- if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
-
- def _handle_redaction(self, txn, redacted_event_id):
- """Handles receiving a redaction and checking whether we need to remove
- any redacted relations from the database.
-
- Args:
- txn
- redacted_event_id (str): The event that was redacted.
- """
-
- self.db.simple_delete_txn(
- txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
- )
+ pass
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 147eba1d..46f643c6 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -21,8 +21,6 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from six import integer_types
-
from canonicaljson import json
from twisted.internet import defer
@@ -98,6 +96,37 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
+ def get_room_with_stats(self, room_id: str):
+ """Retrieve room with statistics.
+
+ Args:
+ room_id: The ID of the room to retrieve.
+ Returns:
+ A dict containing the room information, or None if the room is unknown.
+ """
+
+ def get_room_with_stats_txn(txn, room_id):
+ sql = """
+ SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
+ curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
+ rooms.creator, state.encryption, state.is_federatable AS federatable,
+ rooms.is_public AS public, state.join_rules, state.guest_access,
+ state.history_visibility, curr.current_state_events AS state_events
+ FROM rooms
+ LEFT JOIN room_stats_state state USING (room_id)
+ LEFT JOIN room_stats_current curr USING (room_id)
+ WHERE room_id = ?
+ """
+ txn.execute(sql, [room_id])
+ res = self.db.cursor_to_dict(txn)[0]
+ res["federatable"] = bool(res["federatable"])
+ res["public"] = bool(res["public"])
+ return res
+
+ return self.db.runInteraction(
+ "get_room_with_stats", get_room_with_stats_txn, room_id
+ )
+
def get_public_room_ids(self):
return self.db.simple_select_onecol(
table="rooms",
@@ -1271,53 +1300,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return self.db.runInteraction("get_rooms", f)
- def _store_room_topic_txn(self, txn, event):
- if hasattr(event, "content") and "topic" in event.content:
- self.store_event_search_txn(
- txn, event, "content.topic", event.content["topic"]
- )
-
- def _store_room_name_txn(self, txn, event):
- if hasattr(event, "content") and "name" in event.content:
- self.store_event_search_txn(
- txn, event, "content.name", event.content["name"]
- )
-
- def _store_room_message_txn(self, txn, event):
- if hasattr(event, "content") and "body" in event.content:
- self.store_event_search_txn(
- txn, event, "content.body", event.content["body"]
- )
-
- def _store_retention_policy_for_room_txn(self, txn, event):
- if hasattr(event, "content") and (
- "min_lifetime" in event.content or "max_lifetime" in event.content
- ):
- if (
- "min_lifetime" in event.content
- and not isinstance(event.content.get("min_lifetime"), integer_types)
- ) or (
- "max_lifetime" in event.content
- and not isinstance(event.content.get("max_lifetime"), integer_types)
- ):
- # Ignore the event if one of the value isn't an integer.
- return
-
- self.db.simple_insert_txn(
- txn=txn,
- table="room_retention",
- values={
- "room_id": event.room_id,
- "event_id": event.event_id,
- "min_lifetime": event.content.get("min_lifetime"),
- "max_lifetime": event.content.get("max_lifetime"),
- },
- )
-
- self._invalidate_cache_and_stream(
- txn, self.get_retention_policy_for_room, (event.room_id,)
- )
-
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index e626b7f6..137ebac8 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.metrics import Measure
-from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@@ -153,16 +152,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn,
)
- @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
- def get_hosts_in_room(self, room_id, cache_context):
- """Returns the set of all hosts currently in the room
- """
- user_ids = yield self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
- return hosts
-
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
return self.db.runInteraction(
@@ -189,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (room_id, Membership.JOIN))
- return [to_ascii(r[0]) for r in txn]
+ return [r[0] for r in txn]
@cached(max_entries=100000)
def get_room_summary(self, room_id):
@@ -233,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (room_id,))
res = {}
for count, membership in txn:
- summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+ summary = res.setdefault(membership, MemberSummary([], count))
# we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent
@@ -265,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
for user_id, membership, event_id in txn:
- summary = res[to_ascii(membership)]
+ summary = res[membership]
# we will always have a summary for this membership type at this
# point given the summary currently contains the counts.
members = summary.members
- members.append((to_ascii(user_id), to_ascii(event_id)))
+ members.append((user_id, event_id))
return res
@@ -594,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
ev_entry = event_map.get(event_id)
if ev_entry:
if ev_entry.event.membership == Membership.JOIN:
- users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
- display_name=to_ascii(
- ev_entry.event.content.get("displayname", None)
- ),
- avatar_url=to_ascii(
- ev_entry.event.content.get("avatar_url", None)
- ),
+ users_in_room[ev_entry.event.state_key] = ProfileInfo(
+ display_name=ev_entry.event.content.get("displayname", None),
+ avatar_url=ev_entry.event.content.get("avatar_url", None),
)
else:
missing_member_event_ids.append(event_id)
@@ -614,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
- users_in_room[to_ascii(event.state_key)] = ProfileInfo(
- display_name=to_ascii(event.content.get("displayname", None)),
- avatar_url=to_ascii(event.content.get("avatar_url", None)),
+ users_in_room[event.state_key] = ProfileInfo(
+ display_name=event.content.get("displayname", None),
+ avatar_url=event.content.get("avatar_url", None),
)
return users_in_room
@@ -1061,119 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self.db.simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ],
- )
-
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key,
- event.internal_metadata.stream_ordering,
- )
- txn.call_after(
- self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
- )
-
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened. If the event is an
- # outlier it is only current if its an "out of band membership",
- # like a remote invite or a rejection of a remote invite.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_out_of_band_membership()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self.db.simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- },
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(
- sql,
- (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ),
- )
-
- # We also update the `local_current_membership` table with
- # latest invite info. This will usually get updated by the
- # `current_state_events` handling, unless its an outlier.
- if event.internal_metadata.is_outlier():
- # This should only happen for out of band memberships, so
- # we add a paranoia check.
- assert event.internal_metadata.is_out_of_band_membership()
-
- self.db.simple_upsert_txn(
- txn,
- table="local_current_membership",
- keyvalues={
- "room_id": event.room_id,
- "user_id": event.state_key,
- },
- values={
- "event_id": event.event_id,
- "membership": event.membership,
- },
- )
-
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
- # We also clear this entry from `local_current_membership`.
- # Ideally we'd point to a leave event, but we don't have one, so
- # nevermind.
- self.db.simple_delete_txn(
- txn,
- table="local_current_membership",
- keyvalues={"room_id": room_id, "user_id": user_id},
- )
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
new file mode 100644
index 00000000..aa46eb0e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We keep the old table here to enable us to roll back. It doesn't matter
+-- that we have dropped all the data here.
+TRUNCATE cache_invalidation_stream;
+
+CREATE TABLE cache_invalidation_stream_by_instance (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ cache_func TEXT NOT NULL,
+ keys TEXT[],
+ invalidation_ts BIGINT
+);
+
+CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
+
+CREATE SEQUENCE cache_invalidation_stream_seq;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
new file mode 100644
index 00000000..d353f2bc
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
@@ -0,0 +1,80 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration rebuilds the device_lists_outbound_last_success table without duplicate
+entries, and with a UNIQUE index.
+"""
+
+import logging
+from io import StringIO
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
+from synapse.storage.types import Cursor
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(*args, **kwargs):
+ pass
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+ # some instances might already have this index, in which case we can skip this
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute(
+ """
+ SELECT 1 FROM pg_class WHERE relkind = 'i'
+ AND relname = 'device_lists_outbound_last_success_unique_idx'
+ """
+ )
+
+ if cur.rowcount:
+ logger.info(
+ "Unique index exists on device_lists_outbound_last_success: "
+ "skipping rebuild"
+ )
+ return
+
+ logger.info("Rebuilding device_lists_outbound_last_success with unique index")
+ execute_statements_from_stream(cur, StringIO(_rebuild_commands))
+
+
+# there might be duplicates, so the easiest way to achieve this is to create a new
+# table with the right data, and renaming it into place
+
+_rebuild_commands = """
+DROP TABLE IF EXISTS device_lists_outbound_last_success_new;
+
+CREATE TABLE device_lists_outbound_last_success_new (
+ destination TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+-- this took about 30 seconds on matrix.org's 16 million rows.
+INSERT INTO device_lists_outbound_last_success_new
+ SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success
+ GROUP BY destination, user_id;
+
+-- and this another 30 seconds.
+CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx
+ ON device_lists_outbound_last_success_new (destination, user_id);
+
+DROP TABLE device_lists_outbound_last_success;
+
+ALTER TABLE device_lists_outbound_last_success_new
+ RENAME TO device_lists_outbound_last_success;
+"""
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 47ebb8a2..13f49d80 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -37,7 +37,55 @@ SearchEntry = namedtuple(
)
-class SearchBackgroundUpdateStore(SQLBaseStore):
+class SearchWorkerStore(SQLBaseStore):
+ def store_search_entries_txn(self, txn, entries):
+ """Add entries to the search table
+
+ Args:
+ txn (cursor):
+ entries (iterable[SearchEntry]):
+ entries to be added to the table
+ """
+ if not self.hs.config.enable_search:
+ return
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = (
+ "INSERT INTO event_search"
+ " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+ " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+ )
+
+ args = (
+ (
+ entry.event_id,
+ entry.room_id,
+ entry.key,
+ entry.value,
+ entry.stream_ordering,
+ entry.origin_server_ts,
+ )
+ for entry in entries
+ )
+
+ txn.executemany(sql, args)
+
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = (
+ "INSERT INTO event_search (event_id, room_id, key, value)"
+ " VALUES (?,?,?,?)"
+ )
+ args = (
+ (entry.event_id, entry.room_id, entry.key, entry.value)
+ for entry in entries
+ )
+
+ txn.executemany(sql, args)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+
+class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@@ -296,80 +344,11 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
return num_rows
- def store_search_entries_txn(self, txn, entries):
- """Add entries to the search table
-
- Args:
- txn (cursor):
- entries (iterable[SearchEntry]):
- entries to be added to the table
- """
- if not self.hs.config.enable_search:
- return
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search"
- " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
- " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
- )
-
- args = (
- (
- entry.event_id,
- entry.room_id,
- entry.key,
- entry.value,
- entry.stream_ordering,
- entry.origin_server_ts,
- )
- for entry in entries
- )
-
- txn.executemany(sql, args)
-
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- args = (
- (entry.event_id, entry.room_id, entry.key, entry.value)
- for entry in entries
- )
-
- txn.executemany(sql, args)
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
super(SearchStore, self).__init__(database, db_conn, hs)
- def store_event_search_txn(self, txn, event, key, value):
- """Add event to the search table
-
- Args:
- txn (cursor):
- event (EventBase):
- key (str):
- value (str):
- """
- self.store_search_entries_txn(
- txn,
- (
- SearchEntry(
- key=key,
- value=value,
- event_id=event.event_id,
- room_id=event.room_id,
- stream_ordering=event.internal_metadata.stream_ordering,
- origin_server_ts=event.origin_server_ts,
- ),
- ),
- )
-
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py
index 563216b6..36244d9f 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -13,23 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
-
from unpaddedbase64 import encode_base64
from twisted.internet import defer
-from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
-
class SignatureWorkerStore(SQLBaseStore):
@cached()
@@ -79,23 +69,3 @@ class SignatureWorkerStore(SQLBaseStore):
class SignatureStore(SignatureWorkerStore):
"""Persistence for event signatures and hashes"""
-
- def _store_event_reference_hashes_txn(self, txn, events):
- """Store a hash for a PDU
- Args:
- txn (cursor):
- events (list): list of Events.
- """
-
- vals = []
- for event in events:
- ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
- vals.append(
- {
- "event_id": event.event_id,
- "algorithm": ref_alg,
- "hash": db_binary_type(ref_hash_bytes),
- }
- )
-
- self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 3a3b9a8e..347cc507 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -16,17 +16,12 @@
import collections.abc
import logging
from collections import namedtuple
-from typing import Iterable, Tuple
-
-from six import iteritems
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@@ -34,7 +29,6 @@ from synapse.storage.database import Database
from synapse.storage.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@@ -190,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
(room_id,),
)
- return {
- (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
- }
+ return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return self.db.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
@@ -473,33 +465,3 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
super(StateStore, self).__init__(database, db_conn, hs)
-
- def _store_event_state_mappings_txn(
- self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
- ):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
-
- # if the event was rejected, just give it the same state as its
- # predecessor.
- if context.rejected:
- state_groups[event.event_id] = context.state_group_before_event
- continue
-
- state_groups[event.event_id] = context.state_group
-
- self.db.simple_insert_many_txn(
- txn,
- table="event_to_state_groups",
- values=[
- {"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in iteritems(state_groups)
- ],
- )
-
- for event_id, state_group_id in iteritems(state_groups):
- txn.call_after(
- self._get_state_group_for_event.prefill, (event_id,), state_group_id
- )
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 2aa1bafd..42190183 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -233,6 +233,9 @@ class TagsStore(TagsWorkerStore):
self._account_data_stream_cache.entity_has_changed, user_id, next_id
)
+ # Note: This is only here for backwards compat to allow admins to
+ # roll back to a previous Synapse version. Next time we update the
+ # database version we can remove this table.
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index 5b07c2fb..a9bf4579 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -16,8 +16,6 @@
import logging
from collections import namedtuple
-import six
-
from canonicaljson import encode_canonical_json
from twisted.internet import defer
@@ -27,12 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
+db_binary_type = memoryview
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
index e8edaf9f..ff000bc9 100644
--- a/synapse/storage/data_stores/state/bg_updates.py
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -109,20 +109,20 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
- SELECT DISTINCT type, state_key, last_value(event_id) OVER (
- PARTITION BY type, state_key ORDER BY state_group ASC
- ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
- ) AS event_id FROM state_groups_state
+ SELECT DISTINCT ON (type, state_key)
+ type, state_key, event_id
+ FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
- )
+ ) %s
+ ORDER BY type, state_key, state_group DESC
"""
for group in groups:
args = [group]
args.extend(where_args)
- txn.execute(sql + where_clause, args)
+ txn.execute(sql % (where_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 57a52676..f3ad1e43 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
from synapse.types import StateMap
-from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_cache = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
- 50000 * get_cache_factor_for("stateGroupCache"),
+ 50000,
)
self._state_group_members_cache = DictionaryCache(
- "*stateGroupMembersCache*",
- 500000 * get_cache_factor_for("stateGroupMembersCache"),
+ "*stateGroupMembersCache*", 500000,
)
@cached(max_entries=10000, iterable=True)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 50f475bf..b112ff3d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,7 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
-from synapse.util.stringutils import exception_to_unicode
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -422,20 +422,14 @@ class Database(object):
# This can happen if the database disappears mid
# transaction.
logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d",
- name,
- exception_to_unicode(e),
- i,
- N,
+ "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
)
if i < N:
i += 1
try:
conn.rollback()
except self.engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
- )
+ logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue
raise
except self.engine.module.DatabaseError as e:
@@ -447,9 +441,7 @@ class Database(object):
conn.rollback()
except self.engine.module.Error as e1:
logger.warning(
- "[TXN EROLL] {%s} %s",
- name,
- exception_to_unicode(e1),
+ "[TXN EROLL] {%s} %s", name, e1,
)
continue
raise
@@ -889,20 +881,24 @@ class Database(object):
txn.execute(sql, list(allvalues.values()))
def simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +910,24 @@ class Database(object):
)
def simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Iterable[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
@@ -941,20 +941,24 @@ class Database(object):
self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
def simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ ) -> None:
"""
Upsert, many times, using batching where possible.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
allnames = [] # type: List[str]
allnames.extend(key_names)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 0f9ac1cf..f159400a 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -23,7 +23,6 @@ from typing import Iterable, List, Optional, Set, Tuple
from six import iteritems
from six.moves import range
-import attr
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -35,6 +34,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main.events import DeltaState
from synapse.types import StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -73,22 +73,6 @@ stale_forward_extremities_counter = Histogram(
)
-@attr.s(slots=True)
-class DeltaState:
- """Deltas to use to update the `current_state_events` table.
-
- Attributes:
- to_delete: List of type/state_keys to delete from current state
- to_insert: Map of state to upsert into current state
- no_longer_in_room: The server is not longer in the room, so the room
- should e.g. be removed from `current_state_events` table.
- """
-
- to_delete = attr.ib(type=List[Tuple[str, str]])
- to_insert = attr.ib(type=StateMap[str])
- no_longer_in_room = attr.ib(type=bool, default=False)
-
-
class _EventPeristenceQueue(object):
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
@@ -205,6 +189,7 @@ class EventsPersistenceStorage(object):
# store for now.
self.main_store = stores.main
self.state_store = stores.state
+ self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -445,7 +430,7 @@ class EventsPersistenceStorage(object):
if current_state is not None:
current_state_for_room[room_id] = current_state
- await self.main_store._persist_events_and_state_updates(
+ await self.persist_events_store._persist_events_and_state_updates(
chunk,
current_state_for_room=current_state_for_room,
state_delta_for_room=state_delta_for_room,
@@ -491,13 +476,15 @@ class EventsPersistenceStorage(object):
)
# Remove any events which are prev_events of any existing events.
- existing_prevs = await self.main_store._get_events_which_are_prevs(result)
+ existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
+ result
+ )
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
# events. If they do we need to remove them and their prev events,
# otherwise we end up with dangling extremities.
- existing_prevs = await self.main_store._get_prevs_before_rejected(
+ existing_prevs = await self.persist_events_store._get_prevs_before_rejected(
e_id for event in new_events for e_id in event.prev_event_ids()
)
result.difference_update(existing_prevs)
@@ -753,8 +740,8 @@ class EventsPersistenceStorage(object):
# whose state has changed as we've already their new state above.
users_to_ignore = [
state_key
- for _, state_key in itertools.chain(delta.to_insert, delta.to_delete)
- if self.is_mine_id(state_key)
+ for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete)
+ if typ == EventTypes.Member and self.is_mine_id(state_key)
]
if await self.main_store.is_local_host_in_room_ignoring_users(
@@ -799,3 +786,9 @@ class EventsPersistenceStorage(object):
for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+ async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+ """Mark the invite has having been rejected even though we failed to
+ create a leave event for it.
+ """
+ return await self.persist_events_store.locally_reject_invite(user_id, room_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1712932f..9cc3b51f 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,16 +19,21 @@ import logging
import os
import re
from collections import Counter
+from typing import TextIO
import attr
from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.types import Cursor
logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
+# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
SCHEMA_VERSION = 58
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -362,9 +367,8 @@ def _upgrade_existing_database(
if duplicates:
# We don't support using the same file name in the same delta version.
raise PrepareDatabaseException(
- "Found multiple delta files with the same name in v%d: %s",
- v,
- duplicates,
+ "Found multiple delta files with the same name in v%d: %s"
+ % (v, duplicates,)
)
# We sort to ensure that we apply the delta files in a consistent
@@ -477,8 +481,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
)
logger.info("applying schema %s for %s", name, modname)
- for statement in get_statements(stream):
- cur.execute(statement)
+ execute_statements_from_stream(cur, stream)
# Mark as done.
cur.execute(
@@ -536,8 +539,12 @@ def get_statements(f):
def executescript(txn, schema_path):
with open(schema_path, "r") as f:
- for statement in get_statements(f):
- txn.execute(statement)
+ execute_statements_from_stream(txn, f)
+
+
+def execute_statements_from_stream(cur: Cursor, f: TextIO):
+ for statement in get_statements(f):
+ cur.execute(statement)
def _get_or_create_schema_state(txn, database_engine):
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851bea..f89ce0be 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
import contextlib
import threading
from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
"""
@@ -161,9 +166,10 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
+ self._table = table
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
@@ -198,3 +204,173 @@ class ChainedIdGenerator(object):
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
+
+ def advance(self, token: int):
+ """Stub implementation for advancing the token when receiving updates
+ over replication; raises an exception as this instance should be the
+ only source of updates.
+ """
+
+ raise Exception(
+ "Attempted to advance token on source for table %r", self._table
+ )
+
+
+class MultiWriterIdGenerator:
+ """An ID generator that tracks a stream that can have multiple writers.
+
+ Uses a Postgres sequence to coordinate ID assignment, but positions of other
+ writers will only get updated when `advance` is called (by replication).
+
+ Note: Only works with Postgres.
+
+ Args:
+ db_conn
+ db
+ instance_name: The name of this instance.
+ table: Database table associated with stream.
+ instance_column: Column that stores the row's writer's instance name
+ id_column: Column that stores the stream ID.
+ sequence_name: The name of the postgres sequence used to generate new
+ IDs.
+ """
+
+ def __init__(
+ self,
+ db_conn,
+ db: Database,
+ instance_name: str,
+ table: str,
+ instance_column: str,
+ id_column: str,
+ sequence_name: str,
+ ):
+ self._db = db
+ self._instance_name = instance_name
+ self._sequence_name = sequence_name
+
+ # We lock as some functions may be called from DB threads.
+ self._lock = threading.Lock()
+
+ self._current_positions = self._load_current_ids(
+ db_conn, table, instance_column, id_column
+ )
+
+ # Set of local IDs that we're still processing. The current position
+ # should be less than the minimum of this set (if not empty).
+ self._unfinished_ids = set() # type: Set[int]
+
+ def _load_current_ids(
+ self, db_conn, table: str, instance_column: str, id_column: str
+ ) -> Dict[str, int]:
+ sql = """
+ SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ GROUP BY %(instance)s
+ """ % {
+ "instance": instance_column,
+ "id": id_column,
+ "table": table,
+ }
+
+ cur = db_conn.cursor()
+ cur.execute(sql)
+
+ # `cur` is an iterable over returned rows, which are 2-tuples.
+ current_positions = dict(cur)
+
+ cur.close()
+
+ return current_positions
+
+ def _load_next_id_txn(self, txn):
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ (next_id,) = txn.fetchone()
+ return next_id
+
+ async def get_next(self):
+ """
+ Usage:
+ with await stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+ # Assert the fetched ID is actually greater than what we currently
+ # believe the ID to be. If not, then the sequence and table have got
+ # out of sync somehow.
+ assert self.get_current_token() < next_id
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_id
+ finally:
+ self._mark_id_as_finished(next_id)
+
+ return manager()
+
+ def get_next_txn(self, txn: LoggingTransaction):
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next(txn)
+ # ... persist event ...
+ """
+
+ next_id = self._load_next_id_txn(txn)
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ txn.call_after(self._mark_id_as_finished, next_id)
+ txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+ return next_id
+
+ def _mark_id_as_finished(self, next_id: int):
+ """The ID has finished being processed so we should advance the
+ current poistion if possible.
+ """
+
+ with self._lock:
+ self._unfinished_ids.discard(next_id)
+
+ # Figure out if its safe to advance the position by checking there
+ # aren't any lower allocated IDs that are yet to finish.
+ if all(c > next_id for c in self._unfinished_ids):
+ curr = self._current_positions.get(self._instance_name, 0)
+ self._current_positions[self._instance_name] = max(curr, next_id)
+
+ def get_current_token(self, instance_name: str = None) -> int:
+ """Gets the current position of a named writer (defaults to current
+ instance).
+
+ Returns 0 if we don't have a position for the named writer (likely due
+ to it being a new writer).
+ """
+
+ if instance_name is None:
+ instance_name = self._instance_name
+
+ with self._lock:
+ return self._current_positions.get(instance_name, 0)
+
+ def get_positions(self) -> Dict[str, int]:
+ """Get a copy of the current positon map.
+ """
+
+ with self._lock:
+ return dict(self._current_positions)
+
+ def advance(self, instance_name: str, new_id: int):
+ """Advance the postion of the named writer to the given ID, if greater
+ than existing entry.
+ """
+
+ with self._lock:
+ self._current_positions[instance_name] = max(
+ new_id, self._current_positions.get(instance_name, 0)
+ )
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 581dffd8..f7af2bca 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -225,6 +225,18 @@ class Linearizer(object):
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+ def is_queued(self, key) -> bool:
+ """Checks whether there is a process queued up waiting
+ """
+ entry = self.key_to_defer.get(key)
+ if not entry:
+ # No entry so nothing is waiting.
+ return False
+
+ # There are waiting deferreds only in the OrderedDict of deferreds is
+ # non-empty.
+ return bool(entry[1])
+
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index da5077b4..dd356bf1 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019, 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,28 +15,16 @@
# limitations under the License.
import logging
-import os
-from typing import Dict
+from sys import intern
+from typing import Callable, Dict, Optional
-import six
-from six.moves import intern
+import attr
+from prometheus_client.core import Gauge
-from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily
+from synapse.config.cache import add_resizable_cache
logger = logging.getLogger(__name__)
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
-
-
-def get_cache_factor_for(cache_name):
- env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper()
- factor = os.environ.get(env_var)
- if factor:
- return float(factor)
-
- return CACHE_SIZE_FACTOR
-
-
caches_by_name = {}
collectors_by_name = {} # type: Dict
@@ -44,6 +32,7 @@ cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
+cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
@@ -53,67 +42,82 @@ response_cache_evicted = Gauge(
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
-def register_cache(cache_type, cache_name, cache, collect_callback=None):
- """Register a cache object for metric collection.
+@attr.s
+class CacheMetric(object):
+
+ _cache = attr.ib()
+ _cache_type = attr.ib(type=str)
+ _cache_name = attr.ib(type=str)
+ _collect_callback = attr.ib(type=Optional[Callable])
+
+ hits = attr.ib(default=0)
+ misses = attr.ib(default=0)
+ evicted_size = attr.ib(default=0)
+
+ def inc_hits(self):
+ self.hits += 1
+
+ def inc_misses(self):
+ self.misses += 1
+
+ def inc_evictions(self, size=1):
+ self.evicted_size += size
+
+ def describe(self):
+ return []
+
+ def collect(self):
+ try:
+ if self._cache_type == "response_cache":
+ response_cache_size.labels(self._cache_name).set(len(self._cache))
+ response_cache_hits.labels(self._cache_name).set(self.hits)
+ response_cache_evicted.labels(self._cache_name).set(self.evicted_size)
+ response_cache_total.labels(self._cache_name).set(
+ self.hits + self.misses
+ )
+ else:
+ cache_size.labels(self._cache_name).set(len(self._cache))
+ cache_hits.labels(self._cache_name).set(self.hits)
+ cache_evicted.labels(self._cache_name).set(self.evicted_size)
+ cache_total.labels(self._cache_name).set(self.hits + self.misses)
+ if getattr(self._cache, "max_size", None):
+ cache_max_size.labels(self._cache_name).set(self._cache.max_size)
+ if self._collect_callback:
+ self._collect_callback()
+ except Exception as e:
+ logger.warning("Error calculating metrics for %s: %s", self._cache_name, e)
+ raise
+
+
+def register_cache(
+ cache_type: str,
+ cache_name: str,
+ cache,
+ collect_callback: Optional[Callable] = None,
+ resizable: bool = True,
+ resize_callback: Optional[Callable] = None,
+) -> CacheMetric:
+ """Register a cache object for metric collection and resizing.
Args:
- cache_type (str):
- cache_name (str): name of the cache
- cache (object): cache itself
- collect_callback (callable|None): if not None, a function which is called during
- metric collection to update additional metrics.
+ cache_type
+ cache_name: name of the cache
+ cache: cache itself
+ collect_callback: If given, a function which is called during metric
+ collection to update additional metrics.
+ resizable: Whether this cache supports being resized.
+ resize_callback: A function which can be called to resize the cache.
Returns:
CacheMetric: an object which provides inc_{hits,misses,evictions} methods
"""
+ if resizable:
+ if not resize_callback:
+ resize_callback = getattr(cache, "set_cache_factor")
+ add_resizable_cache(cache_name, resize_callback)
- # Check if the metric is already registered. Unregister it, if so.
- # This usually happens during tests, as at runtime these caches are
- # effectively singletons.
+ metric = CacheMetric(cache, cache_type, cache_name, collect_callback)
metric_name = "cache_%s_%s" % (cache_type, cache_name)
- if metric_name in collectors_by_name.keys():
- REGISTRY.unregister(collectors_by_name[metric_name])
-
- class CacheMetric(object):
-
- hits = 0
- misses = 0
- evicted_size = 0
-
- def inc_hits(self):
- self.hits += 1
-
- def inc_misses(self):
- self.misses += 1
-
- def inc_evictions(self, size=1):
- self.evicted_size += size
-
- def describe(self):
- return []
-
- def collect(self):
- try:
- if cache_type == "response_cache":
- response_cache_size.labels(cache_name).set(len(cache))
- response_cache_hits.labels(cache_name).set(self.hits)
- response_cache_evicted.labels(cache_name).set(self.evicted_size)
- response_cache_total.labels(cache_name).set(self.hits + self.misses)
- else:
- cache_size.labels(cache_name).set(len(cache))
- cache_hits.labels(cache_name).set(self.hits)
- cache_evicted.labels(cache_name).set(self.evicted_size)
- cache_total.labels(cache_name).set(self.hits + self.misses)
- if collect_callback:
- collect_callback()
- except Exception as e:
- logger.warning("Error calculating metrics for %s: %s", cache_name, e)
- raise
-
- yield GaugeMetricFamily("__unused", "")
-
- metric = CacheMetric()
- REGISTRY.register(metric)
caches_by_name[cache_name] = cache
collectors_by_name[metric_name] = metric
return metric
@@ -148,9 +152,6 @@ def intern_string(string):
return None
try:
- if six.PY2:
- string = string.encode("ascii")
-
return intern(string)
except UnicodeEncodeError:
return string
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 2e8f6543..cd482624 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import functools
import inspect
import logging
@@ -30,7 +31,6 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -81,7 +81,6 @@ class CacheEntry(object):
class Cache(object):
__slots__ = (
"cache",
- "max_entries",
"name",
"keylen",
"thread",
@@ -89,7 +88,29 @@ class Cache(object):
"_pending_deferred_cache",
)
- def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
+ def __init__(
+ self,
+ name: str,
+ max_entries: int = 1000,
+ keylen: int = 1,
+ tree: bool = False,
+ iterable: bool = False,
+ apply_cache_factor_from_config: bool = True,
+ ):
+ """
+ Args:
+ name: The name of the cache
+ max_entries: Maximum amount of entries that the cache will hold
+ keylen: The length of the tuple used as the cache key
+ tree: Use a TreeCache instead of a dict as the underlying cache type
+ iterable: If True, count each item in the cached object as an entry,
+ rather than each cached object
+ apply_cache_factor_from_config: Whether cache factors specified in the
+ config file affect `max_entries`
+
+ Returns:
+ Cache
+ """
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
@@ -99,6 +120,7 @@ class Cache(object):
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
+ apply_cache_factor_from_config=apply_cache_factor_from_config,
)
self.name = name
@@ -111,6 +133,10 @@ class Cache(object):
collect_callback=self._metrics_collection_callback,
)
+ @property
+ def max_entries(self):
+ return self.cache.max_size
+
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
@@ -370,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_context=cache_context,
)
- max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
-
self.max_entries = max_entries
self.tree = tree
self.iterable = iterable
- def __get__(self, obj, objtype=None):
+ def __get__(self, obj, owner):
cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index cddf1ed5..2726b67b 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -18,6 +18,7 @@ from collections import OrderedDict
from six import iteritems, itervalues
+from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import register_cache
@@ -51,15 +52,16 @@ class ExpiringCache(object):
an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
-
"""
self._cache_name = cache_name
+ self._original_max_size = max_len
+
+ self._max_size = int(max_len * cache_config.properties.default_factor_size)
+
self._clock = clock
- self._max_len = max_len
self._expiry_ms = expiry_ms
-
self._reset_expiry_on_get = reset_expiry_on_get
self._cache = OrderedDict()
@@ -82,9 +84,11 @@ class ExpiringCache(object):
def __setitem__(self, key, value):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
+ self.evict()
+ def evict(self):
# Evict if there are now too many items
- while self._max_len and len(self) > self._max_len:
+ while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self.metrics.inc_evictions(len(value.value))
@@ -170,6 +174,23 @@ class ExpiringCache(object):
else:
return len(self._cache)
+ def set_cache_factor(self, factor: float) -> bool:
+ """
+ Set the cache factor for this individual cache.
+
+ This will trigger a resize if it changes, which may require evicting
+ items from the cache.
+
+ Returns:
+ bool: Whether the cache changed size or not.
+ """
+ new_size = int(self._original_max_size * factor)
+ if new_size != self._max_size:
+ self._max_size = new_size
+ self.evict()
+ return True
+ return False
+
class _CacheEntry(object):
__slots__ = ["time", "value"]
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 1536cb64..df4ea590 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import threading
from functools import wraps
+from typing import Callable, Optional, Type, Union
+from synapse.config import cache as cache_config
from synapse.util.caches.treecache import TreeCache
@@ -52,17 +53,18 @@ class LruCache(object):
def __init__(
self,
- max_size,
- keylen=1,
- cache_type=dict,
- size_callback=None,
- evicted_callback=None,
+ max_size: int,
+ keylen: int = 1,
+ cache_type: Type[Union[dict, TreeCache]] = dict,
+ size_callback: Optional[Callable] = None,
+ evicted_callback: Optional[Callable] = None,
+ apply_cache_factor_from_config: bool = True,
):
"""
Args:
- max_size (int):
+ max_size: The maximum amount of entries the cache can hold
- keylen (int):
+ keylen: The length of the tuple used as the cache key
cache_type (type):
type of underlying cache to be used. Typically one of dict
@@ -73,9 +75,24 @@ class LruCache(object):
evicted_callback (func(int)|None):
if not None, called on eviction with the size of the evicted
entry
+
+ apply_cache_factor_from_config (bool): If true, `max_size` will be
+ multiplied by a cache factor derived from the homeserver config
"""
cache = cache_type()
self.cache = cache # Used for introspection.
+ self.apply_cache_factor_from_config = apply_cache_factor_from_config
+
+ # Save the original max size, and apply the default size factor.
+ self._original_max_size = max_size
+ # We previously didn't apply the cache factor here, and as such some caches were
+ # not affected by the global cache factor. Add an option here to disable applying
+ # the cache factor when a cache is created
+ if apply_cache_factor_from_config:
+ self.max_size = int(max_size * cache_config.properties.default_factor_size)
+ else:
+ self.max_size = int(max_size)
+
list_root = _Node(None, None, None, None)
list_root.next_node = list_root
list_root.prev_node = list_root
@@ -83,7 +100,7 @@ class LruCache(object):
lock = threading.Lock()
def evict():
- while cache_len() > max_size:
+ while cache_len() > self.max_size:
todelete = list_root.prev_node
evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
@@ -236,6 +253,7 @@ class LruCache(object):
return key in cache
self.sentinel = object()
+ self._on_resize = evict
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
@@ -266,3 +284,23 @@ class LruCache(object):
def __contains__(self, key):
return self.contains(key)
+
+ def set_cache_factor(self, factor: float) -> bool:
+ """
+ Set the cache factor for this individual cache.
+
+ This will trigger a resize if it changes, which may require evicting
+ items from the cache.
+
+ Returns:
+ bool: Whether the cache changed size or not.
+ """
+ if not self.apply_cache_factor_from_config:
+ return False
+
+ new_size = int(self._original_max_size * factor)
+ if new_size != self.max_size:
+ self.max_size = new_size
+ self._on_resize()
+ return True
+ return False
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index b68f9fe0..a6c60888 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -38,7 +38,7 @@ class ResponseCache(object):
self.timeout_sec = timeout_ms / 1000.0
self._name = name
- self._metrics = register_cache("response_cache", name, self)
+ self._metrics = register_cache("response_cache", name, self, resizable=False)
def size(self):
return len(self.pending_result_cache)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index e54f80d7..2a161bf2 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+import math
from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
from six import integer_types
@@ -46,7 +47,8 @@ class StreamChangeCache:
max_size=10000,
prefilled_cache: Optional[Mapping[EntityType, int]] = None,
):
- self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
+ self._original_max_size = max_size
+ self._max_size = math.floor(max_size)
self._entity_to_key = {} # type: Dict[EntityType, int]
# map from stream id to the a set of entities which changed at that stream id.
@@ -58,12 +60,31 @@ class StreamChangeCache:
#
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- self.metrics = caches.register_cache("cache", self.name, self._cache)
+ self.metrics = caches.register_cache(
+ "cache", self.name, self._cache, resize_callback=self.set_cache_factor
+ )
if prefilled_cache:
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
+ def set_cache_factor(self, factor: float) -> bool:
+ """
+ Set the cache factor for this individual cache.
+
+ This will trigger a resize if it changes, which may require evicting
+ items from the cache.
+
+ Returns:
+ bool: Whether the cache changed size or not.
+ """
+ new_size = math.floor(self._original_max_size * factor)
+ if new_size != self._max_size:
+ self.max_size = new_size
+ self._evict()
+ return True
+ return False
+
def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
"""Returns True if the entity may have been updated since stream_pos
"""
@@ -171,6 +192,7 @@ class StreamChangeCache:
e1 = self._cache[stream_pos] = set()
e1.add(entity)
self._entity_to_key[entity] = stream_pos
+ self._evict()
# if the cache is too big, remove entries
while len(self._cache) > self._max_size:
@@ -179,6 +201,13 @@ class StreamChangeCache:
for entity in r:
del self._entity_to_key[entity]
+ def _evict(self):
+ while len(self._cache) > self._max_size:
+ k, r = self._cache.popitem(0)
+ self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+ for entity in r:
+ self._entity_to_key.pop(entity, None)
+
def get_max_pos_of_last_change(self, entity: EntityType) -> int:
"""Returns an upper bound of the stream id of the last change to an
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 99646c7c..6437aa90 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -38,7 +38,7 @@ class TTLCache(object):
self._timer = timer
- self._metrics = register_cache("ttl", cache_name, self)
+ self._metrics = register_cache("ttl", cache_name, self, resizable=False)
def set(self, key, value, ttl):
"""Add/update an entry in the cache
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index f2ccd5e7..9815bb86 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -65,5 +65,5 @@ def _handle_frozendict(obj):
)
-# A JSONEncoder which is capable of encoding frozendics without barfing
+# A JSONEncoder which is capable of encoding frozendicts without barfing
frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict)
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index fdff1957..2605f3c6 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -186,10 +186,15 @@ def _check_yield_points(f: Callable, changes: List[str]):
)
raise Exception(err)
+ # the wrapped function yielded a Deferred: yield it back up to the parent
+ # inlineCallbacks().
try:
result = yield d
- except Exception as e:
- result = Failure(e)
+ except Exception:
+ # this will fish an earlier Failure out of the stack where possible, and
+ # thus is preferable to passing in an exeception to the Failure
+ # constructor, since it results in less stack-mangling.
+ result = Failure()
if current_context() != expected_context:
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 5ca4521c..e5efdfcd 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -43,7 +43,7 @@ class FederationRateLimiter(object):
self.ratelimiters = collections.defaultdict(new_limiter)
def ratelimit(self, host):
- """Used to ratelimit an incoming request from given host
+ """Used to ratelimit an incoming request from a given host
Example usage:
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 6899bcb7..08c86e92 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -19,10 +19,6 @@ import re
import string
from collections import Iterable
-import six
-from six import PY2, PY3
-from six.moves import range
-
from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -47,80 +43,16 @@ def random_string_with_symbols(length):
def is_ascii(s):
-
- if PY3:
- if isinstance(s, bytes):
- try:
- s.decode("ascii").encode("ascii")
- except UnicodeDecodeError:
- return False
- except UnicodeEncodeError:
- return False
- return True
-
- try:
- s.encode("ascii")
- except UnicodeEncodeError:
- return False
- except UnicodeDecodeError:
- return False
- else:
+ if isinstance(s, bytes):
+ try:
+ s.decode("ascii").encode("ascii")
+ except UnicodeDecodeError:
+ return False
+ except UnicodeEncodeError:
+ return False
return True
-def to_ascii(s):
- """Converts a string to ascii if it is ascii, otherwise leave it alone.
-
- If given None then will return None.
- """
- if PY3:
- return s
-
- if s is None:
- return None
-
- try:
- return s.encode("ascii")
- except UnicodeEncodeError:
- return s
-
-
-def exception_to_unicode(e):
- """Helper function to extract the text of an exception as a unicode string
-
- Args:
- e (Exception): exception to be stringified
-
- Returns:
- unicode
- """
- # urgh, this is a mess. The basic problem here is that psycopg2 constructs its
- # exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will
- # then produce the raw byte sequence. Under Python 2, this will then cause another
- # error if it gets mixed with a `unicode` object, as per
- # https://github.com/matrix-org/synapse/issues/4252
-
- # First of all, if we're under python3, everything is fine because it will sort this
- # nonsense out for us.
- if not PY2:
- return str(e)
-
- # otherwise let's have a stab at decoding the exception message. We'll circumvent
- # Exception.__str__(), which would explode if someone raised Exception(u'non-ascii')
- # and instead look at what is in the args member.
-
- if len(e.args) == 0:
- return ""
- elif len(e.args) > 1:
- return six.text_type(repr(e.args))
-
- msg = e.args[0]
- if isinstance(msg, bytes):
- return msg.decode("utf-8", errors="replace")
- else:
- return msg
-
-
def assert_valid_client_secret(client_secret):
"""Validate that a given string matches the client_secret regex defined by the spec"""
if client_secret_regex.match(client_secret) is None:
diff --git a/synctl b/synctl
index bbccd052..960fd357 100755
--- a/synctl
+++ b/synctl
@@ -142,12 +142,23 @@ def start_worker(app: str, configfile: str, worker_configfile: str) -> bool:
return False
-def stop(pidfile, app):
+def stop(pidfile: str, app: str) -> bool:
+ """Attempts to kill a synapse worker from the pidfile.
+ Args:
+ pidfile: path to file containing worker's pid
+ app: name of the worker's appservice
+
+ Returns:
+ True if the process stopped successfully
+ False if process was already stopped or an error occured
+ """
+
if os.path.exists(pidfile):
pid = int(open(pidfile).read())
try:
os.kill(pid, signal.SIGTERM)
write("stopped %s" % (app,), colour=GREEN)
+ return True
except OSError as err:
if err.errno == errno.ESRCH:
write("%s not running" % (app,), colour=YELLOW)
@@ -155,6 +166,14 @@ def stop(pidfile, app):
abort("Cannot stop %s: Operation not permitted" % (app,))
else:
abort("Cannot stop %s: Unknown error" % (app,))
+ return False
+ else:
+ write(
+ "No running worker of %s found (from %s)\nThe process might be managed by another controller (e.g. systemd)"
+ % (app, pidfile),
+ colour=YELLOW,
+ )
+ return False
Worker = collections.namedtuple(
@@ -300,11 +319,17 @@ def main():
action = options.action
if action == "stop" or action == "restart":
+ has_stopped = True
for worker in workers:
- stop(worker.pidfile, worker.app)
+ if not stop(worker.pidfile, worker.app):
+ # A worker could not be stopped.
+ has_stopped = False
if start_stop_synapse:
- stop(pidfile, "synapse.app.homeserver")
+ if not stop(pidfile, "synapse.app.homeserver"):
+ has_stopped = False
+ if not has_stopped and action == "stop":
+ sys.exit(1)
# Wait for synapse to actually shutdown before starting it again
if action == "restart":
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index dbdd427c..d580e729 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,39 +1,97 @@
-from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from tests import unittest
class TestRatelimiter(unittest.TestCase):
- def test_allowed(self):
- limiter = Ratelimiter()
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
- )
+ def test_allowed_via_can_do_action(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
- )
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
- )
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
- def test_pruning(self):
- limiter = Ratelimiter()
+ def test_allowed_via_ratelimit(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # Shouldn't raise
+ limiter.ratelimit(key="test_id", _time_now_s=0)
+
+ # Should raise
+ with self.assertRaises(LimitExceededError) as context:
+ limiter.ratelimit(key="test_id", _time_now_s=5)
+ self.assertEqual(context.exception.retry_after_ms, 5000)
+
+ # Shouldn't raise
+ limiter.ratelimit(key="test_id", _time_now_s=10)
+
+ def test_allowed_via_can_do_action_and_overriding_parameters(self):
+ """Test that we can override options of can_do_action that would otherwise fail
+ an action
+ """
+ # Create a Ratelimiter with a very low allowed rate_hz and burst_count
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # First attempt should be allowed
+ allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,)
+ self.assertTrue(allowed)
+ self.assertEqual(10.0, time_allowed)
+
+ # Second attempt, 1s later, will fail
+ allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,)
+ self.assertFalse(allowed)
+ self.assertEqual(10.0, time_allowed)
+
+ # But, if we allow 10 actions/sec for this request, we should be allowed
+ # to continue.
allowed, time_allowed = limiter.can_do_action(
- key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
+ ("test_id",), _time_now_s=1, rate_hz=10.0
)
+ self.assertTrue(allowed)
+ self.assertEqual(1.1, time_allowed)
- self.assertIn("test_id_1", limiter.message_counts)
-
+ # Similarly if we allow a burst of 10 actions
allowed, time_allowed = limiter.can_do_action(
- key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
+ ("test_id",), _time_now_s=1, burst_count=10
)
+ self.assertTrue(allowed)
+ self.assertEqual(1.0, time_allowed)
+
+ def test_allowed_via_ratelimit_and_overriding_parameters(self):
+ """Test that we can override options of the ratelimit method that would otherwise
+ fail an action
+ """
+ # Create a Ratelimiter with a very low allowed rate_hz and burst_count
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # First attempt should be allowed
+ limiter.ratelimit(key=("test_id",), _time_now_s=0)
+
+ # Second attempt, 1s later, will fail
+ with self.assertRaises(LimitExceededError) as context:
+ limiter.ratelimit(key=("test_id",), _time_now_s=1)
+ self.assertEqual(context.exception.retry_after_ms, 9000)
+
+ # But, if we allow 10 actions/sec for this request, we should be allowed
+ # to continue.
+ limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+
+ # Similarly if we allow a burst of 10 actions
+ limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+
+ def test_pruning(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter.can_do_action(key="test_id_1", _time_now_s=0)
+
+ self.assertIn("test_id_1", limiter.actions)
+
+ limiter.can_do_action(key="test_id_2", _time_now_s=10)
- self.assertNotIn("test_id_1", limiter.message_counts)
+ self.assertNotIn("test_id_1", limiter.actions)
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
new file mode 100644
index 00000000..d3ec24c9
--- /dev/null
+++ b/tests/config/test_cache.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config, RootConfig
+from synapse.config.cache import CacheConfig, add_resizable_cache
+from synapse.util.caches.lrucache import LruCache
+
+from tests.unittest import TestCase
+
+
+class FakeServer(Config):
+ section = "server"
+
+
+class TestConfig(RootConfig):
+ config_classes = [FakeServer, CacheConfig]
+
+
+class CacheConfigTests(TestCase):
+ def setUp(self):
+ # Reset caches before each test
+ TestConfig().caches.reset()
+
+ def test_individual_caches_from_environ(self):
+ """
+ Individual cache factors will be loaded from the environment.
+ """
+ config = {}
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+ "SYNAPSE_NOT_CACHE": "BLAH",
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0})
+
+ def test_config_overrides_environ(self):
+ """
+ Individual cache factors defined in the environment will take precedence
+ over those in the config.
+ """
+ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+ "SYNAPSE_CACHE_FACTOR_FOO": 1,
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(
+ dict(t.caches.cache_factors),
+ {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
+ )
+
+ def test_individual_instantiated_before_config_load(self):
+ """
+ If a cache is instantiated before the config is read, it will be given
+ the default cache size in the interim, and then resized once the config
+ is loaded.
+ """
+ cache = LruCache(100)
+
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 50)
+
+ config = {"caches": {"per_cache_factors": {"foo": 3}}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(cache.max_size, 300)
+
+ def test_individual_instantiated_after_config_load(self):
+ """
+ If a cache is instantiated after the config is read, it will be
+ immediately resized to the correct size given the per_cache_factor if
+ there is one.
+ """
+ config = {"caches": {"per_cache_factors": {"foo": 2}}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 200)
+
+ def test_global_instantiated_before_config_load(self):
+ """
+ If a cache is instantiated before the config is read, it will be given
+ the default cache size in the interim, and then resized to the new
+ default cache size once the config is loaded.
+ """
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 50)
+
+ config = {"caches": {"global_factor": 4}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(cache.max_size, 400)
+
+ def test_global_instantiated_after_config_load(self):
+ """
+ If a cache is instantiated after the config is read, it will be
+ immediately resized to the correct size given the global factor if there
+ is no per-cache factor.
+ """
+ config = {"caches": {"global_factor": 1.5}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 150)
+
+ def test_cache_with_asterisk_in_name(self):
+ """Some caches have asterisks in their name, test that they are set correctly.
+ """
+
+ config = {
+ "caches": {
+ "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
+ }
+ }
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
+ "SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache_a = LruCache(100)
+ add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
+ self.assertEqual(cache_a.max_size, 200)
+
+ cache_b = LruCache(100)
+ add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
+ self.assertEqual(cache_b.max_size, 300)
+
+ cache_c = LruCache(100)
+ add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
+ self.assertEqual(cache_c.max_size, 200)
+
+ def test_apply_cache_factor_from_config(self):
+ """Caches can disable applying cache factor updates, mainly used by
+ event cache size.
+ """
+
+ config = {"caches": {"event_cache_size": "10k"}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(
+ max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False,
+ )
+ add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
+
+ self.assertEqual(cache.max_size, 10240)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index ab5f5ac5..c1274c14 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -156,7 +156,7 @@ class PruneEventTestCase(unittest.TestCase):
"signatures": {},
"unsigned": {},
},
- room_version=RoomVersions.MSC2432_DEV,
+ room_version=RoomVersions.V6,
)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 94980733..0c9987be 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=defer.succeed(("", 1))
+ )
d = handler._remote_join(
None,
@@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
- handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=defer.succeed(("", 1))
+ )
# Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 33105576..ff125390 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -536,7 +536,7 @@ def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
return {
"user_id": user_id,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
"keys": {
"curve25519:" + device_id: "curve25519+key",
key_id(sk): encode_pubkey(sk),
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 854eb6c0..e1e144b2 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -222,7 +222,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_1 = {
"user_id": local_user,
"device_id": "abc",
- "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
"keys": {
"ed25519:abc": "base64+ed25519+key",
"curve25519:abc": "base64+curve25519+key",
@@ -232,7 +232,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_2 = {
"user_id": local_user,
"device_id": "def",
- "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
"keys": {
"ed25519:def": "base64+ed25519+key",
"curve25519:def": "base64+curve25519+key",
@@ -315,7 +315,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key = {
"user_id": local_user,
"device_id": device_id,
- "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
"keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
"signatures": {local_user: {"ed25519:xyz": "something"}},
}
@@ -391,8 +391,8 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_id": local_user,
"device_id": device_id,
"algorithms": [
- "m.olm.curve25519-aes-sha256",
- "m.megolm.v1.aes-sha",
+ "m.olm.curve25519-aes-sha2",
+ "m.megolm.v1.aes-sha2",
],
"keys": {
"curve25519:xyz": "curve25519+key",
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 132e3565..96fea586 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from unittest import TestCase
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes
+from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
@@ -207,3 +210,65 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
return join_event
+
+
+class EventFromPduTestCase(TestCase):
+ def test_valid_json(self):
+ """Valid JSON should be turned into an event."""
+ ev = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"bool": True, "null": None, "int": 1, "str": "foobar"},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
+
+ self.assertIsInstance(ev, EventBase)
+
+ def test_invalid_numbers(self):
+ """Invalid values for an integer should be rejected, all floats should be rejected."""
+ for value in [
+ -(2 ** 53),
+ 2 ** 53,
+ 1.0,
+ float("inf"),
+ float("-inf"),
+ float("nan"),
+ ]:
+ with self.assertRaises(SynapseError):
+ event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"foo": value},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
+
+ def test_invalid_nested(self):
+ """List and dictionaries are recursively searched."""
+ with self.assertRaises(SynapseError):
+ event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"foo": [{"bar": 2 ** 56}]},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
new file mode 100644
index 00000000..1bb25ab6
--- /dev/null
+++ b/tests/handlers/test_oidc.py
@@ -0,0 +1,570 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from urllib.parse import parse_qs, urlparse
+
+from mock import Mock, patch
+
+import attr
+import pymacaroons
+
+from twisted.internet import defer
+from twisted.python.failure import Failure
+from twisted.web._newclient import ResponseDone
+
+from synapse.handlers.oidc_handler import (
+ MappingException,
+ OidcError,
+ OidcHandler,
+ OidcMappingProvider,
+)
+from synapse.types import UserID
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+@attr.s
+class FakeResponse:
+ code = attr.ib()
+ body = attr.ib()
+ phrase = attr.ib()
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
+
+
+# These are a few constants that are used as config parameters in the tests.
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+
+# config for common cases
+COMMON_CONFIG = {
+ "discover": False,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+}
+
+
+# The cookie name and path don't really matter, just that it has to be coherent
+# between the callback & redirect handlers.
+COOKIE_NAME = b"oidc_session"
+COOKIE_PATH = "/_synapse/oidc"
+
+MockedMappingProvider = Mock(OidcMappingProvider)
+
+
+def simple_async_mock(return_value=None, raises=None):
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+async def get_json(url):
+ # Mock get_json calls to handle jwks & oidc discovery endpoints
+ if url == WELL_KNOWN:
+ # Minimal discovery document, as defined in OpenID.Discovery
+ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ return {
+ "issuer": ISSUER,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+ "userinfo_endpoint": USERINFO_ENDPOINT,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+ elif url == JWKS_URI:
+ return {"keys": []}
+
+
+class OidcHandlerTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ config = self.default_config()
+ config["public_baseurl"] = BASE_URL
+ oidc_config = config.get("oidc_config", {})
+ oidc_config["enabled"] = True
+ oidc_config["client_id"] = CLIENT_ID
+ oidc_config["client_secret"] = CLIENT_SECRET
+ oidc_config["issuer"] = ISSUER
+ oidc_config["scopes"] = SCOPES
+ oidc_config["user_mapping_provider"] = {
+ "module": __name__ + ".MockedMappingProvider"
+ }
+ config["oidc_config"] = oidc_config
+
+ hs = self.setup_test_homeserver(
+ http_client=self.http_client,
+ proxied_http_client=self.http_client,
+ config=config,
+ )
+
+ self.handler = OidcHandler(hs)
+
+ return hs
+
+ def metadata_edit(self, values):
+ return patch.dict(self.handler._provider_metadata, values)
+
+ def assertRenderedError(self, error, error_description=None):
+ args = self.handler._render_error.call_args[0]
+ self.assertEqual(args[1], error)
+ if error_description is not None:
+ self.assertEqual(args[2], error_description)
+ # Reset the render_error mock
+ self.handler._render_error.reset_mock()
+
+ def test_config(self):
+ """Basic config correctly sets up the callback URL and client auth correctly."""
+ self.assertEqual(self.handler._callback_url, CALLBACK_URL)
+ self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+
+ @override_config({"oidc_config": {"discover": True}})
+ @defer.inlineCallbacks
+ def test_discovery(self):
+ """The handler should discover the endpoints from OIDC discovery document."""
+ # This would throw if some metadata were invalid
+ metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+
+ self.assertEqual(metadata.issuer, ISSUER)
+ self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
+ self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
+ self.assertEqual(metadata.jwks_uri, JWKS_URI)
+ # FIXME: it seems like authlib does not have that defined in its metadata models
+ # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+
+ # subsequent calls should be cached
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_no_discovery(self):
+ """When discovery is disabled, it should not try to load from discovery document."""
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_load_jwks(self):
+ """JWKS loading is done once (then cached) if used."""
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.assertEqual(jwks, {"keys": []})
+
+ # subsequent calls should be cached…
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_not_called()
+
+ # …unless forced
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+
+ # Throw if the JWKS uri is missing
+ with self.metadata_edit({"jwks_uri": None}):
+ with self.assertRaises(RuntimeError):
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+
+ # Return empty key set if JWKS are not used
+ self.handler._scopes = [] # not asking the openid scope
+ self.http_client.get_json.reset_mock()
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_not_called()
+ self.assertEqual(jwks, {"keys": []})
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ def test_validate_config(self):
+ """Provider metadatas are extensively validated."""
+ h = self.handler
+
+ # Default test config does not throw
+ h._validate_metadata()
+
+ with self.metadata_edit({"issuer": None}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "http://insecure/"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"authorization_endpoint": None}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"token_endpoint": None}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": None}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"response_types_supported": ["id_token"]}):
+ self.assertRaisesRegex(
+ ValueError, "response_types_supported", h._validate_metadata
+ )
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
+ ):
+ # should not throw, as client_secret_basic is the default auth method
+ h._validate_metadata()
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
+ ):
+ self.assertRaisesRegex(
+ ValueError,
+ "token_endpoint_auth_methods_supported",
+ h._validate_metadata,
+ )
+
+ # Tests for configs that the userinfo endpoint
+ self.assertFalse(h._uses_userinfo)
+ h._scopes = [] # do not request the openid scope
+ self.assertTrue(h._uses_userinfo)
+ self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+
+ with self.metadata_edit(
+ {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
+ ):
+ # Shouldn't raise with a valid userinfo, even without
+ h._validate_metadata()
+
+ @override_config({"oidc_config": {"skip_verification": True}})
+ def test_skip_verification(self):
+ """Provider metadata validation can be disabled by config."""
+ with self.metadata_edit({"issuer": "http://insecure"}):
+ # This should not throw
+ self.handler._validate_metadata()
+
+ @defer.inlineCallbacks
+ def test_redirect_request(self):
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["addCookie"])
+ url = yield defer.ensureDeferred(
+ self.handler.handle_redirect_request(req, b"http://client/redirect")
+ )
+ url = urlparse(url)
+ auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+
+ self.assertEqual(url.scheme, auth_endpoint.scheme)
+ self.assertEqual(url.netloc, auth_endpoint.netloc)
+ self.assertEqual(url.path, auth_endpoint.path)
+
+ params = parse_qs(url.query)
+ self.assertEqual(params["redirect_uri"], [CALLBACK_URL])
+ self.assertEqual(params["response_type"], ["code"])
+ self.assertEqual(params["scope"], [" ".join(SCOPES)])
+ self.assertEqual(params["client_id"], [CLIENT_ID])
+ self.assertEqual(len(params["state"]), 1)
+ self.assertEqual(len(params["nonce"]), 1)
+
+ # Check what is in the cookie
+ # note: python3.5 mock does not have the .called_once() method
+ calls = req.addCookie.call_args_list
+ self.assertEqual(len(calls), 1) # called once
+ # For some reason, call.args does not work with python3.5
+ args = calls[0][0]
+ kwargs = calls[0][1]
+ self.assertEqual(args[0], COOKIE_NAME)
+ self.assertEqual(kwargs["path"], COOKIE_PATH)
+ cookie = args[1]
+
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ state = self.handler._get_value_from_macaroon(macaroon, "state")
+ nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
+ redirect = self.handler._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+
+ self.assertEqual(params["state"], [state])
+ self.assertEqual(params["nonce"], [nonce])
+ self.assertEqual(redirect, "http://client/redirect")
+
+ @defer.inlineCallbacks
+ def test_callback_error(self):
+ """Errors from the provider returned in the callback are displayed."""
+ self.handler._render_error = Mock()
+ request = Mock(args={})
+ request.args[b"error"] = [b"invalid_client"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "")
+
+ request.args[b"error_description"] = [b"some description"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "some description")
+
+ @defer.inlineCallbacks
+ def test_callback(self):
+ """Code callback works and display errors if something went wrong.
+
+ A lot of scenarios are tested here:
+ - when the callback works, with userinfo from ID token
+ - when the user mapping fails
+ - when ID token verification fails
+ - when the callback works, with userinfo fetched from the userinfo endpoint
+ - when the userinfo fetching fails
+ - when the code exchange fails
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "preferred_username": "bar",
+ }
+ user_id = UserID("foo", "domain.org")
+ self.handler._render_error = Mock(return_value=None)
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ code = "code"
+ state = "state"
+ nonce = "nonce"
+ client_redirect_url = "http://client/redirect"
+ session = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+ request.getCookie.return_value = session
+
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_not_called()
+ self.handler._render_error.assert_not_called()
+
+ # Handle mapping errors
+ self.handler._map_userinfo_to_user = simple_async_mock(
+ raises=MappingException()
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+
+ # Handle ID token errors
+ self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_token")
+
+ self.handler._auth_handler.complete_sso_login.reset_mock()
+ self.handler._exchange_code.reset_mock()
+ self.handler._parse_id_token.reset_mock()
+ self.handler._map_userinfo_to_user.reset_mock()
+ self.handler._fetch_userinfo.reset_mock()
+
+ # With userinfo fetching
+ self.handler._scopes = [] # do not ask the "openid" scope
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_not_called()
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.handler._render_error.assert_not_called()
+
+ # Handle userinfo fetching error
+ self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("fetch_error")
+
+ # Handle code exchange failure
+ self.handler._exchange_code = simple_async_mock(
+ raises=OidcError("invalid_request")
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @defer.inlineCallbacks
+ def test_callback_session(self):
+ """The callback verifies the session presence and validity"""
+ self.handler._render_error = Mock(return_value=None)
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ # Missing cookie
+ request.args = {}
+ request.getCookie.return_value = None
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("missing_session", "No session cookie found")
+
+ # Missing session parameter
+ request.args = {}
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request", "State parameter is missing")
+
+ # Invalid cookie
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_session")
+
+ # Mismatching session
+ session = self.handler._generate_oidc_session_token(
+ state="state",
+ nonce="nonce",
+ client_redirect_url="http://client/redirect",
+ ui_auth_session_id=None,
+ )
+ request.args = {}
+ request.args[b"state"] = [b"mismatching state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mismatching_session")
+
+ # Valid session
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @defer.inlineCallbacks
+ def test_exchange_code(self):
+ """Code exchange behaves correctly and handles various error scenarios."""
+ token = {"type": "bearer"}
+ token_json = json.dumps(token).encode("utf-8")
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ )
+ code = "code"
+ ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ kwargs = self.http_client.request.call_args[1]
+
+ self.assertEqual(ret, token)
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["client_secret"], [CLIENT_SECRET])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ # Test error handling
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=400,
+ phrase=b"Bad Request",
+ body=b'{"error": "foo", "error_description": "bar"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "foo")
+ self.assertEqual(exc.exception.error_description, "bar")
+
+ # Internal server error with no JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # Internal server error with JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500,
+ phrase=b"Internal Server Error",
+ body=b'{"error": "internal_server_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "internal_server_error")
+
+ # 4xx error without "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # 2xx error with "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "some_error")
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 8aa56f14..29dd7d9c 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from mock import Mock, NonCallableMock
+from mock import Mock
from twisted.internet import defer
@@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
-
self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test")
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 1b7935ce..ca32f993 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -135,6 +135,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
+ def test_auto_join_rooms_for_guests(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ self.hs.config.auto_join_rooms_for_guests = False
+ user_id = self.get_success(
+ self.handler.register_user(localpart="jeff", make_guest=True),
+ )
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 0)
+
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 51e2b372..2fa8d473 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
reactor.pump((1000,))
hs = self.setup_test_homeserver(
- notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
+ notifier=Mock(),
+ http_client=mock_federation_client,
+ keyring=mock_keyring,
+ replication_streams={},
)
hs.datastores = datastores
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 7b92bdbc..c15bce5b 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -14,6 +14,8 @@
# limitations under the License.
from mock import Mock
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.rest.client.v1 import login, room
@@ -75,18 +77,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
- self.store.remove_from_user_dir = Mock()
- self.store.remove_from_user_in_public_room = Mock()
+ self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
self.get_success(self.handler.handle_user_deactivated(s_user_id))
self.store.remove_from_user_dir.not_called()
- self.store.remove_from_user_in_public_room.not_called()
def test_handle_user_deactivated_regular_user(self):
r_user_id = "@regular:test"
self.get_success(
self.store.register_user(user_id=r_user_id, password_hash=None)
)
- self.store.remove_from_user_dir = Mock()
+ self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
self.get_success(self.handler.handle_user_deactivated(r_user_id))
self.store.remove_from_user_dir.called_once_with(r_user_id)
@@ -185,7 +185,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Allow all users.
return False
- spam_checker.spam_checker = AllowAll()
+ spam_checker.spam_checkers = [AllowAll()]
# The results do not change:
# We get one search result when searching for user2 by user1.
@@ -198,7 +198,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# All users are spammy.
return True
- spam_checker.spam_checker = BlockAll()
+ spam_checker.spam_checkers = [BlockAll()]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/_base.py
index 7b56d202..9d4f0bbe 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/_base.py
@@ -27,6 +27,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
+from synapse.replication.http import streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -42,6 +43,10 @@ logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ servlets = [
+ streams.register_servlets,
+ ]
+
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.server = server_factory.buildProtocol(None)
# Make a new HomeServer object for the worker
- config = self.default_config()
- config["worker_app"] = "synapse.app.generic_worker"
- config["worker_replication_host"] = "testserv"
- config["worker_replication_http_port"] = "8765"
-
self.reactor.lookups["testserv"] = "1.2.3.4"
-
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
- config=config,
+ config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 1615dfab..56497b84 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,67 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, NonCallableMock
+from mock import Mock
-from synapse.replication.tcp.client import (
- DirectTcpReplicationClientFactory,
- ReplicationDataHandler,
-)
-from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import make_conn
+from tests.replication._base import BaseStreamTestCase
-from tests import unittest
-from tests.server import FakeTransport
-
-class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(
- "blue",
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
- )
-
- hs.get_ratelimiter().can_do_action.return_value = (True, 0)
+ hs = self.setup_test_homeserver(federation_client=Mock())
return hs
def prepare(self, reactor, clock, hs):
+ super().prepare(reactor, clock, hs)
- db_config = hs.config.database.get_single_database()
- self.master_store = self.hs.get_datastore()
- self.storage = hs.get_storage()
- database = hs.get_datastores().databases[0]
- self.slaved_store = self.STORE_TYPE(
- database, make_conn(db_config, database.engine), self.hs
- )
- self.event_id = 0
-
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = hs.get_replication_streamer()
+ self.reconnect()
- # We now do some gut wrenching so that we have a client that is based
- # off of the slave store rather than the main store.
- self.replication_handler = ReplicationCommandHandler(self.hs)
- self.replication_handler._instance_name = "worker"
- self.replication_handler._replication_data_handler = ReplicationDataHandler(
- self.slaved_store
- )
-
- client_factory = DirectTcpReplicationClientFactory(
- self.hs, "client_name", self.replication_handler
- )
- client_factory.handler = self.replication_handler
-
- server = server_factory.buildProtocol(None)
- client = client_factory.buildProtocol(None)
-
- client.makeConnection(FakeTransport(server, reactor))
-
- self.server_to_client_transport = FakeTransport(client, reactor)
- server.makeConnection(self.server_to_client_transport)
+ self.master_store = hs.get_datastore()
+ self.slaved_store = self.worker_hs.get_datastore()
+ self.storage = hs.get_storage()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f0561b30..1a88c7fb 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -17,17 +17,18 @@ from canonicaljson import encode_canonical_json
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
-from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from tests.server import FakeTransport
+
from ._base import BaseSlavedStoreTestCase
-USER_ID = "@feeling:blue"
-USER_ID_2 = "@bright:blue"
+USER_ID = "@feeling:test"
+USER_ID_2 = "@bright:test"
OUTLIER = {"outlier": True}
-ROOM_ID = "!room:blue"
+ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
@@ -239,7 +240,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
# limit the replication rate
- repl_transport = self.server_to_client_transport
+ repl_transport = self._server_transport
+ assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
@@ -322,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.message",
key=None,
internal={},
- state=None,
depth=None,
prev_events=[],
auth_events=[],
@@ -362,15 +363,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
self.event_id += 1
-
- if state is not None:
- state_ids = {key: e.event_id for key, e in state.items()}
- context = EventContext.with_state(
- state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
- )
- else:
- state_handler = self.hs.get_state_handler()
- context = self.get_success(state_handler.compute_event_context(event))
+ state_handler = self.hs.get_state_handler()
+ context = self.get_success(state_handler.compute_event_context(event))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
new file mode 100644
index 00000000..6a5116dd
--- /dev/null
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.tcp.streams._base import (
+ _STREAM_UPDATE_TARGET_ROW_COUNT,
+ AccountDataStream,
+)
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class AccountDataStreamTestCase(BaseStreamTestCase):
+ def test_update_function_room_account_data_limit(self):
+ """Test replication with many room account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", update, {})
+ )
+ updates.append(update)
+
+ # also one global update
+ self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertEqual(row.room_id, "test_room")
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.global")
+ self.assertIsNone(row.room_id)
+
+ self.assertEqual([], received_rows)
+
+ def test_update_function_global_account_data_limit(self):
+ """Test replication with many global account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(store.add_account_data_for_user("test_user", update, {}))
+ updates.append(update)
+
+ # also one per-room update
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
+ )
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertIsNone(row.room_id)
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.per_room")
+ self.assertEqual(row.room_id, "test_room")
+
+ self.assertEqual([], received_rows)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 8bd67bb9..51bf0ef4 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -26,7 +26,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
-from tests.replication.tcp.streams._base import BaseStreamTestCase
+from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 00000000..2babea4e
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+ def _get_worker_hs_config(self) -> dict:
+ # enable federation sending on the worker
+ config = super()._get_worker_hs_config()
+ # TODO: make it so we don't need both of these
+ config["send_federation"] = True
+ config["worker_app"] = "synapse.app.federation_sender"
+ return config
+
+ def test_catchup(self):
+ """Basic test of catchup on reconnect
+
+ Makes sure that updates sent while we are offline are received later.
+ """
+ fed_sender = self.hs.get_federation_sender()
+ received_rows = self.test_handler.received_rdata_rows
+
+ fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+ self.reconnect()
+ self.reactor.advance(0)
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual(received_rows, [])
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+ # we should have received an update row
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"a": "b"})
+
+ self.assertEqual(received_rows, [])
+
+ # additional updates should be transferred without an HTTP hit
+ fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+ self.reactor.advance(0)
+ # there should be no http hit
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ # ... but we should have a row
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test1")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 5853314f..56b062ec 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -19,7 +19,7 @@ from mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
-from tests.replication.tcp.streams._base import BaseStreamTestCase
+from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index d25a7b19..fd62b263 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -15,19 +15,14 @@
from mock import Mock
from synapse.handlers.typing import RoomMember
-from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
-from tests.replication.tcp.streams._base import BaseStreamTestCase
+from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
class TypingStreamTestCase(BaseStreamTestCase):
- servlets = [
- streams.register_servlets,
- ]
-
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index 7ddfd0a7..60c10a44 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -30,7 +30,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata(self):
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line)
- self.assertIsInstance(cmd, RdataCommand)
+ assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "events")
self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863)
@@ -38,7 +38,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata_batch(self):
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line)
- self.assertIsInstance(cmd, RdataCommand)
+ assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "presence")
self.assertEqual(cmd.instance_name, "master")
self.assertIsNone(cmd.token)
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
new file mode 100644
index 00000000..5448d9f0
--- /dev/null
+++ b/tests/replication/test_federation_ack.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mock
+
+from synapse.app.generic_worker import GenericWorkerServer
+from synapse.replication.tcp.commands import FederationAckCommand
+from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.unittest import HomeserverTestCase
+
+
+class FederationAckTestCase(HomeserverTestCase):
+ def default_config(self) -> dict:
+ config = super().default_config()
+ config["worker_app"] = "synapse.app.federation_sender"
+ config["send_federation"] = True
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+ return hs
+
+ def test_federation_ack_sent(self):
+ """A FEDERATION_ACK should be sent back after each RDATA federation
+
+ This test checks that the federation sender is correctly sending back
+ FEDERATION_ACK messages. The test works by spinning up a federation_sender
+ worker server, and then fishing out its ReplicationCommandHandler. We wire
+ the RCH up to a mock connection (so that we can observe the command being sent)
+ and then poke in an RDATA row.
+
+ XXX: it might be nice to do this by pretending to be a synapse master worker
+ (or a redis server), and having the worker connect to us via a mocked-up TCP
+ transport, rather than assuming that the implementation has a
+ ReplicationCommandHandler.
+ """
+ rch = self.hs.get_tcp_replication()
+
+ # wire up the ReplicationCommandHandler to a mock connection
+ mock_connection = mock.Mock(spec=AbstractConnection)
+ rch.new_connection(mock_connection)
+
+ # tell it it received an RDATA row
+ self.get_success(
+ rch.on_rdata(
+ "federation",
+ "master",
+ token=10,
+ rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])],
+ )
+ )
+
+ # now check that the FEDERATION_ACK was sent
+ mock_connection.send_command.assert_called_once()
+ cmd = mock_connection.send_command.call_args[0][0]
+ assert isinstance(cmd, FederationAckCommand)
+ self.assertEqual(cmd.token, 10)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
new file mode 100644
index 00000000..faa7f381
--- /dev/null
+++ b/tests/rest/admin/test_device.py
@@ -0,0 +1,541 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login
+
+from tests import unittest
+
+
+class DeviceRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.other_user_device_id = res[0]["device_id"]
+
+ self.url = "/_synapse/admin/v2/users/%s/devices/%s" % (
+ urllib.parse.quote(self.other_user),
+ self.other_user_device_id,
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get a device of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request("PUT", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request("DELETE", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "DELETE", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = (
+ "/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
+ % self.other_user_device_id
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = (
+ "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
+ % self.other_user_device_id
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_unknown_device(self):
+ """
+ Tests that a lookup for a device that does not exist returns either 404 or 200.
+ """
+ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
+ self.other_user
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Delete unknown device returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_update_device_too_long_display_name(self):
+ """
+ Update a device with a display name that is invalid (too long).
+ """
+ # Set iniital display name.
+ update = {"display_name": "new display"}
+ self.get_success(
+ self.handler.update_device(
+ self.other_user, self.other_user_device_id, update
+ )
+ )
+
+ # Request to update a device display name with a new value that is longer than allowed.
+ update = {
+ "display_name": "a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
+ }
+
+ body = json.dumps(update)
+ request, channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # Ensure the display name was not updated.
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new display", channel.json_body["display_name"])
+
+ def test_update_no_display_name(self):
+ """
+ Tests that a update for a device without JSON returns a 200
+ """
+ # Set iniital display name.
+ update = {"display_name": "new display"}
+ self.get_success(
+ self.handler.update_device(
+ self.other_user, self.other_user_device_id, update
+ )
+ )
+
+ request, channel = self.make_request(
+ "PUT", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Ensure the display name was not updated.
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new display", channel.json_body["display_name"])
+
+ def test_update_display_name(self):
+ """
+ Tests a normal successful update of display name
+ """
+ # Set new display_name
+ body = json.dumps({"display_name": "new displayname"})
+ request, channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check new display_name
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new displayname", channel.json_body["display_name"])
+
+ def test_get_device(self):
+ """
+ Tests that a normal lookup for a device is successfully
+ """
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ # Check that all fields are available
+ self.assertIn("user_id", channel.json_body)
+ self.assertIn("device_id", channel.json_body)
+ self.assertIn("display_name", channel.json_body)
+ self.assertIn("last_seen_ip", channel.json_body)
+ self.assertIn("last_seen_ts", channel.json_body)
+
+ def test_delete_device(self):
+ """
+ Tests that a remove of a device is successfully
+ """
+ # Count number of devies of an user.
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ number_devices = len(res)
+ self.assertEqual(1, number_devices)
+
+ # Delete device
+ request, channel = self.make_request(
+ "DELETE", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Ensure that the number of devices is decreased
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(number_devices - 1, len(res))
+
+
+class DevicesRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v2/users/%s/devices" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list devices of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_devices(self):
+ """
+ Tests that a normal lookup for devices is successfully
+ """
+ # Create devices
+ number_devices = 5
+ for n in range(number_devices):
+ self.login("user", "pass")
+
+ # Get devices
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_devices, len(channel.json_body["devices"]))
+ self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
+ # Check that all fields are available
+ for d in channel.json_body["devices"]:
+ self.assertIn("user_id", d)
+ self.assertIn("device_id", d)
+ self.assertIn("display_name", d)
+ self.assertIn("last_seen_ip", d)
+ self.assertIn("last_seen_ts", d)
+
+
+class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v2/users/%s/delete_devices" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to delete devices of an user without authentication.
+ """
+ request, channel = self.make_request("POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "POST", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
+ request, channel = self.make_request(
+ "POST", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
+
+ request, channel = self.make_request(
+ "POST", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_unknown_devices(self):
+ """
+ Tests that a remove of a device that does not exist returns 200.
+ """
+ body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ # Delete unknown devices returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_delete_devices(self):
+ """
+ Tests that a remove of devices is successfully
+ """
+
+ # Create devices
+ number_devices = 5
+ for n in range(number_devices):
+ self.login("user", "pass")
+
+ # Get devices
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(number_devices, len(res))
+
+ # Create list of device IDs
+ device_ids = []
+ for d in res:
+ device_ids.append(str(d["device_id"]))
+
+ # Delete devices
+ body = json.dumps({"devices": device_ids})
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(0, len(res))
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 249c9372..54cd24bf 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -701,6 +701,47 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 6c88ab06..cca5f548 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -22,9 +22,12 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
+from synapse.api.errors import HttpResponseException, ResourceLimitError
from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.unittest import override_config
class UserRegisterTestCase(unittest.HomeserverTestCase):
@@ -320,6 +323,52 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid user type", channel.json_body["error"])
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_register_mau_limit_reached(self):
+ """
+ Check we can register a user via the shared secret registration API
+ even if the MAU limit is reached.
+ """
+ handler = self.hs.get_registration_handler()
+ store = self.hs.get_datastore()
+
+ # Set monthly active users to the limit
+ store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
+ )
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
class UsersListTestCase(unittest.HomeserverTestCase):
@@ -368,6 +417,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
@@ -386,7 +436,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
- self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
request, channel = self.make_request(
@@ -409,7 +458,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
- self.hs.config.registration_shared_secret = None
request, channel = self.make_request(
"GET",
@@ -425,7 +473,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Check that a new admin user is created successfully.
"""
- self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user (server admin)
@@ -473,7 +520,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Check that a new regular user is created successfully.
"""
- self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
@@ -516,11 +562,192 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_create_user_mau_limit_reached_active_admin(self):
+ """
+ Check that an admin can register a new user via the admin API
+ even if the MAU limit is reached.
+ Admin user was active before creating user.
+ """
+
+ handler = self.hs.get_registration_handler()
+
+ # Sync to set admin user to active
+ # before limit of monthly active users is reached
+ request, channel = self.make_request(
+ "GET", "/sync", access_token=self.admin_user_tok
+ )
+ self.render(request)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"]
+ )
+
+ # Set monthly active users to the limit
+ self.store.get_monthly_active_count = Mock(
+ return_value=self.hs.config.max_mau_value
+ )
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps({"password": "abc123", "admin": False})
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["admin"])
+
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_create_user_mau_limit_reached_passive_admin(self):
+ """
+ Check that an admin can register a new user via the admin API
+ even if the MAU limit is reached.
+ Admin user was not active before creating user.
+ """
+
+ handler = self.hs.get_registration_handler()
+
+ # Set monthly active users to the limit
+ self.store.get_monthly_active_count = Mock(
+ return_value=self.hs.config.max_mau_value
+ )
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps({"password": "abc123", "admin": False})
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ # Admin user is not blocked by mau anymore
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["admin"])
+
+ @override_config(
+ {
+ "email": {
+ "enable_notifs": True,
+ "notif_for_new_users": True,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "https://example.com",
+ }
+ )
+ def test_create_user_email_notif_for_new_users(self):
+ """
+ Check that a new regular user is created successfully and
+ got an email pusher.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ pushers = self.get_success(
+ self.store.get_pushers_by({"user_name": "@bob:test"})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 1)
+ self.assertEqual("@bob:test", pushers[0]["user_name"])
+
+ @override_config(
+ {
+ "email": {
+ "enable_notifs": False,
+ "notif_for_new_users": False,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "https://example.com",
+ }
+ )
+ def test_create_user_email_no_notif_for_new_users(self):
+ """
+ Check that a new regular user is created successfully and
+ got not an email pusher.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ pushers = self.get_success(
+ self.store.get_pushers_by({"user_name": "@bob:test"})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 0)
+
def test_set_password(self):
"""
Test setting a new password for another user.
"""
- self.hs.config.registration_shared_secret = None
# Change password
body = json.dumps({"password": "hahaha"})
@@ -539,7 +766,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Test setting the displayname of another user.
"""
- self.hs.config.registration_shared_secret = None
# Modify user
body = json.dumps({"displayname": "foobar"})
@@ -570,7 +796,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Test setting threepid for an other user.
"""
- self.hs.config.registration_shared_secret = None
# Delete old and add new threepid to user
body = json.dumps(
@@ -636,7 +861,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Test setting the admin flag on a user.
"""
- self.hs.config.registration_shared_secret = None
# Set a user as an admin
body = json.dumps({"admin": True})
@@ -668,7 +892,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Ensure an account can't accidentally be deactivated by using a str value
for the deactivated body parameter
"""
- self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index b54b0648..f7552087 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -15,7 +15,7 @@
""" Tests REST events for /events paths."""
-from mock import Mock, NonCallableMock
+from mock import Mock
import synapse.rest.admin
from synapse.rest.client.v1 import events, login, room
@@ -40,11 +40,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
config["enable_registration"] = True
config["auto_join_rooms"] = []
- hs = self.setup_test_homeserver(
- config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
+ hs = self.setup_test_homeserver(config=config)
hs.get_handlers().federation_handler = Mock()
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 1856c7ff..9033f09f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,10 +1,13 @@
import json
+import time
import urllib.parse
from mock import Mock
+import jwt
+
import synapse.rest.admin
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
@@ -20,12 +23,12 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
def make_homeserver(self, reactor, clock):
-
self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
@@ -34,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs
+ @override_config(
+ {
+ "rc_login": {
+ "address": {"per_second": 0.17, "burst_count": 5},
+ # Prevent the account login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "account": {"per_second": 10000, "burst_count": 10000},
+ }
+ }
+ )
def test_POST_ratelimiting_per_address(self):
- self.hs.config.rc_login_address.burst_count = 5
- self.hs.config.rc_login_address.per_second = 0.17
-
# Create different users so we're sure not to be bothered by the per-user
# ratelimiter.
for i in range(0, 6):
@@ -76,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config(
+ {
+ "rc_login": {
+ "account": {"per_second": 0.17, "burst_count": 5},
+ # Prevent the address login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "address": {"per_second": 10000, "burst_count": 10000},
+ }
+ }
+ )
def test_POST_ratelimiting_per_account(self):
- self.hs.config.rc_login_account.burst_count = 5
- self.hs.config.rc_login_account.per_second = 0.17
-
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -115,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config(
+ {
+ "rc_login": {
+ # Prevent the address login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "address": {"per_second": 10000, "burst_count": 10000},
+ "failed_attempts": {"per_second": 0.17, "burst_count": 5},
+ }
+ }
+ )
def test_POST_ratelimiting_per_account_failed_attempts(self):
- self.hs.config.rc_login_failed_attempts.burst_count = 5
- self.hs.config.rc_login_failed_attempts.per_second = 0.17
-
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -256,6 +289,72 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard logout this session
+ request, channel = self.make_request(
+ b"POST", "/logout", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard log out all of the user's sessions
+ request, channel = self.make_request(
+ b"POST", "/logout/all", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
class CASTestCase(unittest.HomeserverTestCase):
@@ -406,3 +505,153 @@ class CASTestCase(unittest.HomeserverTestCase):
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
self.assertIn(b"SSO account deactivated", channel.result["body"])
+
+
+class JWTTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ jwt_secret = "secret"
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.jwt_enabled = True
+ self.hs.config.jwt_secret = self.jwt_secret
+ self.hs.config.jwt_algorithm = "HS256"
+ return self.hs
+
+ def jwt_encode(self, token, secret=jwt_secret):
+ return jwt.encode(token, secret, "HS256").decode("ascii")
+
+ def jwt_login(self, *args):
+ params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ return channel
+
+ def test_login_jwt_valid_registered(self):
+ self.register_user("kermit", "monkey")
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ def test_login_jwt_valid_unregistered(self):
+ channel = self.jwt_login({"sub": "frog"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
+ def test_login_jwt_invalid_signature(self):
+ channel = self.jwt_login({"sub": "frog"}, "notsecret")
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_jwt_expired(self):
+ channel = self.jwt_login({"sub": "frog", "exp": 864000})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "JWT expired")
+
+ def test_login_jwt_not_before(self):
+ now = int(time.time())
+ channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_no_sub(self):
+ channel = self.jwt_login({"username": "root"})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_no_token(self):
+ params = json.dumps({"type": "m.login.jwt"})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
+
+
+# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
+# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
+# signed by the private key.
+class JWTPubKeyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ ]
+
+ # This key's pubkey is used as the jwt_secret setting of synapse. Valid
+ # tokens are signed by this and validated using the pubkey. It is generated
+ # with `openssl genrsa 512` (not a secure way to generate real keys, but
+ # good enough for tests!)
+ jwt_privatekey = "\n".join(
+ [
+ "-----BEGIN RSA PRIVATE KEY-----",
+ "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB",
+ "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk",
+ "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/",
+ "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq",
+ "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN",
+ "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA",
+ "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=",
+ "-----END RSA PRIVATE KEY-----",
+ ]
+ )
+
+ # Generated with `openssl rsa -in foo.key -pubout`, with the the above
+ # private key placed in foo.key (jwt_privatekey).
+ jwt_pubkey = "\n".join(
+ [
+ "-----BEGIN PUBLIC KEY-----",
+ "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7",
+ "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==",
+ "-----END PUBLIC KEY-----",
+ ]
+ )
+
+ # This key is used to sign tokens that shouldn't be accepted by synapse.
+ # Generated just like jwt_privatekey.
+ bad_privatekey = "\n".join(
+ [
+ "-----BEGIN RSA PRIVATE KEY-----",
+ "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv",
+ "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L",
+ "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY",
+ "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I",
+ "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb",
+ "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0",
+ "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m",
+ "-----END RSA PRIVATE KEY-----",
+ ]
+ )
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.jwt_enabled = True
+ self.hs.config.jwt_secret = self.jwt_pubkey
+ self.hs.config.jwt_algorithm = "RS256"
+ return self.hs
+
+ def jwt_encode(self, token, secret=jwt_privatekey):
+ return jwt.encode(token, secret, "RS256").decode("ascii")
+
+ def jwt_login(self, *args):
+ params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ return channel
+
+ def test_login_jwt_valid(self):
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ def test_login_jwt_invalid_signature(self):
+ channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 7dd86d0c..4886bbb4 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -20,7 +20,7 @@
import json
-from mock import Mock, NonCallableMock
+from mock import Mock
from six.moves.urllib import parse as urlparse
from twisted.internet import defer
@@ -46,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
+ "red", http_client=None, federation_client=Mock(),
)
- self.ratelimiter = self.hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
self.hs.get_federation_handler = Mock(return_value=Mock())
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 4bc3aaf0..18260bb9 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -16,7 +16,7 @@
"""Tests REST events for /rooms paths."""
-from mock import Mock, NonCallableMock
+from mock import Mock
from twisted.internet import defer
@@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
+ "red", http_client=None, federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
-
hs.get_handlers().federation_handler = Mock()
def get_user_by_access_token(token=None, allow_guest=False):
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 0d6936fd..3ab611f6 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -46,7 +46,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
self.email_attempts.append(msg)
return
@@ -358,7 +358,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
self.email_attempts.append(msg)
config["email"] = {
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index a68a96f6..7deaf5b2 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -25,10 +25,11 @@ import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import account, account_validity, register, sync
from tests import unittest
+from tests.unittest import override_config
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@@ -146,10 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+ @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting_guest(self):
- self.hs.config.rc_registration.burst_count = 5
- self.hs.config.rc_registration.per_second = 0.17
-
for i in range(0, 6):
url = self.url + b"?kind=guest"
request, channel = self.make_request(b"POST", url, b"{}")
@@ -168,10 +167,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self):
- self.hs.config.rc_registration.burst_count = 5
- self.hs.config.rc_registration.per_second = 0.17
-
for i in range(0, 6):
params = {
"username": "kermit" + str(i),
@@ -313,6 +310,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
sync.register_servlets,
+ logout.register_servlets,
account_validity.register_servlets,
]
@@ -405,6 +403,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
+ def test_logging_out_expired_user(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Try to log the user out
+ request, channel = self.make_request(b"POST", "/logout", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Log the user in again (allowed for expired accounts)
+ tok = self.login("kermit", "monkey")
+
+ # Try to log out all of the user's sessions
+ request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 1809ceb8..1ca648ef 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -18,10 +18,16 @@ import os
import shutil
import tempfile
from binascii import unhexlify
+from io import BytesIO
+from typing import Optional
from mock import Mock
from six.moves.urllib import parse
+import attr
+import PIL.Image as Image
+from parameterized import parameterized_class
+
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
@@ -94,6 +100,68 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
+@attr.s
+class _TestImage:
+ """An image for testing thumbnailing with the expected results
+
+ Attributes:
+ data: The raw image to thumbnail
+ content_type: The type of the image as a content type, e.g. "image/png"
+ extension: The extension associated with the format, e.g. ".png"
+ expected_cropped: The expected bytes from cropped thumbnailing, or None if
+ test should just check for success.
+ expected_scaled: The expected bytes from scaled thumbnailing, or None if
+ test should just check for a valid image returned.
+ """
+
+ data = attr.ib(type=bytes)
+ content_type = attr.ib(type=bytes)
+ extension = attr.ib(type=bytes)
+ expected_cropped = attr.ib(type=Optional[bytes])
+ expected_scaled = attr.ib(type=Optional[bytes])
+
+
+@parameterized_class(
+ ("test_image",),
+ [
+ # smol png
+ (
+ _TestImage(
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ ),
+ b"image/png",
+ b".png",
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000020000000200806"
+ b"000000737a7af40000001a49444154789cedc101010000008220"
+ b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
+ b"44ae426082"
+ ),
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000d49444154789c636060606000000005"
+ b"0001a5f645400000000049454e44ae426082"
+ ),
+ ),
+ ),
+ # small lossless webp
+ (
+ _TestImage(
+ unhexlify(
+ b"524946461a000000574542505650384c0d0000002f0000001007"
+ b"1011118888fe0700"
+ ),
+ b"image/webp",
+ b".webp",
+ None,
+ None,
+ ),
+ ),
+ ],
+)
class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True
@@ -151,13 +219,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.download_resource = self.media_repo.children[b"download"]
self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
- # smol png
- self.end_content = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
self.media_id = "example.com/12345"
def _req(self, content_disposition):
@@ -176,14 +237,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
headers = {
- b"Content-Length": [b"%d" % (len(self.end_content))],
- b"Content-Type": [b"image/png"],
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ b"Content-Type": [self.test_image.content_type],
}
if content_disposition:
headers[b"Content-Disposition"] = [content_disposition]
self.fetches[0][0].callback(
- (self.end_content, (len(self.end_content), headers))
+ (self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
@@ -196,12 +257,15 @@ class MediaRepoTests(unittest.HomeserverTestCase):
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
"""
- channel = self._req(b"inline; filename=out.png")
+ channel = self._req(b"inline; filename=out" + self.test_image.extension)
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
self.assertEqual(
- headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"]
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"),
+ [b"inline; filename=out" + self.test_image.extension],
)
def test_disposition_filenamestar_utf8escaped(self):
@@ -211,13 +275,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
response.
"""
filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
- channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
+ channel = self._req(
+ b"inline; filename*=utf-8''" + filename + self.test_image.extension
+ )
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
self.assertEqual(
headers.getRawHeaders(b"Content-Disposition"),
- [b"inline; filename*=utf-8''" + filename + b".png"],
+ [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
)
def test_disposition_none(self):
@@ -228,27 +296,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = self._req(None)
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
- expected_body = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000020000000200806"
- b"000000737a7af40000001a49444154789cedc101010000008220"
- b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
- b"44ae426082"
- )
-
- self._test_thumbnail("crop", expected_body)
+ self._test_thumbnail("crop", self.test_image.expected_cropped)
def test_thumbnail_scale(self):
- expected_body = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000d49444154789c636060606000000005"
- b"0001a5f645400000000049454e44ae426082"
- )
-
- self._test_thumbnail("scale", expected_body)
+ self._test_thumbnail("scale", self.test_image.expected_scaled)
def _test_thumbnail(self, method, expected_body):
params = "?width=32&height=32&method=" + method
@@ -259,13 +316,19 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.pump()
headers = {
- b"Content-Length": [b"%d" % (len(self.end_content))],
- b"Content-Type": [b"image/png"],
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ b"Content-Type": [self.test_image.content_type],
}
self.fetches[0][0].callback(
- (self.end_content, (len(self.end_content), headers))
+ (self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.result["body"], expected_body, channel.result["body"])
+ if expected_body is not None:
+ self.assertEqual(
+ channel.result["body"], expected_body, channel.result["body"]
+ )
+ else:
+ # ensure that the result is at least some valid image
+ Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 406f29a7..99908edb 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -27,20 +27,33 @@ from synapse.server_notices.resource_limits_server_notices import (
)
from tests import unittest
+from tests.unittest import override_config
+from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- hs_config = self.default_config()
- hs_config["server_notices"] = {
- "system_mxid_localpart": "server",
- "system_mxid_display_name": "test display name",
- "system_mxid_avatar_url": None,
- "room_name": "Server Notices",
- }
+ def default_config(self):
+ config = default_config("test")
+
+ config.update(
+ {
+ "admin_contact": "mailto:user@test.com",
+ "limit_usage_by_mau": True,
+ "server_notices": {
+ "system_mxid_localpart": "server",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ "room_name": "Server Notices",
+ },
+ }
+ )
+
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
- hs = self.setup_test_homeserver(config=hs_config)
- return hs
+ return config
def prepare(self, reactor, clock, hs):
self.server_notices_sender = self.hs.get_server_notices_sender()
@@ -60,7 +73,6 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
- self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
@@ -68,21 +80,17 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
- self.hs.config.admin_contact = "mailto:user@test.com"
-
- def test_maybe_send_server_notice_to_user_flag_off(self):
- """Tests cases where the flags indicate nothing to do"""
- # test hs disabled case
- self.hs.config.hs_disabled = True
+ @override_config({"hs_disabled": True})
+ def test_maybe_send_server_notice_disabled_hs(self):
+ """If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
-
self._send_notice.assert_not_called()
- # Test when mau limiting disabled
- self.hs.config.hs_disabled = False
- self.hs.config.limit_usage_by_mau = False
- self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+ @override_config({"limit_usage_by_mau": False})
+ def test_maybe_send_server_notice_to_user_flag_off(self):
+ """If mau limiting is disabled, we should not send notices"""
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
@@ -153,13 +161,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
+ @override_config({"mau_limit_alerting": False})
def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
"""
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
- self.hs.config.mau_limit_alerting = False
-
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
@@ -170,12 +177,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 0)
+ @override_config({"mau_limit_alerting": False})
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
- self.hs.config.mau_limit_alerting = False
-
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
@@ -187,12 +193,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# Would be better to check contents, but 2 calls == set blocking event
self.assertEqual(self._send_notice.call_count, 2)
+ @override_config({"mau_limit_alerting": False})
def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
"""
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
- self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index e37260a8..5a50e4fd 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached
from tests import unittest
-class CacheTestCase(unittest.TestCase):
- def setUp(self):
+class CacheTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
self.cache = Cache("test")
def test_empty(self):
@@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
-class CacheDecoratorTestCase(unittest.TestCase):
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
class A(object):
@@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
- @cached(max_entries=4) # HACK: This makes it 2 due to cache factor
+ @cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 31710949..ef296e7d 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = self.as_yaml_files
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
self.as_token = "token1"
@@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = self.as_yaml_files
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
self.as_list = [
@@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
database = hs.get_datastores().databases[0]
@@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
@@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index cdee0a9e..278961c3 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config._disable_native_upserts = True
- config.event_cache_size = 1
+ config.caches = Mock()
+ config.caches.event_cache_size = 1
hs = TestHomeServer("test", config=config)
sqlite_config = {"name": "sqlite3"}
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 0e04b2cf..43425c96 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, None, None)
- info = self.get_success(self.room_creator.create_room(self.requester, {}))
+ info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
def run_background_update(self):
@@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, None, None)
- info = self.get_success(self.room_creator.create_room(self.requester, {}))
+ info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.user_consent_version = self.CONSENT_VERSION
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index bf674dd1..3b483bc7 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
@@ -137,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@@ -149,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_full(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
lots_of_users = 100
user_id = "@user:server"
@@ -166,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_space(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@@ -184,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_updating_monthly_active_user_when_space(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b7fd36..a7b85004 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
events = [(3, 2), (6, 2), (4, 6)]
for event_count, extrems in events:
- info = self.get_success(room_creator.create_room(requester, {}))
+ info, _ = self.get_success(room_creator.create_room(requester, {}))
room_id = info["room_id"]
last_event = None
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index d4bcf182..b45bc9c1 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
+ self.persist_events_store = hs.get_datastores().persist_events
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
@@ -76,7 +77,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
yield self.store.db.runInteraction(
"",
- self.store._set_push_actions_for_event_and_users_txn,
+ self.persist_events_store._set_push_actions_for_event_and_users_txn,
[(event, None)],
[(event, None)],
)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 00000000..55e9ecf2
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db = self.store.db # type: Database
+
+ self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ )
+
+ return self.get_success(self.db.runWithConnection(_create))
+
+ def _insert_rows(self, instance_name: str, number: int):
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+ (instance_name,),
+ )
+
+ self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+ def test_empty(self):
+ """Test an ID generator against an empty database gives sensible
+ current positions.
+ """
+
+ id_gen = self._create_id_generator()
+
+ # The table is empty so we expect an empty map for positions
+ self.assertEqual(id_gen.get_positions(), {})
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
+
+ def test_multi_instance(self):
+ """Test that reads and writes from multiple processes are handled
+ correctly.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first")
+ second_id_gen = self._create_id_generator("second")
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await first_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+ # However the ID gen on the second instance won't have seen the update
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+ # ... but calling `get_next` on the second instance should give a unique
+ # stream ID
+
+ async def _get_next_async():
+ with await second_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 9)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+ # If the second ID gen gets told about the first, it correctly updates
+ second_id_gen.advance("first", 8)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+ def test_get_next_txn(self):
+ """Test that the `get_next_txn` function works correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ def _get_next_txn(txn):
+ stream_id = id_gen.get_next_txn(txn)
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index bc53bf09..9c04e925 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,158 +19,180 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
+from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
+def gen_3pids(count):
+ """Generate `count` threepids as a list."""
+ return [
+ {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count)
+ ]
+
+
class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def default_config(self):
+ config = default_config("test")
+
+ config.update({"limit_usage_by_mau": True, "max_mau_value": 50})
+
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
- hs = self.setup_test_homeserver()
- self.store = hs.get_datastore()
- hs.config.limit_usage_by_mau = True
- hs.config.max_mau_value = 50
+ return config
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
# Advance the clock a bit
reactor.advance(FORTY_DAYS)
- return hs
-
+ @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
def test_initialise_reserved_users(self):
- self.hs.config.max_mau_value = 5
+ threepids = self.hs.config.mau_limits_reserved_threepids
+
+ # register three users, of which two have reserved 3pids, and a third
+ # which is a support user.
user1 = "@user1:server"
- user1_email = "user1@matrix.org"
+ user1_email = threepids[0]["address"]
user2 = "@user2:server"
- user2_email = "user2@matrix.org"
+ user2_email = threepids[1]["address"]
user3 = "@user3:server"
- user3_email = "user3@matrix.org"
- threepids = [
- {"medium": "email", "address": user1_email},
- {"medium": "email", "address": user2_email},
- {"medium": "email", "address": user3_email},
- ]
- self.hs.config.mau_limits_reserved_threepids = threepids
- # -1 because user3 is a support user and does not count
- user_num = len(threepids) - 1
-
- self.store.register_user(user_id=user1, password_hash=None)
- self.store.register_user(user_id=user2, password_hash=None)
- self.store.register_user(
- user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT
+ self.get_success(self.store.register_user(user_id=user1))
+ self.get_success(self.store.register_user(user_id=user2))
+ self.get_success(
+ self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT)
)
- self.pump()
now = int(self.hs.get_clock().time_msec())
- self.store.user_add_threepid(user1, "email", user1_email, now, now)
- self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ self.get_success(
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ )
+ self.get_success(
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ )
- self.store.db.runInteraction(
- "initialise", self.store._initialise_reserved_users, threepids
+ # XXX why are we doing this here? this function is only run at startup
+ # so it is odd to re-run it here.
+ self.get_success(
+ self.store.db.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
)
- self.pump()
- active_count = self.store.get_monthly_active_count()
+ # the number of users we expect will be counted against the mau limit
+ # -1 because user3 is a support user and does not count
+ user_num = len(threepids) - 1
- # Test total counts, ensure user3 (support user) is not counted
- self.assertEquals(self.get_success(active_count), user_num)
+ # Check the number of active users. Ensure user3 (support user) is not counted
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, user_num)
- # Test user is marked as active
- timestamp = self.store.user_last_seen_monthly_active(user1)
- self.assertTrue(self.get_success(timestamp))
- timestamp = self.store.user_last_seen_monthly_active(user2)
- self.assertTrue(self.get_success(timestamp))
+ # Test each of the registered users is marked as active
+ timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1))
+ self.assertGreater(timestamp, 0)
+ timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2))
+ self.assertGreater(timestamp, 0)
- # Test that users are never removed from the db.
+ # Test that users with reserved 3pids are not removed from the MAU table
+ # XXX some of this is redundant. poking things into the config shouldn't
+ # work, and in any case it's not obvious what we expect to happen when
+ # we advance the reactor.
self.hs.config.max_mau_value = 0
-
self.reactor.advance(FORTY_DAYS)
self.hs.config.max_mau_value = 5
- self.store.reap_monthly_active_users()
- self.pump()
+ self.get_success(self.store.reap_monthly_active_users())
- active_count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(active_count), user_num)
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, user_num)
- # Test that regular users are removed from the db
+ # Add some more users and check they are counted as active
ru_count = 2
- self.store.upsert_monthly_active_user("@ru1:server")
- self.store.upsert_monthly_active_user("@ru2:server")
- self.pump()
- active_count = self.store.get_monthly_active_count()
- self.assertEqual(self.get_success(active_count), user_num + ru_count)
- self.hs.config.max_mau_value = user_num
- self.store.reap_monthly_active_users()
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user("@ru1:server"))
+ self.get_success(self.store.upsert_monthly_active_user("@ru2:server"))
+
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, user_num + ru_count)
+
+ # now run the reaper and check that the number of active users is reduced
+ # to max_mau_value
+ self.get_success(self.store.reap_monthly_active_users())
- active_count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(active_count), user_num)
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, 3)
def test_can_insert_and_count_mau(self):
- count = self.store.get_monthly_active_count()
- self.assertEqual(0, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
- self.store.upsert_monthly_active_user("@user:server")
- self.pump()
+ d = self.store.upsert_monthly_active_user("@user:server")
+ self.get_success(d)
- count = self.store.get_monthly_active_count()
- self.assertEqual(1, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 1)
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
- result = self.store.user_last_seen_monthly_active(user_id1)
- self.assertFalse(self.get_success(result) == 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+ self.assertNotEqual(result, 0)
- self.store.upsert_monthly_active_user(user_id1)
- self.store.upsert_monthly_active_user(user_id2)
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user(user_id1))
+ self.get_success(self.store.upsert_monthly_active_user(user_id2))
- result = self.store.user_last_seen_monthly_active(user_id1)
- self.assertGreater(self.get_success(result), 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+ self.assertGreater(result, 0)
- result = self.store.user_last_seen_monthly_active(user_id3)
- self.assertNotEqual(self.get_success(result), 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id3))
+ self.assertNotEqual(result, 0)
+ @override_config({"max_mau_value": 5})
def test_reap_monthly_active_users(self):
- self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
- self.store.upsert_monthly_active_user("@user%d:server" % i)
- self.pump()
+ self.get_success(
+ self.store.upsert_monthly_active_user("@user%d:server" % i)
+ )
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, initial_users)
- count = self.store.get_monthly_active_count()
- self.assertTrue(self.get_success(count), initial_users)
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
- self.store.reap_monthly_active_users()
- self.pump()
- count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, self.hs.config.max_mau_value)
self.reactor.advance(FORTY_DAYS)
- self.store.reap_monthly_active_users()
- self.pump()
- count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(count), 0)
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ # Note that below says mau_limit (no s), this is the name of the config
+ # value, although it gets stored on the config object as mau_limits.
+ @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
def test_reap_monthly_active_users_reserved_users(self):
""" Tests that reaping correctly handles reaping where reserved users are
present"""
-
- self.hs.config.max_mau_value = 5
- initial_users = 5
+ threepids = self.hs.config.mau_limits_reserved_threepids
+ initial_users = len(threepids)
reserved_user_number = initial_users - 1
- threepids = []
for i in range(initial_users):
user = "@user%d:server" % i
- email = "user%d@example.com" % i
+ email = "user%d@matrix.org" % i
+
self.get_success(self.store.upsert_monthly_active_user(user))
- threepids.append({"medium": "email", "address": email})
+
# Need to ensure that the most recent entries in the
# monthly_active_users table are reserved
now = int(self.hs.get_clock().time_msec())
@@ -182,27 +204,37 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user, "email", email, now, now)
)
- self.hs.config.mau_limits_reserved_threepids = threepids
- self.store.db.runInteraction(
+ d = self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
- count = self.store.get_monthly_active_count()
- self.assertTrue(self.get_success(count), initial_users)
+ self.get_success(d)
- users = self.store.get_registered_reserved_users()
- self.assertEquals(len(self.get_success(users)), reserved_user_number)
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, initial_users)
- self.get_success(self.store.reap_monthly_active_users())
- count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEqual(len(users), reserved_user_number)
+
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, self.hs.config.max_mau_value)
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
user_id = "@user_id:host"
- self.store.register_user(user_id=user_id, password_hash=None, make_guest=True)
+
+ d = self.store.register_user(
+ user_id=user_id, password_hash=None, make_guest=True
+ )
+ self.get_success(d)
+
self.store.upsert_monthly_active_user = Mock()
- self.store.populate_monthly_active_users(user_id)
- self.pump()
+
+ d = self.store.populate_monthly_active_users(user_id)
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
@@ -213,8 +245,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
- self.store.populate_monthly_active_users("user_id")
- self.pump()
+ d = self.store.populate_monthly_active_users("user_id")
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
@@ -224,16 +257,18 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
- self.store.populate_monthly_active_users("user_id")
- self.pump()
+
+ d = self.store.populate_monthly_active_users("user_id")
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_not_called()
def test_get_reserved_real_user_account(self):
# Test no reserved users, or reserved threepids
users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEquals(len(users), 0)
- # Test reserved users but no registered users
+ self.assertEqual(len(users), 0)
+ # Test reserved users but no registered users
user1 = "@user1:example.com"
user2 = "@user2:example.com"
@@ -243,64 +278,64 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": user1_email},
{"medium": "email", "address": user2_email},
]
+
self.hs.config.mau_limits_reserved_threepids = threepids
- self.store.db.runInteraction(
+ d = self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
+ self.get_success(d)
- self.pump()
users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEquals(len(users), 0)
+ self.assertEqual(len(users), 0)
- # Test reserved registed users
- self.store.register_user(user_id=user1, password_hash=None)
- self.store.register_user(user_id=user2, password_hash=None)
- self.pump()
+ # Test reserved registered users
+ self.get_success(self.store.register_user(user_id=user1, password_hash=None))
+ self.get_success(self.store.register_user(user_id=user2, password_hash=None))
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEquals(len(users), len(threepids))
+ self.assertEqual(len(users), len(threepids))
def test_support_user_not_add_to_mau_limits(self):
support_user_id = "@support:test"
- count = self.store.get_monthly_active_count()
- self.pump()
- self.assertEqual(self.get_success(count), 0)
- self.store.register_user(
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ d = self.store.register_user(
user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
+ self.get_success(d)
- self.store.upsert_monthly_active_user(support_user_id)
- count = self.store.get_monthly_active_count()
- self.pump()
- self.assertEqual(self.get_success(count), 0)
+ d = self.store.upsert_monthly_active_user(support_user_id)
+ self.get_success(d)
- def test_track_monthly_users_without_cap(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.mau_stats_only = True
- self.hs.config.max_mau_value = 1 # should not matter
+ d = self.store.get_monthly_active_count()
+ count = self.get_success(d)
+ self.assertEqual(count, 0)
- count = self.store.get_monthly_active_count()
- self.assertEqual(0, self.get_success(count))
+ # Note that the max_mau_value setting should not matter.
+ @override_config(
+ {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
+ )
+ def test_track_monthly_users_without_cap(self):
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(0, count)
- self.store.upsert_monthly_active_user("@user1:server")
- self.store.upsert_monthly_active_user("@user2:server")
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user("@user1:server"))
+ self.get_success(self.store.upsert_monthly_active_user("@user2:server"))
- count = self.store.get_monthly_active_count()
- self.assertEqual(2, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(2, count)
+ @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.mau_stats_only = False
self.store.upsert_monthly_active_user = Mock()
- self.store.populate_monthly_active_users("@user:sever")
- self.pump()
+ self.get_success(self.store.populate_monthly_active_users("@user:sever"))
self.store.upsert_monthly_active_user.assert_not_called()
@@ -315,33 +350,39 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
service2 = "service2"
native = "native"
- self.store.register_user(
- user_id=appservice1_user1, password_hash=None, appservice_id=service1
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user1, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user2, password_hash=None, appservice_id=service1
+ )
)
- self.store.register_user(
- user_id=appservice1_user2, password_hash=None, appservice_id=service1
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice2_user1, password_hash=None, appservice_id=service2
+ )
)
- self.store.register_user(
- user_id=appservice2_user1, password_hash=None, appservice_id=service2
+ self.get_success(
+ self.store.register_user(user_id=native_user1, password_hash=None)
)
- self.store.register_user(user_id=native_user1, password_hash=None)
- self.pump()
- count = self.store.get_monthly_active_count_by_service()
- self.assertEqual({}, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count_by_service())
+ self.assertEqual(count, {})
- self.store.upsert_monthly_active_user(native_user1)
- self.store.upsert_monthly_active_user(appservice1_user1)
- self.store.upsert_monthly_active_user(appservice1_user2)
- self.store.upsert_monthly_active_user(appservice2_user1)
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user(native_user1))
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
+ self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
- count = self.store.get_monthly_active_count()
- self.assertEqual(4, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 4)
- count = self.store.get_monthly_active_count_by_service()
- result = self.get_success(count)
+ d = self.store.get_monthly_active_count_by_service()
+ result = self.get_success(d)
- self.assertEqual(2, result[service1])
- self.assertEqual(1, result[service2])
- self.assertEqual(1, result[native])
+ self.assertEqual(result[service1], 2)
+ self.assertEqual(result[service2], 1)
+ self.assertEqual(result[native], 1)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 086adeb8..3b78d488 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -55,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room(self.room.to_string())),
)
+ @defer.inlineCallbacks
+ def test_get_room_with_stats(self):
+ self.assertDictContainsSubset(
+ {
+ "room_id": self.room.to_string(),
+ "creator": self.u_creator.to_string(),
+ "public": True,
+ },
+ (yield self.store.get_room_with_stats(self.room.to_string())),
+ )
+
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 6c2351cf..69b4c5d6 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -136,21 +136,18 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send aliases
event_auth.check(
- RoomVersions.MSC2432_DEV,
- _alias_event(creator),
- auth_events,
- do_sig_check=False,
+ RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False,
)
# No particular checks are done on the state key.
event_auth.check(
- RoomVersions.MSC2432_DEV,
+ RoomVersions.V6,
_alias_event(creator, state_key=""),
auth_events,
do_sig_check=False,
)
event_auth.check(
- RoomVersions.MSC2432_DEV,
+ RoomVersions.V6,
_alias_event(creator, state_key="test.com"),
auth_events,
do_sig_check=False,
@@ -159,8 +156,38 @@ class EventAuthTestCase(unittest.TestCase):
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
event_auth.check(
- RoomVersions.MSC2432_DEV,
- _alias_event(other),
+ RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False,
+ )
+
+ def test_msc2209(self):
+ """
+ Notifications power levels get checked due to MSC2209.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.power_levels", ""): _power_levels_event(
+ creator, {"state_default": "30", "users": {pleb: "30"}}
+ ),
+ ("m.room.member", pleb): _join_event(pleb),
+ }
+
+ # pleb should be able to modify the notifications power level.
+ event_auth.check(
+ RoomVersions.V1,
+ _power_levels_event(pleb, {"notifications": {"room": 100}}),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # But an MSC2209 room rejects this change.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _power_levels_event(pleb, {"notifications": {"room": 100}}),
auth_events,
do_sig_check=False,
)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index f297de95..c662195e 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -6,12 +6,13 @@ from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID
from synapse.util import Clock
+from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
-class MessageAcceptTests(unittest.TestCase):
+class MessageAcceptTests(unittest.HomeserverTestCase):
def setUp(self):
self.http_client = Mock()
@@ -27,13 +28,13 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
- room = ensureDeferred(
+ room_deferred = ensureDeferred(
room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
- self.room_id = self.successResultOf(room)["room_id"]
+ self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
self.store = self.homeserver.get_datastore()
@@ -145,3 +146,119 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure the invalid event isn't there
extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+
+ def test_retry_device_list_resync(self):
+ """Tests that device lists are marked as stale if they couldn't be synced, and
+ that stale device lists are retried periodically.
+ """
+ remote_user_id = "@john:test_remote"
+ remote_origin = "test_remote"
+
+ # Track the number of attempts to resync the user's device list.
+ self.resync_attempts = 0
+
+ # When this function is called, increment the number of resync attempts (only if
+ # we're querying devices for the right user ID), then raise a
+ # NotRetryingDestination error to fail the resync gracefully.
+ def query_user_devices(destination, user_id):
+ if user_id == remote_user_id:
+ self.resync_attempts += 1
+
+ raise NotRetryingDestination(0, 0, destination)
+
+ # Register the mock on the federation client.
+ federation_client = self.homeserver.get_federation_client()
+ federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+
+ # Register a mock on the store so that the incoming update doesn't fail because
+ # we don't share a room with the user.
+ store = self.homeserver.get_datastore()
+ store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
+
+ # Manually inject a fake device list update. We need this update to include at
+ # least one prev_id so that the user's device list will need to be retried.
+ device_list_updater = self.homeserver.get_device_handler().device_list_updater
+ self.get_success(
+ device_list_updater.incoming_device_list_update(
+ origin=remote_origin,
+ edu_content={
+ "deleted": False,
+ "device_display_name": "Mobile",
+ "device_id": "QBUAZIFURK",
+ "prev_id": [5],
+ "stream_id": 6,
+ "user_id": remote_user_id,
+ },
+ )
+ )
+
+ # Check that there was one resync attempt.
+ self.assertEqual(self.resync_attempts, 1)
+
+ # Check that the resync attempt failed and caused the user's device list to be
+ # marked as stale.
+ need_resync = self.get_success(
+ store.get_user_ids_requiring_device_list_resync()
+ )
+ self.assertIn(remote_user_id, need_resync)
+
+ # Check that waiting for 30 seconds caused Synapse to retry resyncing the device
+ # list.
+ self.reactor.advance(30)
+ self.assertEqual(self.resync_attempts, 2)
+
+ def test_cross_signing_keys_retry(self):
+ """Tests that resyncing a device list correctly processes cross-signing keys from
+ the remote server.
+ """
+ remote_user_id = "@john:test_remote"
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ # Register mock device list retrieval on the federation client.
+ federation_client = self.homeserver.get_federation_client()
+ federation_client.query_user_devices = Mock(
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+
+ # Resync the device list.
+ device_handler = self.homeserver.get_device_handler()
+ self.get_success(
+ device_handler.device_list_updater.user_device_resync(remote_user_id),
+ )
+
+ # Retrieve the cross-signing keys for this user.
+ keys = self.get_success(
+ self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
+ )
+ self.assertTrue(remote_user_id in keys)
+
+ # Check that the master key is the one returned by the mock.
+ master_key = keys[remote_user_id]["master"]
+ self.assertEqual(len(master_key["keys"]), 1)
+ self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
+ self.assertTrue(remote_master_key in master_key["keys"].values())
+
+ # Check that the self-signing key is the one returned by the mock.
+ self_signing_key = keys[remote_user_id]["self_signing"]
+ self.assertEqual(len(self_signing_key["keys"]), 1)
+ self.assertTrue(
+ "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
+ )
+ self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
diff --git a/tests/test_mau.py b/tests/test_mau.py
index eb159e3b..49667ed7 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -17,47 +17,44 @@
import json
-from mock import Mock
-
-from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
+from tests.unittest import override_config
+from tests.utils import default_config
class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def default_config(self):
+ config = default_config("test")
- self.hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
+ config.update(
+ {
+ "registrations_require_3pid": [],
+ "limit_usage_by_mau": True,
+ "max_mau_value": 2,
+ "mau_trial_days": 0,
+ "server_notices": {
+ "system_mxid_localpart": "server",
+ "room_name": "Test Server Notice Room",
+ },
+ }
)
- self.store = self.hs.get_datastore()
-
- self.hs.config.registrations_require_3pid = []
- self.hs.config.enable_registration_captcha = False
- self.hs.config.recaptcha_public_key = []
-
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.hs_disabled = False
- self.hs.config.max_mau_value = 2
- self.hs.config.server_notices_mxid = "@server:red"
- self.hs.config.server_notices_mxid_display_name = None
- self.hs.config.server_notices_mxid_avatar_url = None
- self.hs.config.server_notices_room_name = "Test Server Notice Room"
- self.hs.config.mau_trial_days = 0
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
- # AuthBlocking reads config options during hs creation. Recreate the
- # hs' copy of AuthBlocking after we've updated config values above
- self.auth_blocking = AuthBlocking(self.hs)
- self.hs.get_auth()._auth_blocking = self.auth_blocking
+ return config
- return self.hs
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
def test_simple_deny_mau(self):
# Create and sync so that the MAU counts get updated
@@ -66,6 +63,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
+ # check we're testing what we think we are: there should be two active users
+ self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
+
# We've created and activated two users, we shouldn't be able to
# register new users
with self.assertRaises(SynapseError) as cm:
@@ -85,7 +85,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# Advance time by 31 days
self.reactor.advance(31 * 24 * 60 * 60)
- self.store.reap_monthly_active_users()
+ self.get_success(self.store.reap_monthly_active_users())
self.reactor.advance(0)
@@ -93,9 +93,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
token3 = self.create_user("kermit3")
self.do_sync_for_user(token3)
+ @override_config({"mau_trial_days": 1})
def test_trial_delay(self):
- self.hs.config.mau_trial_days = 1
-
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -127,8 +126,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ @override_config({"mau_trial_days": 1})
def test_trial_users_cant_come_back(self):
- self.auth_blocking._mau_trial_days = 1
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -148,8 +147,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# Advance by 2 months so everyone falls out of MAU
self.reactor.advance(60 * 24 * 60 * 60)
- self.store.reap_monthly_active_users()
- self.reactor.advance(0)
+ self.get_success(self.store.reap_monthly_active_users())
# We can create as many new users as we want
token4 = self.create_user("kermit4")
@@ -176,11 +174,11 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ @override_config(
+ # max_mau_value should not matter
+ {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
+ )
def test_tracked_but_not_limited(self):
- self.auth_blocking._max_mau_value = 1 # should not matter
- self.auth_blocking._limit_usage_by_mau = False
- self.hs.config.mau_stats_only = True
-
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 270f853d..f5f63d8e 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -15,6 +15,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
+from synapse.util.caches.descriptors import Cache
from tests import unittest
@@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase):
self.assertTrue(b"osversion=" in items[0])
self.assertTrue(b"pythonversion=" in items[0])
self.assertTrue(b"version=" in items[0])
+
+
+class CacheMetricsTests(unittest.HomeserverTestCase):
+ def test_cache_metric(self):
+ """
+ Caches produce metrics reflecting their state when scraped.
+ """
+ CACHE_NAME = "cache_metrics_test_fgjkbdfg"
+ cache = Cache(CACHE_NAME, max_entries=777)
+
+ items = {
+ x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
+ for x in filter(
+ lambda x: b"cache_metrics_test_fgjkbdfg" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ }
+
+ self.assertEqual(items["synapse_util_caches_cache_size"], "0.0")
+ self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
+
+ cache.prefill("1", "hi")
+
+ items = {
+ x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
+ for x in filter(
+ lambda x: b"cache_metrics_test_fgjkbdfg" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ }
+
+ self.assertEqual(items["synapse_util_caches_cache_size"], "1.0")
+ self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
diff --git a/tests/test_server.py b/tests/test_server.py
index 0d57eed2..e9a43b1e 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -27,6 +27,7 @@ from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.http.server import (
DirectServeResource,
JsonResource,
+ OptionsResource,
wrap_html_request_handler,
)
from synapse.http.site import SynapseSite, logger
@@ -168,6 +169,86 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+class OptionsResourceTests(unittest.TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ class DummyResource(Resource):
+ isLeaf = True
+
+ def render(self, request):
+ return request.path
+
+ # Setup a resource with some children.
+ self.resource = OptionsResource()
+ self.resource.putChild(b"res", DummyResource())
+
+ def _make_request(self, method, path):
+ """Create a request from the method/path and return a channel with the response."""
+ request, channel = make_request(self.reactor, method, path, shorthand=False)
+ request.prepath = [] # This doesn't get set properly by make_request.
+
+ # Create a site and query for the resource.
+ site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+ request.site = site
+ resource = site.getResourceFor(request)
+
+ # Finally, render the resource and return the channel.
+ render(request, resource, self.reactor)
+ return channel
+
+ def test_unknown_options_request(self):
+ """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ channel = self._make_request(b"OPTIONS", b"/foo/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"{}")
+
+ # Ensure the correct CORS headers have been added
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+ "has CORS Origin header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+ "has CORS Methods header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+ "has CORS Headers header",
+ )
+
+ def test_known_options_request(self):
+ """An OPTIONS requests to an known URL still returns 200 OK."""
+ channel = self._make_request(b"OPTIONS", b"/res/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"{}")
+
+ # Ensure the correct CORS headers have been added
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+ "has CORS Origin header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+ "has CORS Methods header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+ "has CORS Headers header",
+ )
+
+ def test_unknown_request(self):
+ """A non-OPTIONS request to an unknown URL should 404."""
+ channel = self._make_request(b"GET", b"/foo/")
+ self.assertEqual(channel.result["code"], b"404")
+
+ def test_known_request(self):
+ """A non-OPTIONS request to an known URL should query the proper resource."""
+ channel = self._make_request(b"GET", b"/res/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"/res/")
+
+
class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeResource):
callback = None
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 50bc7702..49ffeebd 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -21,7 +21,7 @@ from tests.utils import MockClock
from .. import unittest
-class ExpiringCacheTestCase(unittest.TestCase):
+class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 852ef231..ca3858b1 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -45,6 +45,38 @@ class LinearizerTestCase(unittest.TestCase):
with (yield d2):
pass
+ @defer.inlineCallbacks
+ def test_linearizer_is_queued(self):
+ linearizer = Linearizer()
+
+ key = object()
+
+ d1 = linearizer.queue(key)
+ cm1 = yield d1
+
+ # Since d1 gets called immediately, "is_queued" should return false.
+ self.assertFalse(linearizer.is_queued(key))
+
+ d2 = linearizer.queue(key)
+ self.assertFalse(d2.called)
+
+ # Now d2 is queued up behind successful completion of cm1
+ self.assertTrue(linearizer.is_queued(key))
+
+ with cm1:
+ self.assertFalse(d2.called)
+
+ # cm1 still not done, so d2 still queued.
+ self.assertTrue(linearizer.is_queued(key))
+
+ # And now d2 is called and nothing is in the queue again
+ self.assertFalse(linearizer.is_queued(key))
+
+ with (yield d2):
+ self.assertFalse(linearizer.is_queued(key))
+
+ self.assertFalse(linearizer.is_queued(key))
+
def test_lots_of_queued_things(self):
# we have one slow thing, and lots of fast things queued up behind it.
# it should *not* explode the stack.
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 78694737..0adb2174 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache
from .. import unittest
-class LruCacheTestCase(unittest.TestCase):
+class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
cache = LruCache(1)
cache["key"] = "value"
@@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(len(cache), 0)
-class LruCacheCallbacksTestCase(unittest.TestCase):
+class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
@@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m3.call_count, 1)
-class LruCacheSizedTestCase(unittest.TestCase):
+class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 68579335..13b753e3 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -1,11 +1,9 @@
-from mock import patch
-
from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests import unittest
-class StreamChangeCacheTests(unittest.TestCase):
+class StreamChangeCacheTests(unittest.HomeserverTestCase):
"""
Tests for StreamChangeCache.
"""
@@ -54,7 +52,6 @@ class StreamChangeCacheTests(unittest.TestCase):
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
- @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
def test_entity_has_changed_pops_off_start(self):
"""
StreamChangeCache.entity_has_changed will respect the max size and
diff --git a/tests/utils.py b/tests/utils.py
index f9be62b4..59c020a0 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -167,6 +167,7 @@ def default_config(name, parse=False):
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
"update_user_directory": False,
+ "caches": {"global_factor": 1},
}
if parse:
diff --git a/tox.ini b/tox.ini
index 8aef5202..463a34d1 100644
--- a/tox.ini
+++ b/tox.ini
@@ -180,18 +180,21 @@ commands = mypy \
synapse/api \
synapse/appservice \
synapse/config \
+ synapse/event_auth.py \
synapse/events/spamcheck.py \
- synapse/federation/federation_base.py \
- synapse/federation/federation_client.py \
- synapse/federation/federation_server.py \
- synapse/federation/sender \
- synapse/federation/transport \
+ synapse/federation \
synapse/handlers/auth.py \
synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \
+ synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
+ synapse/handlers/room_member.py \
+ synapse/handlers/room_member_worker.py \
+ synapse/handlers/saml_handler.py \
synapse/handlers/sync.py \
synapse/handlers/ui_auth \
+ synapse/http/server.py \
+ synapse/http/site.py \
synapse/logging/ \
synapse/metrics \
synapse/module_api \
@@ -203,9 +206,10 @@ commands = mypy \
synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
+ synapse/storage/util \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
- tests/replication/tcp/streams \
+ tests/replication \
tests/test_utils \
tests/rest/client/v2_alpha/test_auth.py \
tests/util/test_stream_change_cache.py