summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES.md135
-rw-r--r--INSTALL.md12
-rw-r--r--README.rst10
-rw-r--r--UPGRADE.rst18
-rwxr-xr-xcontrib/cmdclient/console.py21
-rw-r--r--contrib/cmdclient/http.py10
-rw-r--r--contrib/experiments/test_messaging.py55
-rw-r--r--contrib/grafana/synapse.json299
-rw-r--r--contrib/graph/graph.py21
-rw-r--r--contrib/graph/graph2.py11
-rw-r--r--contrib/graph/graph3.py22
-rw-r--r--contrib/jitsimeetbridge/jitsimeetbridge.py10
-rwxr-xr-xcontrib/scripts/kick_users.py6
-rw-r--r--debian/changelog12
-rw-r--r--docker/Dockerfile57
-rw-r--r--docker/README.md15
-rwxr-xr-xdocker/start.py12
-rw-r--r--docs/ACME.md5
-rw-r--r--docs/admin_api/purge_room.md2
-rw-r--r--docs/admin_api/rooms.md126
-rw-r--r--docs/admin_api/shutdown_room.md2
-rw-r--r--docs/admin_api/user_admin_api.rst6
-rw-r--r--docs/jwt.md19
-rw-r--r--docs/password_auth_providers.md187
-rw-r--r--docs/reverse_proxy.md16
-rw-r--r--docs/sample_config.yaml219
-rw-r--r--docs/synctl_workers.md32
-rw-r--r--docs/workers.md459
-rwxr-xr-xscripts-dev/build_debian_packages1
-rwxr-xr-xscripts-dev/lint.sh2
-rwxr-xr-xscripts/synapse_port_db12
-rw-r--r--stubs/txredisapi.pyi1
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py12
-rw-r--r--synapse/api/errors.py130
-rw-r--r--synapse/app/generic_worker.py114
-rw-r--r--synapse/app/homeserver.py25
-rw-r--r--synapse/appservice/api.py20
-rw-r--r--synapse/config/_base.py38
-rw-r--r--synapse/config/_base.pyi5
-rw-r--r--synapse/config/database.py2
-rw-r--r--synapse/config/emailconfig.py118
-rw-r--r--synapse/config/federation.py88
-rw-r--r--synapse/config/homeserver.py5
-rw-r--r--synapse/config/jwt_config.py28
-rw-r--r--synapse/config/logger.py2
-rw-r--r--synapse/config/push.py5
-rw-r--r--synapse/config/redis.py23
-rw-r--r--synapse/config/room.py7
-rw-r--r--synapse/config/server.py74
-rw-r--r--synapse/config/workers.py68
-rw-r--r--synapse/event_auth.py10
-rw-r--r--synapse/events/builder.py4
-rw-r--r--synapse/events/utils.py6
-rw-r--r--synapse/federation/federation_client.py40
-rw-r--r--synapse/federation/federation_server.py142
-rw-r--r--synapse/federation/send_queue.py14
-rw-r--r--synapse/federation/sender/__init__.py56
-rw-r--r--synapse/federation/sender/per_destination_queue.py22
-rw-r--r--synapse/federation/sender/transaction_manager.py2
-rw-r--r--synapse/federation/transport/server.py16
-rw-r--r--synapse/handlers/_base.py17
-rw-r--r--synapse/handlers/admin.py2
-rw-r--r--synapse/handlers/auth.py7
-rw-r--r--synapse/handlers/cas_handler.py2
-rw-r--r--synapse/handlers/deactivate_account.py65
-rw-r--r--synapse/handlers/device.py249
-rw-r--r--synapse/handlers/e2e_keys.py167
-rw-r--r--synapse/handlers/e2e_room_keys.py75
-rw-r--r--synapse/handlers/federation.py115
-rw-r--r--synapse/handlers/identity.py271
-rw-r--r--synapse/handlers/message.py303
-rw-r--r--synapse/handlers/presence.py47
-rw-r--r--synapse/handlers/profile.py63
-rw-r--r--synapse/handlers/receipts.py16
-rw-r--r--synapse/handlers/register.py22
-rw-r--r--synapse/handlers/room.py221
-rw-r--r--synapse/handlers/room_list.py62
-rw-r--r--synapse/handlers/search.py7
-rw-r--r--synapse/handlers/sync.py21
-rw-r--r--synapse/handlers/typing.py243
-rw-r--r--synapse/handlers/ui_auth/checkers.py38
-rw-r--r--synapse/http/client.py34
-rw-r--r--synapse/http/federation/matrix_federation_agent.py16
-rw-r--r--synapse/http/federation/srv_resolver.py10
-rw-r--r--synapse/http/server.py99
-rw-r--r--synapse/http/servlet.py10
-rw-r--r--synapse/http/site.py4
-rw-r--r--synapse/logging/context.py29
-rw-r--r--synapse/logging/opentracing.py52
-rw-r--r--synapse/logging/scopecontextmanager.py2
-rw-r--r--synapse/logging/utils.py126
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/push/mailer.py61
-rw-r--r--synapse/push/pusherpool.py78
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py2
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/replication/tcp/commands.py10
-rw-r--r--synapse/replication/tcp/handler.py427
-rw-r--r--synapse/replication/tcp/protocol.py58
-rw-r--r--synapse/replication/tcp/redis.py59
-rw-r--r--synapse/replication/tcp/streams/_base.py7
-rw-r--r--synapse/replication/tcp/streams/events.py2
-rw-r--r--synapse/res/templates/mail-Element.css7
-rw-r--r--synapse/res/templates/notice_expiry.html2
-rw-r--r--synapse/res/templates/notif_mail.html2
-rw-r--r--synapse/rest/admin/__init__.py4
-rw-r--r--synapse/rest/admin/rooms.py182
-rw-r--r--synapse/rest/admin/users.py10
-rw-r--r--synapse/rest/client/v1/login.py27
-rw-r--r--synapse/rest/client/v1/room.py22
-rw-r--r--synapse/rest/client/v2_alpha/_base.py58
-rw-r--r--synapse/rest/client/v2_alpha/sync.py9
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py4
-rw-r--r--synapse/rest/media/v1/_base.py15
-rw-r--r--synapse/rest/media/v1/media_storage.py60
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py265
-rw-r--r--synapse/server.py27
-rw-r--r--synapse/server.pyi12
-rw-r--r--synapse/state/__init__.py95
-rw-r--r--synapse/state/v1.py15
-rw-r--r--synapse/state/v2.py107
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/background_updates.py5
-rw-r--r--synapse/storage/data_stores/main/__init__.py2
-rw-r--r--synapse/storage/data_stores/main/account_data.py16
-rw-r--r--synapse/storage/data_stores/main/appservice.py4
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py9
-rw-r--r--synapse/storage/data_stores/main/devices.py2
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py10
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py2
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py4
-rw-r--r--synapse/storage/data_stores/main/events.py76
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py14
-rw-r--r--synapse/storage/data_stores/main/events_worker.py9
-rw-r--r--synapse/storage/data_stores/main/group_server.py22
-rw-r--r--synapse/storage/data_stores/main/push_rule.py8
-rw-r--r--synapse/storage/data_stores/main/pusher.py6
-rw-r--r--synapse/storage/data_stores/main/receipts.py8
-rw-r--r--synapse/storage/data_stores/main/registration.py65
-rw-r--r--synapse/storage/data_stores/main/room.py17
-rw-r--r--synapse/storage/data_stores/main/roommember.py7
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py34
-rw-r--r--synapse/storage/data_stores/main/search.py6
-rw-r--r--synapse/storage/data_stores/main/state.py65
-rw-r--r--synapse/storage/data_stores/main/stream.py97
-rw-r--r--synapse/storage/data_stores/main/tags.py5
-rw-r--r--synapse/storage/data_stores/main/ui_auth.py12
-rw-r--r--synapse/storage/data_stores/main/user_directory.py4
-rw-r--r--synapse/storage/data_stores/main/user_erasure_store.py26
-rw-r--r--synapse/storage/data_stores/state/store.py12
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite.py13
-rw-r--r--synapse/storage/persist_events.py5
-rw-r--r--synapse/storage/util/id_generators.py8
-rw-r--r--synapse/storage/util/sequence.py98
-rw-r--r--synapse/util/distributor.py28
-rw-r--r--synapse/util/stringutils.py2
-rw-r--r--tests/events/test_snapshot.py36
-rw-r--r--tests/federation/test_federation_sender.py19
-rw-r--r--tests/handlers/test_device.py13
-rw-r--r--tests/handlers/test_e2e_keys.py296
-rw-r--r--tests/handlers/test_e2e_room_keys.py373
-rw-r--r--tests/handlers/test_profile.py17
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py51
-rw-r--r--tests/http/federation/test_srv_resolver.py26
-rw-r--r--tests/replication/_base.py168
-rw-r--r--tests/replication/tcp/streams/test_events.py76
-rw-r--r--tests/replication/test_client_reader_shard.py96
-rw-r--r--tests/replication/test_federation_ack.py1
-rw-r--r--tests/replication/test_federation_sender_shard.py235
-rw-r--r--tests/replication/test_pusher_shard.py193
-rw-r--r--tests/rest/admin/test_room.py2453
-rw-r--r--tests/rest/admin/test_user.py47
-rw-r--r--tests/rest/client/v1/test_login.py133
-rw-r--r--tests/rest/media/v1/test_media_storage.py5
-rw-r--r--tests/rest/media/v1/test_url_preview.py142
-rw-r--r--tests/server.py26
-rw-r--r--tests/state/test_v2.py17
-rw-r--r--tests/storage/test_room.py16
-rw-r--r--tests/storage/test_roommember.py56
-rw-r--r--tests/storage/test_state.py4
-rw-r--r--tests/test_federation.py35
-rw-r--r--tests/test_server.py71
-rw-r--r--tests/test_state.py72
-rw-r--r--tests/test_utils/__init__.py7
-rw-r--r--tests/test_utils/event_injection.py28
-rw-r--r--tests/test_visibility.py14
-rw-r--r--tests/unittest.py4
-rw-r--r--tests/utils.py4
-rw-r--r--tox.ini3
195 files changed, 7927 insertions, 4482 deletions
diff --git a/CHANGES.md b/CHANGES.md
index 6d4bd23e..6c986808 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,138 @@
+Synapse 1.18.0 (2020-07-30)
+===========================
+
+Deprecation Warnings
+--------------------
+
+### Docker Tags with `-py3` Suffix
+
+From 10th August 2020, we will no longer publish Docker images with the `-py3` tag suffix. The images tagged with the `-py3` suffix have been identical to the non-suffixed tags since release 0.99.0, and the suffix is obsolete.
+
+On 10th August, we will remove the `latest-py3` tag. Existing per-release tags (such as `v1.18.0-py3`) will not be removed, but no new `-py3` tags will be added.
+
+Scripts relying on the `-py3` suffix will need to be updated.
+
+
+### TCP-based Replication
+
+When setting up worker processes, we now recommend the use of a Redis server for replication. The old direct TCP connection method is deprecated and will be removed in a future release. See [docs/workers.md](https://github.com/matrix-org/synapse/blob/release-v1.18.0/docs/workers.md) for more details.
+
+
+Improved Documentation
+----------------------
+
+- Update worker docs with latest enhancements. ([\#7969](https://github.com/matrix-org/synapse/issues/7969))
+
+
+Synapse 1.18.0rc2 (2020-07-28)
+==============================
+
+Bugfixes
+--------
+
+- Fix an `AssertionError` exception introduced in v1.18.0rc1. ([\#7876](https://github.com/matrix-org/synapse/issues/7876))
+- Fix experimental support for moving typing off master when worker is restarted, which is broken in v1.18.0rc1. ([\#7967](https://github.com/matrix-org/synapse/issues/7967))
+
+
+Internal Changes
+----------------
+
+- Further optimise queueing of inbound replication commands. ([\#7876](https://github.com/matrix-org/synapse/issues/7876))
+
+
+Synapse 1.18.0rc1 (2020-07-27)
+==============================
+
+Features
+--------
+
+- Include room states on invite events that are sent to application services. Contributed by @Sorunome. ([\#6455](https://github.com/matrix-org/synapse/issues/6455))
+- Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel. ([\#7613](https://github.com/matrix-org/synapse/issues/7613), [\#7953](https://github.com/matrix-org/synapse/issues/7953))
+- Add experimental support for running multiple federation sender processes. ([\#7798](https://github.com/matrix-org/synapse/issues/7798))
+- Add the option to validate the `iss` and `aud` claims for JWT logins. ([\#7827](https://github.com/matrix-org/synapse/issues/7827))
+- Add support for handling registration requests across multiple client reader workers. ([\#7830](https://github.com/matrix-org/synapse/issues/7830))
+- Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7842](https://github.com/matrix-org/synapse/issues/7842))
+- Allow email subjects to be customised through Synapse's configuration. ([\#7846](https://github.com/matrix-org/synapse/issues/7846))
+- Add the ability to re-activate an account from the admin API. ([\#7847](https://github.com/matrix-org/synapse/issues/7847), [\#7908](https://github.com/matrix-org/synapse/issues/7908))
+- Add experimental support for running multiple pusher workers. ([\#7855](https://github.com/matrix-org/synapse/issues/7855))
+- Add experimental support for moving typing off master. ([\#7869](https://github.com/matrix-org/synapse/issues/7869), [\#7959](https://github.com/matrix-org/synapse/issues/7959))
+- Report CPU metrics to prometheus for time spent processing replication commands. ([\#7879](https://github.com/matrix-org/synapse/issues/7879))
+- Support oEmbed for media previews. ([\#7920](https://github.com/matrix-org/synapse/issues/7920))
+- Abort federation requests where the client disconnects before the ratelimiter expires. ([\#7930](https://github.com/matrix-org/synapse/issues/7930))
+- Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work. ([\#7931](https://github.com/matrix-org/synapse/issues/7931))
+
+
+Bugfixes
+--------
+
+- Fix detection of out of sync remote device lists when receiving events from remote users. ([\#7815](https://github.com/matrix-org/synapse/issues/7815))
+- Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain. ([\#7817](https://github.com/matrix-org/synapse/issues/7817))
+- Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0. ([\#7822](https://github.com/matrix-org/synapse/issues/7822))
+- Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails. ([\#7829](https://github.com/matrix-org/synapse/issues/7829))
+- Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`. ([\#7844](https://github.com/matrix-org/synapse/issues/7844))
+- Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0. ([\#7850](https://github.com/matrix-org/synapse/issues/7850))
+- Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation. ([\#7854](https://github.com/matrix-org/synapse/issues/7854))
+- Fix a bug which allowed empty rooms to be rejoined over federation. ([\#7859](https://github.com/matrix-org/synapse/issues/7859))
+- Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers. ([\#7866](https://github.com/matrix-org/synapse/issues/7866))
+- Fix a long standing bug where the tracing of async functions with opentracing was broken. ([\#7872](https://github.com/matrix-org/synapse/issues/7872), [\#7961](https://github.com/matrix-org/synapse/issues/7961))
+- Fix "TypeError in `synapse.notifier`" exceptions. ([\#7880](https://github.com/matrix-org/synapse/issues/7880))
+- Fix deprecation warning due to invalid escape sequences. ([\#7895](https://github.com/matrix-org/synapse/issues/7895))
+
+
+Updates to the Docker image
+---------------------------
+
+- Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196. ([\#7839](https://github.com/matrix-org/synapse/issues/7839))
+
+
+Improved Documentation
+----------------------
+
+- Provide instructions on using `register_new_matrix_user` via docker. ([\#7885](https://github.com/matrix-org/synapse/issues/7885))
+- Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation. ([\#7889](https://github.com/matrix-org/synapse/issues/7889))
+- Reorder database paragraphs to promote postgres over sqlite. ([\#7933](https://github.com/matrix-org/synapse/issues/7933))
+- Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md). ([\#7934](https://github.com/matrix-org/synapse/issues/7934))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric. ([\#7878](https://github.com/matrix-org/synapse/issues/7878))
+- Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim. ([\#7888](https://github.com/matrix-org/synapse/issues/7888))
+
+
+Internal Changes
+----------------
+
+- Switch parts of the codebase from `simplejson` to the standard library `json`. ([\#7802](https://github.com/matrix-org/synapse/issues/7802))
+- Add type hints to the http server code and remove an unused parameter. ([\#7813](https://github.com/matrix-org/synapse/issues/7813))
+- Add type hints to synapse.api.errors module. ([\#7820](https://github.com/matrix-org/synapse/issues/7820))
+- Ensure that calls to `json.dumps` are compatible with the standard library json. ([\#7836](https://github.com/matrix-org/synapse/issues/7836))
+- Remove redundant `retry_on_integrity_error` wrapper for event persistence code. ([\#7848](https://github.com/matrix-org/synapse/issues/7848))
+- Consistently use `db_to_json` to convert from database values to JSON objects. ([\#7849](https://github.com/matrix-org/synapse/issues/7849))
+- Convert various parts of the codebase to async/await. ([\#7851](https://github.com/matrix-org/synapse/issues/7851), [\#7860](https://github.com/matrix-org/synapse/issues/7860), [\#7868](https://github.com/matrix-org/synapse/issues/7868), [\#7871](https://github.com/matrix-org/synapse/issues/7871), [\#7873](https://github.com/matrix-org/synapse/issues/7873), [\#7874](https://github.com/matrix-org/synapse/issues/7874), [\#7884](https://github.com/matrix-org/synapse/issues/7884), [\#7912](https://github.com/matrix-org/synapse/issues/7912), [\#7935](https://github.com/matrix-org/synapse/issues/7935), [\#7939](https://github.com/matrix-org/synapse/issues/7939), [\#7942](https://github.com/matrix-org/synapse/issues/7942), [\#7944](https://github.com/matrix-org/synapse/issues/7944))
+- Add support for handling registration requests across multiple client reader workers. ([\#7853](https://github.com/matrix-org/synapse/issues/7853))
+- Small performance improvement in typing processing. ([\#7856](https://github.com/matrix-org/synapse/issues/7856))
+- The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100. ([\#7858](https://github.com/matrix-org/synapse/issues/7858))
+- Optimise queueing of inbound replication commands. ([\#7861](https://github.com/matrix-org/synapse/issues/7861))
+- Add some type annotations to `HomeServer` and `BaseHandler`. ([\#7870](https://github.com/matrix-org/synapse/issues/7870))
+- Clean up `PreserveLoggingContext`. ([\#7877](https://github.com/matrix-org/synapse/issues/7877))
+- Change "unknown room version" logging from 'error' to 'warning'. ([\#7881](https://github.com/matrix-org/synapse/issues/7881))
+- Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`. ([\#7882](https://github.com/matrix-org/synapse/issues/7882))
+- Return an empty body for OPTIONS requests. ([\#7886](https://github.com/matrix-org/synapse/issues/7886))
+- Fix typo in generated config file. Contributed by @ThiefMaster. ([\#7890](https://github.com/matrix-org/synapse/issues/7890))
+- Import ABC from `collections.abc` for Python 3.10 compatibility. ([\#7892](https://github.com/matrix-org/synapse/issues/7892))
+- Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
+ and `get_previous_frame` from `synapse.logging.utils` module. ([\#7897](https://github.com/matrix-org/synapse/issues/7897))
+- Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI. ([\#7914](https://github.com/matrix-org/synapse/issues/7914))
+- Use Element CSS and logo in notification emails when app name is Element. ([\#7919](https://github.com/matrix-org/synapse/issues/7919))
+- Optimisation to /sync handling: skip serializing the response if the client has already disconnected. ([\#7927](https://github.com/matrix-org/synapse/issues/7927))
+- When a client disconnects, don't log it as 'Error processing request'. ([\#7928](https://github.com/matrix-org/synapse/issues/7928))
+- Add debugging to `/sync` response generation (disabled by default). ([\#7929](https://github.com/matrix-org/synapse/issues/7929))
+- Update comments that refer to Deferreds for async functions. ([\#7945](https://github.com/matrix-org/synapse/issues/7945))
+- Simplify error handling in federation handler. ([\#7950](https://github.com/matrix-org/synapse/issues/7950))
+
+
Synapse 1.17.0 (2020-07-13)
===========================
diff --git a/INSTALL.md b/INSTALL.md
index ef80a26c..b507de74 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -405,13 +405,11 @@ so, you will need to edit `homeserver.yaml`, as follows:
```
* You will also need to uncomment the `tls_certificate_path` and
- `tls_private_key_path` lines under the `TLS` section. You can either
- point these settings at an existing certificate and key, or you can
- enable Synapse's built-in ACME (Let's Encrypt) support. Instructions
- for having Synapse automatically provision and renew federation
- certificates through ACME can be found at [ACME.md](docs/ACME.md).
- Note that, as pointed out in that document, this feature will not
- work with installs set up after November 2019.
+ `tls_private_key_path` lines under the `TLS` section. You will need to manage
+ provisioning of these certificates yourself — Synapse had built-in ACME
+ support, but the ACMEv1 protocol Synapse implements is deprecated, not
+ allowed by LetsEncrypt for new sites, and will break for existing sites in
+ late 2020. See [ACME.md](docs/ACME.md).
If you are using your own certificate, be sure to use a `.pem` file that
includes the full certificate chain including any intermediate certificates
diff --git a/README.rst b/README.rst
index 38376e23..f7116b34 100644
--- a/README.rst
+++ b/README.rst
@@ -188,12 +188,8 @@ Using PostgreSQL
================
Synapse offers two database engines:
- * `SQLite <https://sqlite.org/>`_
* `PostgreSQL <https://www.postgresql.org>`_
-
-By default Synapse uses SQLite in and doing so trades performance for convenience.
-SQLite is only recommended in Synapse for testing purposes or for servers with
-light workloads.
+ * `SQLite <https://sqlite.org/>`_
Almost all installations should opt to use PostgreSQL. Advantages include:
@@ -207,6 +203,10 @@ Almost all installations should opt to use PostgreSQL. Advantages include:
For information on how to install and use PostgreSQL, please see
`docs/postgres.md <docs/postgres.md>`_.
+By default Synapse uses SQLite and in doing so trades performance for convenience.
+SQLite is only recommended in Synapse for testing purposes or for servers with
+light workloads.
+
.. _reverse-proxy:
Using a reverse proxy with Synapse
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 3b5627e8..6492fa01 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -75,6 +75,24 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.18.0
+====================
+
+Docker `-py3` suffix will be removed in future versions
+-------------------------------------------------------
+
+From 10th August 2020, we will no longer publish Docker images with the `-py3` tag suffix. The images tagged with the `-py3` suffix have been identical to the non-suffixed tags since release 0.99.0, and the suffix is obsolete.
+
+On 10th August, we will remove the `latest-py3` tag. Existing per-release tags (such as `v1.18.0-py3`) will not be removed, but no new `-py3` tags will be added.
+
+Scripts relying on the `-py3` suffix will need to be updated.
+
+Redis replication is now recommended in lieu of TCP replication
+---------------------------------------------------------------
+
+When setting up worker processes, we now recommend the use of a Redis server for replication. **The old direct TCP connection method is deprecated and will be removed in a future release.**
+See `docs/workers.md <docs/workers.md>`_ for more details.
+
Upgrading to v1.14.0
====================
diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index 48da410d..77422f5e 100755
--- a/contrib/cmdclient/console.py
+++ b/contrib/cmdclient/console.py
@@ -17,9 +17,6 @@
""" Starts a synapse client console. """
from __future__ import print_function
-from twisted.internet import reactor, defer, threads
-from http import TwistedHttpClient
-
import argparse
import cmd
import getpass
@@ -28,12 +25,14 @@ import shlex
import sys
import time
import urllib
-import urlparse
+from http import TwistedHttpClient
-import nacl.signing
import nacl.encoding
+import nacl.signing
+import urlparse
+from signedjson.sign import SignatureVerifyException, verify_signed_json
-from signedjson.sign import verify_signed_json, SignatureVerifyException
+from twisted.internet import defer, reactor, threads
CONFIG_JSON = "cmdclient_config.json"
@@ -493,7 +492,7 @@ class SynapseCmd(cmd.Cmd):
"list messages <roomid> from=END&to=START&limit=3"
"""
args = self._parse(line, ["type", "roomid", "qp"])
- if not "type" in args or not "roomid" in args:
+ if "type" not in args or "roomid" not in args:
print("Must specify type and room ID.")
return
if args["type"] not in ["members", "messages"]:
@@ -508,7 +507,7 @@ class SynapseCmd(cmd.Cmd):
try:
key_value = key_value_str.split("=")
qp[key_value[0]] = key_value[1]
- except:
+ except Exception:
print("Bad query param: %s" % key_value)
return
@@ -585,7 +584,7 @@ class SynapseCmd(cmd.Cmd):
parsed_url = urlparse.urlparse(args["path"])
qp.update(urlparse.parse_qs(parsed_url.query))
args["path"] = parsed_url.path
- except:
+ except Exception:
pass
reactor.callFromThread(
@@ -772,10 +771,10 @@ def main(server_url, identity_server_url, username, token, config_path):
syn_cmd.config = json.load(config)
try:
http_client.verbose = "on" == syn_cmd.config["verbose"]
- except:
+ except Exception:
pass
print("Loaded config from %s" % config_path)
- except:
+ except Exception:
pass
# Twisted-specific: Runs the command processor in Twisted's event loop
diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py
index 0e101d2b..e2534ee5 100644
--- a/contrib/cmdclient/http.py
+++ b/contrib/cmdclient/http.py
@@ -14,14 +14,14 @@
# limitations under the License.
from __future__ import print_function
-from twisted.web.client import Agent, readBody
-from twisted.web.http_headers import Headers
-from twisted.internet import defer, reactor
-
-from pprint import pformat
import json
import urllib
+from pprint import pformat
+
+from twisted.internet import defer, reactor
+from twisted.web.client import Agent, readBody
+from twisted.web.http_headers import Headers
class HttpClient(object):
diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
index 3bbbcfa1..a84ec4ec 100644
--- a/contrib/experiments/test_messaging.py
+++ b/contrib/experiments/test_messaging.py
@@ -28,27 +28,24 @@ Currently assumes the local address is localhost:<port>
"""
-from synapse.federation import ReplicationHandler
-
-from synapse.federation.units import Pdu
-
-from synapse.util import origin_from_ucid
-
-from synapse.app.homeserver import SynapseHomeServer
-
-# from synapse.logging.utils import log_function
-
-from twisted.internet import reactor, defer
-from twisted.python import log
-
import argparse
+import curses.wrapper
import json
import logging
import os
import re
import cursesio
-import curses.wrapper
+
+from twisted.internet import defer, reactor
+from twisted.python import log
+
+from synapse.app.homeserver import SynapseHomeServer
+from synapse.federation import ReplicationHandler
+from synapse.federation.units import Pdu
+from synapse.util import origin_from_ucid
+
+# from synapse.logging.utils import log_function
logger = logging.getLogger("example")
@@ -75,7 +72,7 @@ class InputOutput(object):
"""
try:
- m = re.match("^join (\S+)$", line)
+ m = re.match(r"^join (\S+)$", line)
if m:
# The `sender` wants to join a room.
(room_name,) = m.groups()
@@ -84,7 +81,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
- m = re.match("^invite (\S+) (\S+)$", line)
+ m = re.match(r"^invite (\S+) (\S+)$", line)
if m:
# `sender` wants to invite someone to a room
room_name, invitee = m.groups()
@@ -93,7 +90,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
- m = re.match("^send (\S+) (.*)$", line)
+ m = re.match(r"^send (\S+) (.*)$", line)
if m:
# `sender` wants to message a room
room_name, body = m.groups()
@@ -102,7 +99,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
- m = re.match("^backfill (\S+)$", line)
+ m = re.match(r"^backfill (\S+)$", line)
if m:
# we want to backfill a room
(room_name,) = m.groups()
@@ -201,16 +198,6 @@ class HomeServer(ReplicationHandler):
% (pdu.context, pdu.pdu_type, json.dumps(pdu.content))
)
- # def on_state_change(self, pdu):
- ##self.output.print_line("#%s (state) %s *** %s" %
- ##(pdu.context, pdu.state_key, pdu.pdu_type)
- ##)
-
- # if "joinee" in pdu.content:
- # self._on_join(pdu.context, pdu.content["joinee"])
- # elif "invitee" in pdu.content:
- # self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"])
-
def _on_message(self, pdu):
""" We received a message
"""
@@ -314,7 +301,7 @@ class HomeServer(ReplicationHandler):
return self.replication_layer.backfill(dest, room_name, limit)
def _get_room_remote_servers(self, room_name):
- return [i for i in self.joined_rooms.setdefault(room_name).servers]
+ return list(self.joined_rooms.setdefault(room_name).servers)
def _get_or_create_room(self, room_name):
return self.joined_rooms.setdefault(room_name, Room(room_name))
@@ -334,7 +321,7 @@ def main(stdscr):
user = args.user
server_name = origin_from_ucid(user)
- ## Set up logging ##
+ # Set up logging
root_logger = logging.getLogger()
@@ -354,7 +341,7 @@ def main(stdscr):
observer = log.PythonLoggingObserver()
observer.start()
- ## Set up synapse server
+ # Set up synapse server
curses_stdio = cursesio.CursesStdIO(stdscr)
input_output = InputOutput(curses_stdio, user)
@@ -368,16 +355,16 @@ def main(stdscr):
input_output.set_home_server(hs)
- ## Add input_output logger
+ # Add input_output logger
io_logger = IOLoggerHandler(input_output)
io_logger.setFormatter(formatter)
root_logger.addHandler(io_logger)
- ## Start! ##
+ # Start!
try:
port = int(server_name.split(":")[1])
- except:
+ except Exception:
port = 12345
app_hs.get_http_server().start_listening(port)
diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json
index 30a8681f..539569b5 100644
--- a/contrib/grafana/synapse.json
+++ b/contrib/grafana/synapse.json
@@ -1,7 +1,44 @@
{
+ "__inputs": [
+ {
+ "name": "DS_PROMETHEUS",
+ "label": "Prometheus",
+ "description": "",
+ "type": "datasource",
+ "pluginId": "prometheus",
+ "pluginName": "Prometheus"
+ }
+ ],
+ "__requires": [
+ {
+ "type": "grafana",
+ "id": "grafana",
+ "name": "Grafana",
+ "version": "6.7.4"
+ },
+ {
+ "type": "panel",
+ "id": "graph",
+ "name": "Graph",
+ "version": ""
+ },
+ {
+ "type": "panel",
+ "id": "heatmap",
+ "name": "Heatmap",
+ "version": ""
+ },
+ {
+ "type": "datasource",
+ "id": "prometheus",
+ "name": "Prometheus",
+ "version": "1.0.0"
+ }
+ ],
"annotations": {
"list": [
{
+ "$$hashKey": "object:76",
"builtIn": 1,
"datasource": "$datasource",
"enable": false,
@@ -17,8 +54,8 @@
"editable": true,
"gnetId": null,
"graphTooltip": 0,
- "id": 1,
- "iteration": 1591098104645,
+ "id": null,
+ "iteration": 1594646317221,
"links": [
{
"asDropdown": true,
@@ -34,7 +71,7 @@
"panels": [
{
"collapsed": false,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -269,7 +306,6 @@
"show": false
},
"links": [],
- "options": {},
"reverseYBuckets": false,
"targets": [
{
@@ -559,7 +595,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -1423,7 +1459,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -1795,7 +1831,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -2531,7 +2567,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -2823,7 +2859,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -2844,7 +2880,7 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 33
+ "y": 6
},
"hiddenSeries": false,
"id": 79,
@@ -2940,7 +2976,7 @@
"h": 9,
"w": 12,
"x": 12,
- "y": 33
+ "y": 6
},
"hiddenSeries": false,
"id": 83,
@@ -3038,7 +3074,7 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 42
+ "y": 15
},
"hiddenSeries": false,
"id": 109,
@@ -3137,7 +3173,7 @@
"h": 9,
"w": 12,
"x": 12,
- "y": 42
+ "y": 15
},
"hiddenSeries": false,
"id": 111,
@@ -3223,14 +3259,14 @@
"dashLength": 10,
"dashes": false,
"datasource": "$datasource",
- "description": "",
+ "description": "Number of events queued up on the master process for processing by the federation sender",
"fill": 1,
"fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
- "y": 51
+ "y": 24
},
"hiddenSeries": false,
"id": 140,
@@ -3354,6 +3390,103 @@
"align": false,
"alignLevel": null
}
+ },
+ {
+ "aliasColors": {},
+ "bars": false,
+ "dashLength": 10,
+ "dashes": false,
+ "datasource": "${DS_PROMETHEUS}",
+ "description": "The number of events in the in-memory queues ",
+ "fill": 1,
+ "fillGradient": 0,
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 12,
+ "y": 24
+ },
+ "hiddenSeries": false,
+ "id": 142,
+ "legend": {
+ "avg": false,
+ "current": false,
+ "max": false,
+ "min": false,
+ "show": true,
+ "total": false,
+ "values": false
+ },
+ "lines": true,
+ "linewidth": 1,
+ "nullPointMode": "null",
+ "options": {
+ "dataLinks": []
+ },
+ "percentage": false,
+ "pointradius": 2,
+ "points": false,
+ "renderer": "flot",
+ "seriesOverrides": [],
+ "spaceLength": 10,
+ "stack": false,
+ "steppedLine": false,
+ "targets": [
+ {
+ "expr": "synapse_federation_transaction_queue_pending_pdus{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "interval": "",
+ "legendFormat": "pending PDUs {{job}}-{{index}}",
+ "refId": "A"
+ },
+ {
+ "expr": "synapse_federation_transaction_queue_pending_edus{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "interval": "",
+ "legendFormat": "pending EDUs {{job}}-{{index}}",
+ "refId": "B"
+ }
+ ],
+ "thresholds": [],
+ "timeFrom": null,
+ "timeRegions": [],
+ "timeShift": null,
+ "title": "In-memory federation transmission queues",
+ "tooltip": {
+ "shared": true,
+ "sort": 0,
+ "value_type": "individual"
+ },
+ "type": "graph",
+ "xaxis": {
+ "buckets": null,
+ "mode": "time",
+ "name": null,
+ "show": true,
+ "values": []
+ },
+ "yaxes": [
+ {
+ "$$hashKey": "object:317",
+ "format": "short",
+ "label": "events",
+ "logBase": 1,
+ "max": null,
+ "min": "0",
+ "show": true
+ },
+ {
+ "$$hashKey": "object:318",
+ "format": "short",
+ "label": "",
+ "logBase": 1,
+ "max": null,
+ "min": null,
+ "show": true
+ }
+ ],
+ "yaxis": {
+ "align": false,
+ "alignLevel": null
+ }
}
],
"title": "Federation",
@@ -3361,7 +3494,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -3567,7 +3700,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -3588,7 +3721,7 @@
"h": 7,
"w": 12,
"x": 0,
- "y": 52
+ "y": 79
},
"hiddenSeries": false,
"id": 48,
@@ -3682,7 +3815,7 @@
"h": 7,
"w": 12,
"x": 12,
- "y": 52
+ "y": 79
},
"hiddenSeries": false,
"id": 104,
@@ -3802,7 +3935,7 @@
"h": 7,
"w": 12,
"x": 0,
- "y": 59
+ "y": 86
},
"hiddenSeries": false,
"id": 10,
@@ -3898,7 +4031,7 @@
"h": 7,
"w": 12,
"x": 12,
- "y": 59
+ "y": 86
},
"hiddenSeries": false,
"id": 11,
@@ -3987,7 +4120,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -4011,7 +4144,7 @@
"h": 13,
"w": 12,
"x": 0,
- "y": 67
+ "y": 80
},
"hiddenSeries": false,
"id": 12,
@@ -4106,7 +4239,7 @@
"h": 13,
"w": 12,
"x": 12,
- "y": 67
+ "y": 80
},
"hiddenSeries": false,
"id": 26,
@@ -4201,7 +4334,7 @@
"h": 13,
"w": 12,
"x": 0,
- "y": 80
+ "y": 93
},
"hiddenSeries": false,
"id": 13,
@@ -4297,7 +4430,7 @@
"h": 13,
"w": 12,
"x": 12,
- "y": 80
+ "y": 93
},
"hiddenSeries": false,
"id": 27,
@@ -4392,7 +4525,7 @@
"h": 13,
"w": 12,
"x": 0,
- "y": 93
+ "y": 106
},
"hiddenSeries": false,
"id": 28,
@@ -4486,7 +4619,7 @@
"h": 13,
"w": 12,
"x": 12,
- "y": 93
+ "y": 106
},
"hiddenSeries": false,
"id": 25,
@@ -4572,7 +4705,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -5062,7 +5195,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -5083,7 +5216,7 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 66
+ "y": 121
},
"hiddenSeries": false,
"id": 91,
@@ -5179,7 +5312,7 @@
"h": 9,
"w": 12,
"x": 12,
- "y": 66
+ "y": 121
},
"hiddenSeries": false,
"id": 21,
@@ -5271,7 +5404,7 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 75
+ "y": 130
},
"hiddenSeries": false,
"id": 89,
@@ -5369,7 +5502,7 @@
"h": 9,
"w": 12,
"x": 12,
- "y": 75
+ "y": 130
},
"hiddenSeries": false,
"id": 93,
@@ -5459,7 +5592,7 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 84
+ "y": 139
},
"hiddenSeries": false,
"id": 95,
@@ -5552,12 +5685,12 @@
"mode": "spectrum"
},
"dataFormat": "tsbuckets",
- "datasource": "Prometheus",
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
- "y": 84
+ "y": 139
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -5567,7 +5700,6 @@
"show": true
},
"links": [],
- "options": {},
"reverseYBuckets": false,
"targets": [
{
@@ -5609,7 +5741,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -5630,7 +5762,7 @@
"h": 7,
"w": 12,
"x": 0,
- "y": 39
+ "y": 66
},
"hiddenSeries": false,
"id": 2,
@@ -5754,7 +5886,7 @@
"h": 7,
"w": 12,
"x": 12,
- "y": 39
+ "y": 66
},
"hiddenSeries": false,
"id": 41,
@@ -5847,7 +5979,7 @@
"h": 7,
"w": 12,
"x": 0,
- "y": 46
+ "y": 73
},
"hiddenSeries": false,
"id": 42,
@@ -5939,7 +6071,7 @@
"h": 7,
"w": 12,
"x": 12,
- "y": 46
+ "y": 73
},
"hiddenSeries": false,
"id": 43,
@@ -6031,7 +6163,7 @@
"h": 7,
"w": 12,
"x": 0,
- "y": 53
+ "y": 80
},
"hiddenSeries": false,
"id": 113,
@@ -6129,7 +6261,7 @@
"h": 7,
"w": 12,
"x": 12,
- "y": 53
+ "y": 80
},
"hiddenSeries": false,
"id": 115,
@@ -6215,7 +6347,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -6236,7 +6368,7 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 58
+ "y": 40
},
"hiddenSeries": false,
"id": 67,
@@ -6267,7 +6399,7 @@
"steppedLine": false,
"targets": [
{
- "expr": " synapse_event_persisted_position{instance=\"$instance\",job=\"synapse\"} - ignoring(index, job, name) group_right() synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
+ "expr": "max(synapse_event_persisted_position{instance=\"$instance\"}) - ignoring(instance,index, job, name) group_right() synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
"format": "time_series",
"interval": "",
"intervalFactor": 1,
@@ -6328,7 +6460,7 @@
"h": 9,
"w": 12,
"x": 12,
- "y": 58
+ "y": 40
},
"hiddenSeries": false,
"id": 71,
@@ -6362,6 +6494,7 @@
"expr": "time()*1000-synapse_event_processing_last_ts{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}",
"format": "time_series",
"hide": false,
+ "interval": "",
"intervalFactor": 1,
"legendFormat": "{{job}}-{{index}} {{name}}",
"refId": "B"
@@ -6420,7 +6553,7 @@
"h": 9,
"w": 12,
"x": 0,
- "y": 67
+ "y": 49
},
"hiddenSeries": false,
"id": 121,
@@ -6509,7 +6642,7 @@
},
{
"collapsed": true,
- "datasource": null,
+ "datasource": "${DS_PROMETHEUS}",
"gridPos": {
"h": 1,
"w": 24,
@@ -6539,7 +6672,7 @@
"h": 8,
"w": 12,
"x": 0,
- "y": 41
+ "y": 86
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -6549,7 +6682,6 @@
"show": true
},
"links": [],
- "options": {},
"reverseYBuckets": false,
"targets": [
{
@@ -6599,7 +6731,7 @@
"h": 8,
"w": 12,
"x": 12,
- "y": 41
+ "y": 86
},
"hiddenSeries": false,
"id": 124,
@@ -6700,7 +6832,7 @@
"h": 8,
"w": 12,
"x": 0,
- "y": 49
+ "y": 94
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -6710,7 +6842,6 @@
"show": true
},
"links": [],
- "options": {},
"reverseYBuckets": false,
"targets": [
{
@@ -6760,7 +6891,7 @@
"h": 8,
"w": 12,
"x": 12,
- "y": 49
+ "y": 94
},
"hiddenSeries": false,
"id": 128,
@@ -6879,7 +7010,7 @@
"h": 8,
"w": 12,
"x": 0,
- "y": 57
+ "y": 102
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -6889,7 +7020,6 @@
"show": true
},
"links": [],
- "options": {},
"reverseYBuckets": false,
"targets": [
{
@@ -6939,7 +7069,7 @@
"h": 8,
"w": 12,
"x": 12,
- "y": 57
+ "y": 102
},
"hiddenSeries": false,
"id": 130,
@@ -7058,7 +7188,7 @@
"h": 8,
"w": 12,
"x": 0,
- "y": 65
+ "y": 110
},
"heatmap": {},
"hideZeroBuckets": true,
@@ -7068,12 +7198,12 @@
"show": true
},
"links": [],
- "options": {},
"reverseYBuckets": false,
"targets": [
{
- "expr": "rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0)",
+ "expr": "rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])",
"format": "heatmap",
+ "interval": "",
"intervalFactor": 1,
"legendFormat": "{{le}}",
"refId": "A"
@@ -7118,7 +7248,7 @@
"h": 8,
"w": 12,
"x": 12,
- "y": 65
+ "y": 110
},
"hiddenSeries": false,
"id": 132,
@@ -7149,29 +7279,33 @@
"steppedLine": false,
"targets": [
{
- "expr": "histogram_quantile(0.5, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0)) ",
+ "expr": "histogram_quantile(0.5, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))",
"format": "time_series",
+ "interval": "",
"intervalFactor": 1,
"legendFormat": "50%",
"refId": "A"
},
{
- "expr": "histogram_quantile(0.75, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))",
+ "expr": "histogram_quantile(0.75, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))",
"format": "time_series",
+ "interval": "",
"intervalFactor": 1,
"legendFormat": "75%",
"refId": "B"
},
{
- "expr": "histogram_quantile(0.90, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))",
+ "expr": "histogram_quantile(0.90, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))",
"format": "time_series",
+ "interval": "",
"intervalFactor": 1,
"legendFormat": "90%",
"refId": "C"
},
{
- "expr": "histogram_quantile(0.99, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))",
+ "expr": "histogram_quantile(0.99, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))",
"format": "time_series",
+ "interval": "",
"intervalFactor": 1,
"legendFormat": "99%",
"refId": "D"
@@ -7181,7 +7315,7 @@
"timeFrom": null,
"timeRegions": [],
"timeShift": null,
- "title": "Number of state resolution performed, by number of state groups involved (quantiles)",
+ "title": "Number of state resolutions performed, by number of state groups involved (quantiles)",
"tooltip": {
"shared": true,
"sort": 0,
@@ -7233,6 +7367,7 @@
"list": [
{
"current": {
+ "selected": false,
"text": "Prometheus",
"value": "Prometheus"
},
@@ -7309,14 +7444,12 @@
},
{
"allValue": null,
- "current": {
- "text": "matrix.org",
- "value": "matrix.org"
- },
+ "current": {},
"datasource": "$datasource",
"definition": "",
"hide": 0,
"includeAll": false,
+ "index": -1,
"label": null,
"multi": false,
"name": "instance",
@@ -7335,17 +7468,13 @@
{
"allFormat": "regex wildcard",
"allValue": "",
- "current": {
- "text": "synapse",
- "value": [
- "synapse"
- ]
- },
+ "current": {},
"datasource": "$datasource",
"definition": "",
"hide": 0,
"hideLabel": false,
"includeAll": true,
+ "index": -1,
"label": "Job",
"multi": true,
"multiFormat": "regex values",
@@ -7366,16 +7495,13 @@
{
"allFormat": "regex wildcard",
"allValue": ".*",
- "current": {
- "selected": false,
- "text": "All",
- "value": "$__all"
- },
+ "current": {},
"datasource": "$datasource",
"definition": "",
"hide": 0,
"hideLabel": false,
"includeAll": true,
+ "index": -1,
"label": "",
"multi": true,
"multiFormat": "regex values",
@@ -7428,5 +7554,8 @@
"timezone": "",
"title": "Synapse",
"uid": "000000012",
- "version": 29
+ "variables": {
+ "list": []
+ },
+ "version": 32
} \ No newline at end of file
diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py
index 92736480..de33fac1 100644
--- a/contrib/graph/graph.py
+++ b/contrib/graph/graph.py
@@ -1,5 +1,13 @@
from __future__ import print_function
+import argparse
+import cgi
+import datetime
+import json
+
+import pydot
+import urllib2
+
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,15 +23,6 @@ from __future__ import print_function
# limitations under the License.
-import sqlite3
-import pydot
-import cgi
-import json
-import datetime
-import argparse
-import urllib2
-
-
def make_name(pdu_id, origin):
return "%s@%s" % (pdu_id, origin)
@@ -33,7 +32,7 @@ def make_graph(pdus, room, filename_prefix):
node_map = {}
origins = set()
- colors = set(("red", "green", "blue", "yellow", "purple"))
+ colors = {"red", "green", "blue", "yellow", "purple"}
for pdu in pdus:
origins.add(pdu.get("origin"))
@@ -49,7 +48,7 @@ def make_graph(pdus, room, filename_prefix):
try:
c = colors.pop()
color_map[o] = c
- except:
+ except Exception:
print("Run out of colours!")
color_map[o] = "black"
diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py
index 4619f0e3..0980231e 100644
--- a/contrib/graph/graph2.py
+++ b/contrib/graph/graph2.py
@@ -13,12 +13,13 @@
# limitations under the License.
-import sqlite3
-import pydot
+import argparse
import cgi
-import json
import datetime
-import argparse
+import json
+import sqlite3
+
+import pydot
from synapse.events import FrozenEvent
from synapse.util.frozenutils import unfreeze
@@ -98,7 +99,7 @@ def make_graph(db_name, room_id, file_prefix, limit):
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
- except:
+ except Exception:
end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,))
node_map[prev_id] = end_node
diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py
index 31546385..91db98e7 100644
--- a/contrib/graph/graph3.py
+++ b/contrib/graph/graph3.py
@@ -1,5 +1,15 @@
from __future__ import print_function
+import argparse
+import cgi
+import datetime
+
+import pydot
+import simplejson as json
+
+from synapse.events import FrozenEvent
+from synapse.util.frozenutils import unfreeze
+
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,16 +25,6 @@ from __future__ import print_function
# limitations under the License.
-import pydot
-import cgi
-import simplejson as json
-import datetime
-import argparse
-
-from synapse.events import FrozenEvent
-from synapse.util.frozenutils import unfreeze
-
-
def make_graph(file_name, room_id, file_prefix, limit):
print("Reading lines")
with open(file_name) as f:
@@ -106,7 +106,7 @@ def make_graph(file_name, room_id, file_prefix, limit):
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
- except:
+ except Exception:
end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,))
node_map[prev_id] = end_node
diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py
index 67fb2cd1..69aa74bd 100644
--- a/contrib/jitsimeetbridge/jitsimeetbridge.py
+++ b/contrib/jitsimeetbridge/jitsimeetbridge.py
@@ -12,15 +12,15 @@ npm install jquery jsdom
"""
from __future__ import print_function
-import gevent
-import grequests
-from BeautifulSoup import BeautifulSoup
import json
-import urllib
import subprocess
import time
-# ACCESS_TOKEN="" #
+import gevent
+import grequests
+from BeautifulSoup import BeautifulSoup
+
+ACCESS_TOKEN = ""
MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/"
MYUSERNAME = "@davetest:matrix.org"
diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py
index f57e6e7d..372dbd9e 100755
--- a/contrib/scripts/kick_users.py
+++ b/contrib/scripts/kick_users.py
@@ -1,10 +1,12 @@
#!/usr/bin/env python
from __future__ import print_function
-from argparse import ArgumentParser
+
import json
-import requests
import sys
import urllib
+from argparse import ArgumentParser
+
+import requests
try:
raw_input
diff --git a/debian/changelog b/debian/changelog
index 7fc0c711..b0c84527 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,15 @@
+matrix-synapse (1.18.0-1~bpo10+1) buster-backports; urgency=medium
+
+ * Rebuild for buster-backports.
+
+ -- Andrej Shadura <andrewsh@debian.org> Wed, 19 Aug 2020 21:15:54 +0200
+
+matrix-synapse (1.18.0-1) unstable; urgency=medium
+
+ * New upstream release.
+
+ -- Andrej Shadura <andrewsh@debian.org> Wed, 12 Aug 2020 09:05:45 +0200
+
matrix-synapse (1.17.0-1~bpo10+1) buster-backports; urgency=medium
* Rebuild for buster-backports.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 093e89af..8b3a4246 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -16,35 +16,31 @@ ARG PYTHON_VERSION=3.7
###
### Stage 0: builder
###
-FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 as builder
+FROM docker.io/python:${PYTHON_VERSION}-slim as builder
# install the OS build deps
-RUN apk add \
- build-base \
- libffi-dev \
- libjpeg-turbo-dev \
- libwebp-dev \
- libressl-dev \
- libxslt-dev \
- linux-headers \
- postgresql-dev \
- zlib-dev
-# build things which have slow build steps, before we copy synapse, so that
-# the layer can be cached.
-#
-# (we really just care about caching a wheel here, as the "pip install" below
-# will install them again.)
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ libpq-dev \
+ && rm -rf /var/lib/apt/lists/*
+# Build dependencies that are not available as wheels, to speed up rebuilds
RUN pip install --prefix="/install" --no-warn-script-location \
- cryptography \
- msgpack-python \
- pillow \
- pynacl
+ frozendict \
+ jaeger-client \
+ opentracing \
+ prometheus-client \
+ psycopg2 \
+ pycparser \
+ pyrsistent \
+ pyyaml \
+ simplejson \
+ threadloop \
+ thrift
# now install synapse and all of the python deps to /install.
-
COPY synapse /synapse/synapse/
COPY scripts /synapse/scripts/
COPY MANIFEST.in README.rst setup.py synctl /synapse/
@@ -56,20 +52,13 @@ RUN pip install --prefix="/install" --no-warn-script-location \
### Stage 1: runtime
###
-FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
+FROM docker.io/python:${PYTHON_VERSION}-slim
-# xmlsec is required for saml support
-RUN apk add --no-cache --virtual .runtime_deps \
- libffi \
- libjpeg-turbo \
- libwebp \
- libressl \
- libxslt \
- libpq \
- zlib \
- su-exec \
- tzdata \
- xmlsec
+RUN apt-get update && apt-get install -y \
+ libpq5 \
+ xmlsec1 \
+ gosu \
+ && rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local
COPY ./docker/start.py /start.py
diff --git a/docker/README.md b/docker/README.md
index 8c337149..008a9ff7 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -94,6 +94,21 @@ The following environment variables are supported in run mode:
* `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`.
* `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`.
+## Generating an (admin) user
+
+After synapse is running, you may wish to create a user via `register_new_matrix_user`.
+
+This requires a `registration_shared_secret` to be set in your config file. Synapse
+must be restarted to pick up this change.
+
+You can then call the script:
+
+```
+docker exec -it synapse register_new_matrix_user http://localhost:8008 -c /data/homeserver.yaml --help
+```
+
+Remember to remove the `registration_shared_secret` and restart if you no-longer need it.
+
## TLS support
The default configuration exposes a single HTTP port: http://localhost:8008. It
diff --git a/docker/start.py b/docker/start.py
index 2a25c938..9f081341 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -120,7 +120,7 @@ def generate_config_from_template(config_dir, config_path, environ, ownership):
if ownership is not None:
subprocess.check_output(["chown", "-R", ownership, "/data"])
- args = ["su-exec", ownership] + args
+ args = ["gosu", ownership] + args
subprocess.check_output(args)
@@ -172,8 +172,8 @@ def run_generate_config(environ, ownership):
# make sure that synapse has perms to write to the data dir.
subprocess.check_output(["chown", ownership, data_dir])
- args = ["su-exec", ownership] + args
- os.execv("/sbin/su-exec", args)
+ args = ["gosu", ownership] + args
+ os.execv("/usr/sbin/gosu", args)
else:
os.execv("/usr/local/bin/python", args)
@@ -189,7 +189,7 @@ def main(args, environ):
ownership = "{}:{}".format(desired_uid, desired_gid)
if ownership is None:
- log("Will not perform chmod/su-exec as UserID already matches request")
+ log("Will not perform chmod/gosu as UserID already matches request")
# In generate mode, generate a configuration and missing keys, then exit
if mode == "generate":
@@ -236,8 +236,8 @@ running with 'migrate_config'. See the README for more details.
args = ["python", "-m", synapse_worker, "--config-path", config_path]
if ownership is not None:
- args = ["su-exec", ownership] + args
- os.execv("/sbin/su-exec", args)
+ args = ["gosu", ownership] + args
+ os.execv("/usr/sbin/gosu", args)
else:
os.execv("/usr/local/bin/python", args)
diff --git a/docs/ACME.md b/docs/ACME.md
index f4c47404..a7a498f5 100644
--- a/docs/ACME.md
+++ b/docs/ACME.md
@@ -12,13 +12,14 @@ introduced support for automatically provisioning certificates through
In [March 2019](https://community.letsencrypt.org/t/end-of-life-plan-for-acmev1/88430),
Let's Encrypt announced that they were deprecating version 1 of the ACME
protocol, with the plan to disable the use of it for new accounts in
-November 2019, and for existing accounts in June 2020.
+November 2019, for new domains in June 2020, and for existing accounts and
+domains in June 2021.
Synapse doesn't currently support version 2 of the ACME protocol, which
means that:
* for existing installs, Synapse's built-in ACME support will continue
- to work until June 2020.
+ to work until June 2021.
* for new installs, this feature will not work at all.
Either way, it is recommended to move from Synapse's ACME support
diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md
index 64ea7b6a..ae01a543 100644
--- a/docs/admin_api/purge_room.md
+++ b/docs/admin_api/purge_room.md
@@ -5,6 +5,8 @@ This API will remove all trace of a room from your database.
All local users must have left the room before it can be removed.
+See also: [Delete Room API](rooms.md#delete-room-api)
+
The API is:
```
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 624e7745..15b83e98 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -318,3 +318,129 @@ Response:
"state_events": 93534
}
```
+
+# Room Members API
+
+The Room Members admin API allows server admins to get a list of all members of a room.
+
+The response includes the following fields:
+
+* `members` - A list of all the members that are present in the room, represented by their ids.
+* `total` - Total number of members in the room.
+
+## Usage
+
+A standard request:
+
+```
+GET /_synapse/admin/v1/rooms/<room_id>/members
+
+{}
+```
+
+Response:
+
+```
+{
+ "members": [
+ "@foo:matrix.org",
+ "@bar:matrix.org",
+ "@foobar:matrix.org
+ ],
+ "total": 3
+}
+```
+
+# Delete Room API
+
+The Delete Room admin API allows server admins to remove rooms from server
+and block these rooms.
+It is a combination and improvement of "[Shutdown room](shutdown_room.md)"
+and "[Purge room](purge_room.md)" API.
+
+Shuts down a room. Moves all local users and room aliases automatically to a
+new room if `new_room_user_id` is set. Otherwise local users only
+leave the room without any information.
+
+The new room will be created with the user specified by the `new_room_user_id` parameter
+as room administrator and will contain a message explaining what happened. Users invited
+to the new room will have power level `-10` by default, and thus be unable to speak.
+
+If `block` is `True` it prevents new joins to the old room.
+
+This API will remove all trace of the old room from your database after removing
+all local users.
+Depending on the amount of history being purged a call to the API may take
+several minutes or longer.
+
+The local server will only have the power to move local user and room aliases to
+the new room. Users on other servers will be unaffected.
+
+The API is:
+
+```json
+POST /_synapse/admin/v1/rooms/<room_id>/delete
+```
+
+with a body of:
+```json
+{
+ "new_room_user_id": "@someuser:example.com",
+ "room_name": "Content Violation Notification",
+ "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.",
+ "block": true
+}
+```
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see [README.rst](README.rst).
+
+A response body like the following is returned:
+
+```json
+{
+ "kicked_users": [
+ "@foobar:example.com"
+ ],
+ "failed_to_kick_users": [],
+ "local_aliases": [
+ "#badroom:example.com",
+ "#evilsaloon:example.com"
+ ],
+ "new_room_id": "!newroomid:example.com"
+}
+```
+
+## Parameters
+
+The following parameters should be set in the URL:
+
+* `room_id` - The ID of the room.
+
+The following JSON body parameters are available:
+
+* `new_room_user_id` - Optional. If set, a new room will be created with this user ID
+ as the creator and admin, and all users in the old room will be moved into that
+ room. If not set, no new room will be created and the users will just be removed
+ from the old room. The user ID must be on the local server, but does not necessarily
+ have to belong to a registered user.
+* `room_name` - Optional. A string representing the name of the room that new users will be
+ invited to. Defaults to `Content Violation Notification`
+* `message` - Optional. A string containing the first message that will be sent as
+ `new_room_user_id` in the new room. Ideally this will clearly convey why the
+ original room was shut down. Defaults to `Sharing illegal content on this server
+ is not permitted and rooms in violation will be blocked.`
+* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to
+ join the room. Defaults to `false`.
+
+The JSON body must not be empty. The body must be at least `{}`.
+
+## Response
+
+The following fields are returned in the JSON response body:
+
+* `kicked_users` - An array of users (`user_id`) that were kicked.
+* `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked.
+* `local_aliases` - An array of strings representing the local aliases that were migrated from
+ the old room to the new.
+* `new_room_id` - A string representing the room ID of the new room.
diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md
index 54ce1cd2..808caeec 100644
--- a/docs/admin_api/shutdown_room.md
+++ b/docs/admin_api/shutdown_room.md
@@ -10,6 +10,8 @@ disallow any further invites or joins.
The local server will only have the power to move local user and room aliases to
the new room. Users on other servers will be unaffected.
+See also: [Delete Room API](rooms.md#delete-room-api)
+
## API
You will need to authenticate with an access token for an admin user.
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 7b030a62..be05128b 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -91,10 +91,14 @@ Body parameters:
- ``admin``, optional, defaults to ``false``.
-- ``deactivated``, optional, defaults to ``false``.
+- ``deactivated``, optional. If unspecified, deactivation state will be left
+ unchanged on existing accounts and set to ``false`` for new accounts.
If the user already exists then optional parameters default to the current value.
+In order to re-activate an account ``deactivated`` must be set to ``false``. If
+users do not login via single-sign-on, a new ``password`` must be provided.
+
List Accounts
=============
diff --git a/docs/jwt.md b/docs/jwt.md
index 289d66b3..5be9fd26 100644
--- a/docs/jwt.md
+++ b/docs/jwt.md
@@ -20,12 +20,18 @@ follows:
Note that the login type of `m.login.jwt` is supported, but is deprecated. This
will be removed in a future version of Synapse.
-The `jwt` should encode the local part of the user ID as the standard `sub`
-claim. In the case that the token is not valid, the homeserver must respond with
-`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
+The `token` field should include the JSON web token with the following claims:
-(Note that this differs from the token based logins which return a
-`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.)
+* The `sub` (subject) claim is required and should encode the local part of the
+ user ID.
+* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`)
+ claims are optional, but validated if present.
+* The issuer (`iss`) claim is optional, but required and validated if configured.
+* The audience (`aud`) claim is optional, but required and validated if configured.
+ Providing the audience claim when not configured will cause validation to fail.
+
+In the case that the token is not valid, the homeserver must respond with
+`403 Forbidden` and an error code of `M_FORBIDDEN`.
As with other login types, there are additional fields (e.g. `device_id` and
`initial_device_display_name`) which can be included in the above request.
@@ -55,7 +61,8 @@ sample settings.
Although JSON Web Tokens are typically generated from an external server, the
examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
-1. Configure Synapse with JWT logins:
+1. Configure Synapse with JWT logins, note that this example uses a pre-shared
+ secret and an algorithm of HS256:
```yaml
jwt_config:
diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md
index 5d9ae670..fef1d47e 100644
--- a/docs/password_auth_providers.md
+++ b/docs/password_auth_providers.md
@@ -19,102 +19,103 @@ password auth provider module implementations:
Password auth provider classes must provide the following methods:
-*class* `SomeProvider.parse_config`(*config*)
+* `parse_config(config)`
+ This method is passed the `config` object for this module from the
+ homeserver configuration file.
-> This method is passed the `config` object for this module from the
-> homeserver configuration file.
->
-> It should perform any appropriate sanity checks on the provided
-> configuration, and return an object which is then passed into
-> `__init__`.
+ It should perform any appropriate sanity checks on the provided
+ configuration, and return an object which is then passed into
-*class* `SomeProvider`(*config*, *account_handler*)
+ This method should have the `@staticmethod` decoration.
-> The constructor is passed the config object returned by
-> `parse_config`, and a `synapse.module_api.ModuleApi` object which
-> allows the password provider to check if accounts exist and/or create
-> new ones.
+* `__init__(self, config, account_handler)`
+
+ The constructor is passed the config object returned by
+ `parse_config`, and a `synapse.module_api.ModuleApi` object which
+ allows the password provider to check if accounts exist and/or create
+ new ones.
## Optional methods
-Password auth provider classes may optionally provide the following
-methods.
-
-*class* `SomeProvider.get_db_schema_files`()
-
-> This method, if implemented, should return an Iterable of
-> `(name, stream)` pairs of database schema files. Each file is applied
-> in turn at initialisation, and a record is then made in the database
-> so that it is not re-applied on the next start.
-
-`someprovider.get_supported_login_types`()
-
-> This method, if implemented, should return a `dict` mapping from a
-> login type identifier (such as `m.login.password`) to an iterable
-> giving the fields which must be provided by the user in the submission
-> to the `/login` api. These fields are passed in the `login_dict`
-> dictionary to `check_auth`.
->
-> For example, if a password auth provider wants to implement a custom
-> login type of `com.example.custom_login`, where the client is expected
-> to pass the fields `secret1` and `secret2`, the provider should
-> implement this method and return the following dict:
->
-> {"com.example.custom_login": ("secret1", "secret2")}
-
-`someprovider.check_auth`(*username*, *login_type*, *login_dict*)
-
-> This method is the one that does the real work. If implemented, it
-> will be called for each login attempt where the login type matches one
-> of the keys returned by `get_supported_login_types`.
->
-> It is passed the (possibly UNqualified) `user` provided by the client,
-> the login type, and a dictionary of login secrets passed by the
-> client.
->
-> The method should return a Twisted `Deferred` object, which resolves
-> to the canonical `@localpart:domain` user id if authentication is
-> successful, and `None` if not.
->
-> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in
-> which case the second field is a callback which will be called with
-> the result from the `/login` call (including `access_token`,
-> `device_id`, etc.)
-
-`someprovider.check_3pid_auth`(*medium*, *address*, *password*)
-
-> This method, if implemented, is called when a user attempts to
-> register or log in with a third party identifier, such as email. It is
-> passed the medium (ex. "email"), an address (ex.
-> "<jdoe@example.com>") and the user's password.
->
-> The method should return a Twisted `Deferred` object, which resolves
-> to a `str` containing the user's (canonical) User ID if
-> authentication was successful, and `None` if not.
->
-> As with `check_auth`, the `Deferred` may alternatively resolve to a
-> `(user_id, callback)` tuple.
-
-`someprovider.check_password`(*user_id*, *password*)
-
-> This method provides a simpler interface than
-> `get_supported_login_types` and `check_auth` for password auth
-> providers that just want to provide a mechanism for validating
-> `m.login.password` logins.
->
-> Iif implemented, it will be called to check logins with an
-> `m.login.password` login type. It is passed a qualified
-> `@localpart:domain` user id, and the password provided by the user.
->
-> The method should return a Twisted `Deferred` object, which resolves
-> to `True` if authentication is successful, and `False` if not.
-
-`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*)
-
-> This method, if implemented, is called when a user logs out. It is
-> passed the qualified user ID, the ID of the deactivated device (if
-> any: access tokens are occasionally created without an associated
-> device ID), and the (now deactivated) access token.
->
-> It may return a Twisted `Deferred` object; the logout request will
-> wait for the deferred to complete but the result is ignored.
+Password auth provider classes may optionally provide the following methods:
+
+* `get_db_schema_files(self)`
+
+ This method, if implemented, should return an Iterable of
+ `(name, stream)` pairs of database schema files. Each file is applied
+ in turn at initialisation, and a record is then made in the database
+ so that it is not re-applied on the next start.
+
+* `get_supported_login_types(self)`
+
+ This method, if implemented, should return a `dict` mapping from a
+ login type identifier (such as `m.login.password`) to an iterable
+ giving the fields which must be provided by the user in the submission
+ to [the `/login` API](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
+ These fields are passed in the `login_dict` dictionary to `check_auth`.
+
+ For example, if a password auth provider wants to implement a custom
+ login type of `com.example.custom_login`, where the client is expected
+ to pass the fields `secret1` and `secret2`, the provider should
+ implement this method and return the following dict:
+
+ ```python
+ {"com.example.custom_login": ("secret1", "secret2")}
+ ```
+
+* `check_auth(self, username, login_type, login_dict)`
+
+ This method does the real work. If implemented, it
+ will be called for each login attempt where the login type matches one
+ of the keys returned by `get_supported_login_types`.
+
+ It is passed the (possibly unqualified) `user` field provided by the client,
+ the login type, and a dictionary of login secrets passed by the
+ client.
+
+ The method should return an `Awaitable` object, which resolves
+ to the canonical `@localpart:domain` user ID if authentication is
+ successful, and `None` if not.
+
+ Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in
+ which case the second field is a callback which will be called with
+ the result from the `/login` call (including `access_token`,
+ `device_id`, etc.)
+
+* `check_3pid_auth(self, medium, address, password)`
+
+ This method, if implemented, is called when a user attempts to
+ register or log in with a third party identifier, such as email. It is
+ passed the medium (ex. "email"), an address (ex.
+ "<jdoe@example.com>") and the user's password.
+
+ The method should return an `Awaitable` object, which resolves
+ to a `str` containing the user's (canonical) User id if
+ authentication was successful, and `None` if not.
+
+ As with `check_auth`, the `Awaitable` may alternatively resolve to a
+ `(user_id, callback)` tuple.
+
+* `check_password(self, user_id, password)`
+
+ This method provides a simpler interface than
+ `get_supported_login_types` and `check_auth` for password auth
+ providers that just want to provide a mechanism for validating
+ `m.login.password` logins.
+
+ If implemented, it will be called to check logins with an
+ `m.login.password` login type. It is passed a qualified
+ `@localpart:domain` user id, and the password provided by the user.
+
+ The method should return an `Awaitable` object, which resolves
+ to `True` if authentication is successful, and `False` if not.
+
+* `on_logged_out(self, user_id, device_id, access_token)`
+
+ This method, if implemented, is called when a user logs out. It is
+ passed the qualified user ID, the ID of the deactivated device (if
+ any: access tokens are occasionally created without an associated
+ device ID), and the (now deactivated) access token.
+
+ It may return an `Awaitable` object; the logout request will
+ wait for the `Awaitable` to complete, but the result is ignored.
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 13199000..7bfb96ef 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -38,6 +38,11 @@ the reverse proxy and the homeserver.
server {
listen 443 ssl;
listen [::]:443 ssl;
+
+ # For the federation port
+ listen 8448 ssl default_server;
+ listen [::]:8448 ssl default_server;
+
server_name matrix.example.com;
location /_matrix {
@@ -48,17 +53,6 @@ server {
client_max_body_size 10M;
}
}
-
-server {
- listen 8448 ssl default_server;
- listen [::]:8448 ssl default_server;
- server_name example.com;
-
- location / {
- proxy_pass http://localhost:8008;
- proxy_set_header X-Forwarded-For $remote_addr;
- }
-}
```
**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 164a1040..b21e36bb 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -102,7 +102,9 @@ pid_file: DATADIR/homeserver.pid
#gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
-# and sync operations. The default value is -1, means no upper limit.
+# and sync operations. The default value is 100. -1 means no upper limit.
+#
+# Uncomment the following to increase the limit to 5000.
#
#filter_timeline_limit: 5000
@@ -118,38 +120,6 @@ pid_file: DATADIR/homeserver.pid
#
#enable_search: false
-# Restrict federation to the following whitelist of domains.
-# N.B. we recommend also firewalling your federation listener to limit
-# inbound federation traffic as early as possible, rather than relying
-# purely on this application-layer restriction. If not specified, the
-# default is to whitelist everything.
-#
-#federation_domain_whitelist:
-# - lon.example.com
-# - nyc.example.com
-# - syd.example.com
-
-# Prevent federation requests from being sent to the following
-# blacklist IP address CIDR ranges. If this option is not specified, or
-# specified with an empty list, no ip range blacklist will be enforced.
-#
-# As of Synapse v1.4.0 this option also affects any outbound requests to identity
-# servers provided by user input.
-#
-# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
-# listed here, since they correspond to unroutable addresses.)
-#
-federation_ip_range_blacklist:
- - '127.0.0.0/8'
- - '10.0.0.0/8'
- - '172.16.0.0/12'
- - '192.168.0.0/16'
- - '100.64.0.0/10'
- - '169.254.0.0/16'
- - '::1/128'
- - 'fe80::/64'
- - 'fc00::/7'
-
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
@@ -178,7 +148,7 @@ federation_ip_range_blacklist:
# names: a list of names of HTTP resources. See below for a list of
# valid resource names.
#
-# compress: set to true to enable HTTP comression for this resource.
+# compress: set to true to enable HTTP compression for this resource.
#
# additional_resources: Only valid for an 'http' listener. A map of
# additional endpoints which should be loaded via dynamic modules.
@@ -608,6 +578,39 @@ acme:
+# Restrict federation to the following whitelist of domains.
+# N.B. we recommend also firewalling your federation listener to limit
+# inbound federation traffic as early as possible, rather than relying
+# purely on this application-layer restriction. If not specified, the
+# default is to whitelist everything.
+#
+#federation_domain_whitelist:
+# - lon.example.com
+# - nyc.example.com
+# - syd.example.com
+
+# Prevent federation requests from being sent to the following
+# blacklist IP address CIDR ranges. If this option is not specified, or
+# specified with an empty list, no ip range blacklist will be enforced.
+#
+# As of Synapse v1.4.0 this option also affects any outbound requests to identity
+# servers provided by user input.
+#
+# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+# listed here, since they correspond to unroutable addresses.)
+#
+federation_ip_range_blacklist:
+ - '127.0.0.0/8'
+ - '10.0.0.0/8'
+ - '172.16.0.0/12'
+ - '192.168.0.0/16'
+ - '100.64.0.0/10'
+ - '169.254.0.0/16'
+ - '::1/128'
+ - 'fe80::/64'
+ - 'fc00::/7'
+
+
## Caching ##
# Caching can be configured through the following options.
@@ -682,7 +685,7 @@ caches:
#database:
# name: psycopg2
# args:
-# user: synapse
+# user: synapse_user
# password: secretpassword
# database: synapse
# host: localhost
@@ -1811,6 +1814,9 @@ sso:
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
+# Additionally, the expiration time ("exp"), not before time ("nbf"),
+# and issued at ("iat") claims are validated if present.
+#
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
@@ -1838,6 +1844,24 @@ sso:
#
#algorithm: "provided-by-your-issuer"
+ # The issuer to validate the "iss" claim against.
+ #
+ # Optional, if provided the "iss" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ #issuer: "provided-by-your-issuer"
+
+ # A list of audiences to validate the "aud" claim against.
+ #
+ # Optional, if provided the "aud" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ # Note that if the "aud" claim is included in a JSON web token then
+ # validation will fail without configuring audiences.
+ #
+ #audiences:
+ # - "provided-by-your-issuer"
+
password_config:
# Uncomment to disable password login
@@ -1927,8 +1951,8 @@ email:
#
#notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
- # app_name defines the default value for '%(app)s' in notif_from. It
- # defaults to 'Matrix'.
+ # app_name defines the default value for '%(app)s' in notif_from and email
+ # subjects. It defaults to 'Matrix'.
#
#app_name: my_branded_matrix_server
@@ -1997,6 +2021,73 @@ email:
#
#template_dir: "res/templates"
+ # Subjects to use when sending emails from Synapse.
+ #
+ # The placeholder '%(app)s' will be replaced with the value of the 'app_name'
+ # setting above, or by a value dictated by the Matrix client application.
+ #
+ # If a subject isn't overridden in this configuration file, the value used as
+ # its example will be used.
+ #
+ #subjects:
+
+ # Subjects for notification emails.
+ #
+ # On top of the '%(app)s' placeholder, these can use the following
+ # placeholders:
+ #
+ # * '%(person)s', which will be replaced by the display name of the user(s)
+ # that sent the message(s), e.g. "Alice and Bob".
+ # * '%(room)s', which will be replaced by the name of the room the
+ # message(s) have been sent to, e.g. "My super room".
+ #
+ # See the example provided for each setting to see which placeholder can be
+ # used and how to use them.
+ #
+ # Subject to use to notify about one message from one or more user(s) in a
+ # room which has a name.
+ #message_from_person_in_room: "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room..."
+ #
+ # Subject to use to notify about one message from one or more user(s) in a
+ # room which doesn't have a name.
+ #message_from_person: "[%(app)s] You have a message on %(app)s from %(person)s..."
+ #
+ # Subject to use to notify about multiple messages from one or more users in
+ # a room which doesn't have a name.
+ #messages_from_person: "[%(app)s] You have messages on %(app)s from %(person)s..."
+ #
+ # Subject to use to notify about multiple messages in a room which has a
+ # name.
+ #messages_in_room: "[%(app)s] You have messages on %(app)s in the %(room)s room..."
+ #
+ # Subject to use to notify about multiple messages in multiple rooms.
+ #messages_in_room_and_others: "[%(app)s] You have messages on %(app)s in the %(room)s room and others..."
+ #
+ # Subject to use to notify about multiple messages from multiple persons in
+ # multiple rooms. This is similar to the setting above except it's used when
+ # the room in which the notification was triggered has no name.
+ #messages_from_person_and_others: "[%(app)s] You have messages on %(app)s from %(person)s and others..."
+ #
+ # Subject to use to notify about an invite to a room which has a name.
+ #invite_from_person_to_room: "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s..."
+ #
+ # Subject to use to notify about an invite to a room which doesn't have a
+ # name.
+ #invite_from_person: "[%(app)s] %(person)s has invited you to chat on %(app)s..."
+
+ # Subject for emails related to account administration.
+ #
+ # On top of the '%(app)s' placeholder, these one can use the
+ # '%(server_name)s' placeholder, which will be replaced by the value of the
+ # 'server_name' setting in your Synapse configuration.
+ #
+ # Subject to use when sending a password reset email.
+ #password_reset: "[%(server_name)s] Password reset"
+ #
+ # Subject to use when sending a verification email to assert an address's
+ # ownership.
+ #email_validation: "[%(server_name)s] Validate your email"
+
# Password providers allow homeserver administrators to integrate
# their Synapse installation with existing authentication methods
@@ -2307,3 +2398,57 @@ opentracing:
#
# logging:
# false
+
+
+## Workers ##
+
+# Disables sending of outbound federation transactions on the main process.
+# Uncomment if using a federation sender worker.
+#
+#send_federation: false
+
+# It is possible to run multiple federation sender workers, in which case the
+# work is balanced across them.
+#
+# This configuration must be shared between all federation sender workers, and if
+# changed all federation sender workers must be stopped at the same time and then
+# started, to ensure that all instances are running with the same config (otherwise
+# events may be dropped).
+#
+#federation_sender_instances:
+# - federation_sender1
+
+# When using workers this should be a map from `worker_name` to the
+# HTTP replication listener of the worker, if configured.
+#
+#instance_map:
+# worker1:
+# host: localhost
+# port: 8034
+
+# Experimental: When using workers you can define which workers should
+# handle event persistence and typing notifications. Any worker
+# specified here must also be in the `instance_map`.
+#
+#stream_writers:
+# events: worker1
+# typing: worker1
+
+
+# Configuration for Redis when using workers. This *must* be enabled when
+# using workers (unless using old style direct TCP configuration).
+#
+redis:
+ # Uncomment the below to enable Redis support.
+ #
+ #enabled: true
+
+ # Optional host and port to use to connect to redis. Defaults to
+ # localhost and 6379
+ #
+ #host: localhost
+ #port: 6379
+
+ # Optional password if configured on the Redis instance
+ #
+ #password: <secret_password>
diff --git a/docs/synctl_workers.md b/docs/synctl_workers.md
new file mode 100644
index 00000000..8da4a318
--- /dev/null
+++ b/docs/synctl_workers.md
@@ -0,0 +1,32 @@
+### Using synctl with workers
+
+If you want to use `synctl` to manage your synapse processes, you will need to
+create an an additional configuration file for the main synapse process. That
+configuration should look like this:
+
+```yaml
+worker_app: synapse.app.homeserver
+```
+
+Additionally, each worker app must be configured with the name of a "pid file",
+to which it will write its process ID when it starts. For example, for a
+synchrotron, you might write:
+
+```yaml
+worker_pid_file: /home/matrix/synapse/worker1.pid
+```
+
+Finally, to actually run your worker-based synapse, you must pass synctl the `-a`
+commandline option to tell it to operate on all the worker configurations found
+in the given directory, e.g.:
+
+ synctl -a $CONFIG/workers start
+
+Currently one should always restart all workers when restarting or upgrading
+synapse, unless you explicitly know it's safe not to. For instance, restarting
+synapse without restarting all the synchrotrons may result in broken typing
+notifications.
+
+To manipulate a specific worker, you pass the -w option to synctl:
+
+ synctl -w $CONFIG/workers/worker1.yaml restart
diff --git a/docs/workers.md b/docs/workers.md
index f4cbbc04..38bd758e 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -16,69 +16,106 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
be used for demo purposes and any admin considering workers should already be
running PostgreSQL.
-## Master/worker communication
+## Main process/worker communication
-The workers communicate with the master process via a Synapse-specific protocol
-called 'replication' (analogous to MySQL- or Postgres-style database
-replication) which feeds a stream of relevant data from the master to the
-workers so they can be kept in sync with the master process and database state.
+The processes communicate with each other via a Synapse-specific protocol called
+'replication' (analogous to MySQL- or Postgres-style database replication) which
+feeds streams of newly written data between processes so they can be kept in
+sync with the database state.
-Additionally, workers may make HTTP requests to the master, to send information
-in the other direction. Typically this is used for operations which need to
-wait for a reply - such as sending an event.
+Additionally, processes may make HTTP requests to each other. Typically this is
+used for operations which need to wait for a reply - such as sending an event.
-## Configuration
+As of Synapse v1.13.0, it is possible to configure Synapse to send replication
+via a [Redis pub/sub channel](https://redis.io/topics/pubsub), and is now the
+recommended way of configuring replication. This is an alternative to the old
+direct TCP connections to the main process: rather than all the workers
+connecting to the main process, all the workers and the main process connect to
+Redis, which relays replication commands between processes. This can give a
+significant cpu saving on the main process and will be a prerequisite for
+upcoming performance improvements.
+
+(See the [Architectural diagram](#architectural-diagram) section at the end for
+a visualisation of what this looks like)
+
+
+## Setting up workers
+
+A Redis server is required to manage the communication between the processes.
+(The older direct TCP connections are now deprecated.) The Redis server
+should be installed following the normal procedure for your distribution (e.g.
+`apt install redis-server` on Debian). It is safe to use an existing Redis
+deployment if you have one.
+
+Once installed, check that Redis is running and accessible from the host running
+Synapse, for example by executing `echo PING | nc -q1 localhost 6379` and seeing
+a response of `+PONG`.
+
+The appropriate dependencies must also be installed for Synapse. If using a
+virtualenv, these can be installed with:
+
+```sh
+pip install matrix-synapse[redis]
+```
+
+Note that these dependencies are included when synapse is installed with `pip
+install matrix-synapse[all]`. They are also included in the debian packages from
+`matrix.org` and in the docker images at
+https://hub.docker.com/r/matrixdotorg/synapse/.
To make effective use of the workers, you will need to configure an HTTP
reverse-proxy such as nginx or haproxy, which will direct incoming requests to
-the correct worker, or to the main synapse instance. Note that this includes
-requests made to the federation port. See [reverse_proxy.md](reverse_proxy.md)
+the correct worker, or to the main synapse instance. See [reverse_proxy.md](reverse_proxy.md)
for information on setting up a reverse proxy.
-To enable workers, you need to add *two* replication listeners to the
-main Synapse configuration file (`homeserver.yaml`). For example:
+To enable workers you should create a configuration file for each worker
+process. Each worker configuration file inherits the configuration of the shared
+homeserver configuration file. You can then override configuration specific to
+that worker, e.g. the HTTP listener that it provides (if any); logging
+configuration; etc. You should minimise the number of overrides though to
+maintain a usable config.
+
+Next you need to add both a HTTP replication listener and redis config to the
+shared Synapse configuration file (`homeserver.yaml`). For example:
```yaml
+# extend the existing `listeners` section. This defines the ports that the
+# main process will listen on.
listeners:
- # The TCP replication port
- - port: 9092
- bind_address: '127.0.0.1'
- type: replication
-
# The HTTP replication port
- port: 9093
bind_address: '127.0.0.1'
type: http
resources:
- names: [replication]
+
+redis:
+ enabled: true
```
-Under **no circumstances** should these replication API listeners be exposed to
-the public internet; they have no authentication and are unencrypted.
+See the sample config for the full documentation of each option.
-You should then create a set of configs for the various worker processes. Each
-worker configuration file inherits the configuration of the main homeserver
-configuration file. You can then override configuration specific to that
-worker, e.g. the HTTP listener that it provides (if any); logging
-configuration; etc. You should minimise the number of overrides though to
-maintain a usable config.
+Under **no circumstances** should the replication listener be exposed to the
+public internet; it has no authentication and is unencrypted.
In the config file for each worker, you must specify the type of worker
-application (`worker_app`). The currently available worker applications are
-listed below. You must also specify the replication endpoints that it should
-talk to on the main synapse process. `worker_replication_host` should specify
-the host of the main synapse, `worker_replication_port` should point to the TCP
-replication listener port and `worker_replication_http_port` should point to
-the HTTP replication port.
+application (`worker_app`), and you should specify a unqiue name for the worker
+(`worker_name`). The currently available worker applications are listed below.
+You must also specify the HTTP replication endpoint that it should talk to on
+the main synapse process. `worker_replication_host` should specify the host of
+the main synapse and `worker_replication_http_port` should point to the HTTP
+replication port. If the worker will handle HTTP requests then the
+`worker_listeners` option should be set with a `http` listener, in the same way
+as the `listeners` option in the shared config.
For example:
```yaml
-worker_app: synapse.app.synchrotron
+worker_app: synapse.app.generic_worker
+worker_name: worker1
-# The replication listener on the synapse to talk to.
+# The replication listener on the main synapse process.
worker_replication_host: 127.0.0.1
-worker_replication_port: 9092
worker_replication_http_port: 9093
worker_listeners:
@@ -87,13 +124,14 @@ worker_listeners:
resources:
- names:
- client
+ - federation
-worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
+worker_log_config: /home/matrix/synapse/config/worker1_log_config.yaml
```
-...is a full configuration for a synchrotron worker instance, which will expose a
-plain HTTP `/sync` endpoint on port 8083 separately from the `/sync` endpoint provided
-by the main synapse.
+...is a full configuration for a generic worker instance, which will expose a
+plain HTTP endpoint on port 8083 separately serving various endpoints, e.g.
+`/sync`, which are listed below.
Obviously you should configure your reverse-proxy to route the relevant
endpoints to the worker (`localhost:8083` in the above example).
@@ -102,127 +140,24 @@ Finally, you need to start your worker processes. This can be done with either
`synctl` or your distribution's preferred service manager such as `systemd`. We
recommend the use of `systemd` where available: for information on setting up
`systemd` to start synapse workers, see
-[systemd-with-workers](systemd-with-workers). To use `synctl`, see below.
+[systemd-with-workers](systemd-with-workers). To use `synctl`, see
+[synctl_workers.md](synctl_workers.md).
-### **Experimental** support for replication over redis
-
-As of Synapse v1.13.0, it is possible to configure Synapse to send replication
-via a [Redis pub/sub channel](https://redis.io/topics/pubsub). This is an
-alternative to direct TCP connections to the master: rather than all the
-workers connecting to the master, all the workers and the master connect to
-Redis, which relays replication commands between processes. This can give a
-significant cpu saving on the master and will be a prerequisite for upcoming
-performance improvements.
-
-Note that this support is currently experimental; you may experience lost
-messages and similar problems! It is strongly recommended that admins setting
-up workers for the first time use direct TCP replication as above.
-
-To configure Synapse to use Redis:
-
-1. Install Redis following the normal procedure for your distribution - for
- example, on Debian, `apt install redis-server`. (It is safe to use an
- existing Redis deployment if you have one: we use a pub/sub stream named
- according to the `server_name` of your synapse server.)
-2. Check Redis is running and accessible: you should be able to `echo PING | nc -q1
- localhost 6379` and get a response of `+PONG`.
-3. Install the python prerequisites. If you installed synapse into a
- virtualenv, this can be done with:
- ```sh
- pip install matrix-synapse[redis]
- ```
- The debian packages from matrix.org already include the required
- dependencies.
-4. Add config to the shared configuration (`homeserver.yaml`):
- ```yaml
- redis:
- enabled: true
- ```
- Optional parameters which can go alongside `enabled` are `host`, `port`,
- `password`. Normally none of these are required.
-5. Restart master and all workers.
-
-Once redis replication is in use, `worker_replication_port` is redundant and
-can be removed from the worker configuration files. Similarly, the
-configuration for the `listener` for the TCP replication port can be removed
-from the main configuration file. Note that the HTTP replication port is
-still required.
-
-### Using synctl
-
-If you want to use `synctl` to manage your synapse processes, you will need to
-create an an additional configuration file for the master synapse process. That
-configuration should look like this:
-
-```yaml
-worker_app: synapse.app.homeserver
-```
-
-Additionally, each worker app must be configured with the name of a "pid file",
-to which it will write its process ID when it starts. For example, for a
-synchrotron, you might write:
-
-```yaml
-worker_pid_file: /home/matrix/synapse/synchrotron.pid
-```
-
-Finally, to actually run your worker-based synapse, you must pass synctl the `-a`
-commandline option to tell it to operate on all the worker configurations found
-in the given directory, e.g.:
-
- synctl -a $CONFIG/workers start
-
-Currently one should always restart all workers when restarting or upgrading
-synapse, unless you explicitly know it's safe not to. For instance, restarting
-synapse without restarting all the synchrotrons may result in broken typing
-notifications.
-
-To manipulate a specific worker, you pass the -w option to synctl:
-
- synctl -w $CONFIG/workers/synchrotron.yaml restart
## Available worker applications
-### `synapse.app.pusher`
-
-Handles sending push notifications to sygnal and email. Doesn't handle any
-REST endpoints itself, but you should set `start_pushers: False` in the
-shared configuration file to stop the main synapse sending these notifications.
-
-Note this worker cannot be load-balanced: only one instance should be active.
-
-### `synapse.app.synchrotron`
+### `synapse.app.generic_worker`
-The synchrotron handles `sync` requests from clients. In particular, it can
-handle REST endpoints matching the following regular expressions:
+This worker can handle API requests matching the following regular
+expressions:
+ # Sync requests
^/_matrix/client/(v2_alpha|r0)/sync$
^/_matrix/client/(api/v1|v2_alpha|r0)/events$
^/_matrix/client/(api/v1|r0)/initialSync$
^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$
-The above endpoints should all be routed to the synchrotron worker by the
-reverse-proxy configuration.
-
-It is possible to run multiple instances of the synchrotron to scale
-horizontally. In this case the reverse-proxy should be configured to
-load-balance across the instances, though it will be more efficient if all
-requests from a particular user are routed to a single instance. Extracting
-a userid from the access token is currently left as an exercise for the reader.
-
-### `synapse.app.appservice`
-
-Handles sending output traffic to Application Services. Doesn't handle any
-REST endpoints itself, but you should set `notify_appservices: False` in the
-shared configuration file to stop the main synapse sending these notifications.
-
-Note this worker cannot be load-balanced: only one instance should be active.
-
-### `synapse.app.federation_reader`
-
-Handles a subset of federation endpoints. In particular, it can handle REST
-endpoints matching the following regular expressions:
-
+ # Federation requests
^/_matrix/federation/v1/event/
^/_matrix/federation/v1/state/
^/_matrix/federation/v1/state_ids/
@@ -242,40 +177,145 @@ endpoints matching the following regular expressions:
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/user/devices/
- ^/_matrix/federation/v1/send/
^/_matrix/federation/v1/get_groups_publicised$
^/_matrix/key/v2/query
+ # Inbound federation transaction request
+ ^/_matrix/federation/v1/send/
+
+ # Client API requests
+ ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/joined_members$
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
+ ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
+ ^/_matrix/client/(api/v1|r0|unstable)/keys/query$
+ ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
+ ^/_matrix/client/versions$
+ ^/_matrix/client/(api/v1|r0|unstable)/voip/turnServer$
+ ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
+ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
+ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
+
+ # Registration/login requests
+ ^/_matrix/client/(api/v1|r0|unstable)/login$
+ ^/_matrix/client/(r0|unstable)/register$
+ ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$
+
+ # Event sending requests
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
+ ^/_matrix/client/(api/v1|r0|unstable)/join/
+ ^/_matrix/client/(api/v1|r0|unstable)/profile/
+
+
Additionally, the following REST endpoints can be handled for GET requests:
^/_matrix/federation/v1/groups/
-The above endpoints should all be routed to the federation_reader worker by the
-reverse-proxy configuration.
+Pagination requests can also be handled, but all requests for a given
+room must be routed to the same instance. Additionally, care must be taken to
+ensure that the purge history admin API is not used while pagination requests
+for the room are in flight:
+
+ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$
+
+Note that a HTTP listener with `client` and `federation` resources must be
+configured in the `worker_listeners` option in the worker config.
+
+
+#### Load balancing
+
+It is possible to run multiple instances of this worker app, with incoming requests
+being load-balanced between them by the reverse-proxy. However, different endpoints
+have different characteristics and so admins
+may wish to run multiple groups of workers handling different endpoints so that
+load balancing can be done in different ways.
+
+For `/sync` and `/initialSync` requests it will be more efficient if all
+requests from a particular user are routed to a single instance. Extracting a
+user ID from the access token or `Authorization` header is currently left as an
+exercise for the reader. Admins may additionally wish to separate out `/sync`
+requests that have a `since` query parameter from those that don't (and
+`/initialSync`), as requests that don't are known as "initial sync" that happens
+when a user logs in on a new device and can be *very* resource intensive, so
+isolating these requests will stop them from interfering with other users ongoing
+syncs.
+
+Federation and client requests can be balanced via simple round robin.
-The `^/_matrix/federation/v1/send/` endpoint must only be handled by a single
-instance.
+The inbound federation transaction request `^/_matrix/federation/v1/send/`
+should be balanced by source IP so that transactions from the same remote server
+go to the same process.
-Note that `federation` must be added to the listener resources in the worker config:
+Registration/login requests can be handled separately purely to help ensure that
+unexpected load doesn't affect new logins and sign ups.
+
+Finally, event sending requests can be balanced by the room ID in the URI (or
+the full URI, or even just round robin), the room ID is the path component after
+`/rooms/`. If there is a large bridge connected that is sending or may send lots
+of events, then a dedicated set of workers can be provisioned to limit the
+effects of bursts of events from that bridge on events sent by normal users.
+
+#### Stream writers
+
+Additionally, there is *experimental* support for moving writing of specific
+streams (such as events) off of the main process to a particular worker. (This
+is only supported with Redis-based replication.)
+
+Currently support streams are `events` and `typing`.
+
+To enable this, the worker must have a HTTP replication listener configured,
+have a `worker_name` and be listed in the `instance_map` config. For example to
+move event persistence off to a dedicated worker, the shared configuration would
+include:
```yaml
-worker_app: synapse.app.federation_reader
-...
-worker_listeners:
- - type: http
- port: <port>
- resources:
- - names:
- - federation
+instance_map:
+ event_persister1:
+ host: localhost
+ port: 8034
+
+streams_writers:
+ events: event_persister1
```
+
+### `synapse.app.pusher`
+
+Handles sending push notifications to sygnal and email. Doesn't handle any
+REST endpoints itself, but you should set `start_pushers: False` in the
+shared configuration file to stop the main synapse sending push notifications.
+
+Note this worker cannot be load-balanced: only one instance should be active.
+
+### `synapse.app.appservice`
+
+Handles sending output traffic to Application Services. Doesn't handle any
+REST endpoints itself, but you should set `notify_appservices: False` in the
+shared configuration file to stop the main synapse sending appservice notifications.
+
+Note this worker cannot be load-balanced: only one instance should be active.
+
+
### `synapse.app.federation_sender`
Handles sending federation traffic to other servers. Doesn't handle any
REST endpoints itself, but you should set `send_federation: False` in the
shared configuration file to stop the main synapse sending this traffic.
-Note this worker cannot be load-balanced: only one instance should be active.
+If running multiple federation senders then you must list each
+instance in the `federation_sender_instances` option by their `worker_name`.
+All instances must be stopped and started when adding or removing instances.
+For example:
+
+```yaml
+federation_sender_instances:
+ - federation_sender1
+ - federation_sender2
+```
### `synapse.app.media_repository`
@@ -314,46 +354,6 @@ and you must configure a single instance to run the background tasks, e.g.:
media_instance_running_background_jobs: "media-repository-1"
```
-### `synapse.app.client_reader`
-
-Handles client API endpoints. It can handle REST endpoints matching the
-following regular expressions:
-
- ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$
- ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/joined_members$
- ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$
- ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
- ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
- ^/_matrix/client/(api/v1|r0|unstable)/login$
- ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
- ^/_matrix/client/(api/v1|r0|unstable)/keys/query$
- ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
- ^/_matrix/client/versions$
- ^/_matrix/client/(api/v1|r0|unstable)/voip/turnServer$
- ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
- ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
- ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
-
-Additionally, the following REST endpoints can be handled for GET requests:
-
- ^/_matrix/client/(api/v1|r0|unstable)/pushrules/.*$
- ^/_matrix/client/(api/v1|r0|unstable)/groups/.*$
- ^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/account_data/
- ^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/rooms/[^/]*/account_data/
-
-Additionally, the following REST endpoints can be handled, but all requests must
-be routed to the same instance:
-
- ^/_matrix/client/(r0|unstable)/register$
- ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$
-
-Pagination requests can also be handled, but all requests with the same path
-room must be routed to the same instance. Additionally, care must be taken to
-ensure that the purge history admin API is not used while pagination requests
-for the room are in flight:
-
- ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$
-
### `synapse.app.user_dir`
Handles searches in the user directory. It can handle REST endpoints matching
@@ -388,15 +388,48 @@ file. For example:
worker_main_http_uri: http://127.0.0.1:8008
-### `synapse.app.event_creator`
+### Historical apps
-Handles some event creation. It can handle REST endpoints matching:
+*Note:* Historically there used to be more apps, however they have been
+amalgamated into a single `synapse.app.generic_worker` app. The remaining apps
+are ones that do specific processing unrelated to requests, e.g. the `pusher`
+that handles sending out push notifications for new events. The intention is for
+all these to be folded into the `generic_worker` app and to use config to define
+which processes handle the various proccessing such as push notifications.
- ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
- ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/
- ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
- ^/_matrix/client/(api/v1|r0|unstable)/join/
- ^/_matrix/client/(api/v1|r0|unstable)/profile/
-It will create events locally and then send them on to the main synapse
-instance to be persisted and handled.
+## Architectural diagram
+
+The following shows an example setup using Redis and a reverse proxy:
+
+```
+ Clients & Federation
+ |
+ v
+ +-----------+
+ | |
+ | Reverse |
+ | Proxy |
+ | |
+ +-----------+
+ | | |
+ | | | HTTP requests
+ +-------------------+ | +-----------+
+ | +---+ |
+ | | |
+ v v v
++--------------+ +--------------+ +--------------+ +--------------+
+| Main | | Generic | | Generic | | Event |
+| Process | | Worker 1 | | Worker 2 | | Persister |
++--------------+ +--------------+ +--------------+ +--------------+
+ ^ ^ | ^ | | ^ | ^ ^
+ | | | | | | | | | |
+ | | | | | HTTP | | | | |
+ | +----------+<--|---|---------+ | | | |
+ | | +-------------|-->+----------+ |
+ | | | |
+ | | | |
+ v v v v
+====================================================================
+ Redis pub/sub channel
+```
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index e6f4bd1d..d055cf32 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -24,7 +24,6 @@ DISTS = (
"debian:sid",
"ubuntu:xenial",
"ubuntu:bionic",
- "ubuntu:eoan",
"ubuntu:focal",
)
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 66b05688..06479936 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -11,7 +11,7 @@ if [ $# -ge 1 ]
then
files=$*
else
- files="synapse tests scripts-dev scripts"
+ files="synapse tests scripts-dev scripts contrib synctl"
fi
echo "Linting these locations: $files"
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 2eb79519..22a6abd7 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -48,6 +48,7 @@ from synapse.storage.data_stores.main.media_repository import (
)
from synapse.storage.data_stores.main.registration import (
RegistrationBackgroundUpdateStore,
+ find_max_generated_user_id_localpart,
)
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
@@ -622,8 +623,10 @@ class Porter(object):
)
)
- # Step 5. Do final post-processing
+ # Step 5. Set up sequences
+ self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
+ await self._setup_user_id_seq()
self.progress.done()
except Exception as e:
@@ -793,6 +796,13 @@ class Porter(object):
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
+ def _setup_user_id_seq(self):
+ def r(txn):
+ next_id = find_max_generated_user_id_localpart(txn) + 1
+ txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
+
+ return self.postgres_store.db.runInteraction("setup_user_id_seq", r)
+
##############################################
# The following is simply UI stuff
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index cac689d4..c66413f0 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -22,6 +22,7 @@ class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
class SubscriberProtocol:
+ def __init__(self, *args, **kwargs): ...
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
def connectionMade(self): ...
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 8592dee1..5155e719 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.17.0"
+__version__ = "1.18.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 40dc62ef..b53e8451 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -127,8 +127,10 @@ class Auth(object):
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
- member = yield self.state.get_current_state(
- room_id=room_id, event_type=EventTypes.Member, state_key=user_id
+ member = yield defer.ensureDeferred(
+ self.state.get_current_state(
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
+ )
)
membership = member.membership if member else None
@@ -665,8 +667,10 @@ class Auth(object):
)
return member_event.membership, member_event.event_id
except AuthError:
- visibility = yield self.state.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
+ visibility = yield defer.ensureDeferred(
+ self.state.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
+ )
)
if (
visibility
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 5305038c..b3bab1aa 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,13 +17,17 @@
"""Contains exceptions and error codes."""
import logging
+import typing
from http import HTTPStatus
-from typing import Dict, List
+from typing import Dict, List, Optional, Union
from canonicaljson import json
from twisted.web import http
+if typing.TYPE_CHECKING:
+ from synapse.types import JsonDict
+
logger = logging.getLogger(__name__)
@@ -78,11 +82,11 @@ class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.
Attributes:
- code (int): HTTP error code
- msg (str): string describing the error
+ code: HTTP error code
+ msg: string describing the error
"""
- def __init__(self, code, msg):
+ def __init__(self, code: Union[int, HTTPStatus], msg: str):
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
# Some calls to this method pass instances of http.HTTPStatus for `code`.
@@ -123,16 +127,16 @@ class SynapseError(CodeMessageException):
message (as well as an HTTP status code).
Attributes:
- errcode (str): Matrix error code e.g 'M_FORBIDDEN'
+ errcode: Matrix error code e.g 'M_FORBIDDEN'
"""
- def __init__(self, code, msg, errcode=Codes.UNKNOWN):
+ def __init__(self, code: int, msg: str, errcode: str = Codes.UNKNOWN):
"""Constructs a synapse error.
Args:
- code (int): The integer error code (an HTTP response code)
- msg (str): The human-readable error message.
- errcode (str): The matrix error code e.g 'M_FORBIDDEN'
+ code: The integer error code (an HTTP response code)
+ msg: The human-readable error message.
+ errcode: The matrix error code e.g 'M_FORBIDDEN'
"""
super(SynapseError, self).__init__(code, msg)
self.errcode = errcode
@@ -145,10 +149,16 @@ class ProxiedRequestError(SynapseError):
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
Attributes:
- errcode (str): Matrix error code e.g 'M_FORBIDDEN'
+ errcode: Matrix error code e.g 'M_FORBIDDEN'
"""
- def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
+ def __init__(
+ self,
+ code: int,
+ msg: str,
+ errcode: str = Codes.UNKNOWN,
+ additional_fields: Optional[Dict] = None,
+ ):
super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None:
self._additional_fields = {} # type: Dict
@@ -164,12 +174,12 @@ class ConsentNotGivenError(SynapseError):
privacy policy.
"""
- def __init__(self, msg, consent_uri):
+ def __init__(self, msg: str, consent_uri: str):
"""Constructs a ConsentNotGivenError
Args:
- msg (str): The human-readable error message
- consent_url (str): The URL where the user can give their consent
+ msg: The human-readable error message
+ consent_url: The URL where the user can give their consent
"""
super(ConsentNotGivenError, self).__init__(
code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
@@ -185,11 +195,11 @@ class UserDeactivatedError(SynapseError):
authenticated endpoint, but the account has been deactivated.
"""
- def __init__(self, msg):
+ def __init__(self, msg: str):
"""Constructs a UserDeactivatedError
Args:
- msg (str): The human-readable error message
+ msg: The human-readable error message
"""
super(UserDeactivatedError, self).__init__(
code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED
@@ -201,16 +211,16 @@ class FederationDeniedError(SynapseError):
is not on its federation whitelist.
Attributes:
- destination (str): The destination which has been denied
+ destination: The destination which has been denied
"""
- def __init__(self, destination):
+ def __init__(self, destination: Optional[str]):
"""Raised by federation client or server to indicate that we are
are deliberately not attempting to contact a given server because it is
not on our federation whitelist.
Args:
- destination (str): the domain in question
+ destination: the domain in question
"""
self.destination = destination
@@ -228,11 +238,11 @@ class InteractiveAuthIncompleteError(Exception):
(This indicates we should return a 401 with 'result' as the body)
Attributes:
- result (dict): the server response to the request, which should be
+ result: the server response to the request, which should be
passed back to the client
"""
- def __init__(self, result):
+ def __init__(self, result: "JsonDict"):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete"
)
@@ -245,7 +255,6 @@ class UnrecognizedRequestError(SynapseError):
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
- message = None
if len(args) == 0:
message = "Unrecognized request"
else:
@@ -256,7 +265,7 @@ class UnrecognizedRequestError(SynapseError):
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
- def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
+ def __init__(self, msg: str = "Not found", errcode: str = Codes.NOT_FOUND):
super(NotFoundError, self).__init__(404, msg, errcode=errcode)
@@ -282,21 +291,23 @@ class InvalidClientCredentialsError(SynapseError):
M_UNKNOWN_TOKEN respectively.
"""
- def __init__(self, msg, errcode):
+ def __init__(self, msg: str, errcode: str):
super().__init__(code=401, msg=msg, errcode=errcode)
class MissingClientTokenError(InvalidClientCredentialsError):
"""Raised when we couldn't find the access token in a request"""
- def __init__(self, msg="Missing access token"):
+ def __init__(self, msg: str = "Missing access token"):
super().__init__(msg=msg, errcode="M_MISSING_TOKEN")
class InvalidClientTokenError(InvalidClientCredentialsError):
"""Raised when we didn't understand the access token in a request"""
- def __init__(self, msg="Unrecognised access token", soft_logout=False):
+ def __init__(
+ self, msg: str = "Unrecognised access token", soft_logout: bool = False
+ ):
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout
@@ -314,11 +325,11 @@ class ResourceLimitError(SynapseError):
def __init__(
self,
- code,
- msg,
- errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
- admin_contact=None,
- limit_type=None,
+ code: int,
+ msg: str,
+ errcode: str = Codes.RESOURCE_LIMIT_EXCEEDED,
+ admin_contact: Optional[str] = None,
+ limit_type: Optional[str] = None,
):
self.admin_contact = admin_contact
self.limit_type = limit_type
@@ -366,10 +377,10 @@ class StoreError(SynapseError):
class InvalidCaptchaError(SynapseError):
def __init__(
self,
- code=400,
- msg="Invalid captcha.",
- error_url=None,
- errcode=Codes.CAPTCHA_INVALID,
+ code: int = 400,
+ msg: str = "Invalid captcha.",
+ error_url: Optional[str] = None,
+ errcode: str = Codes.CAPTCHA_INVALID,
):
super(InvalidCaptchaError, self).__init__(code, msg, errcode)
self.error_url = error_url
@@ -384,10 +395,10 @@ class LimitExceededError(SynapseError):
def __init__(
self,
- code=429,
- msg="Too Many Requests",
- retry_after_ms=None,
- errcode=Codes.LIMIT_EXCEEDED,
+ code: int = 429,
+ msg: str = "Too Many Requests",
+ retry_after_ms: Optional[int] = None,
+ errcode: str = Codes.LIMIT_EXCEEDED,
):
super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
@@ -400,10 +411,10 @@ class RoomKeysVersionError(SynapseError):
"""A client has tried to upload to a non-current version of the room_keys store
"""
- def __init__(self, current_version):
+ def __init__(self, current_version: str):
"""
Args:
- current_version (str): the current version of the store they should have used
+ current_version: the current version of the store they should have used
"""
super(RoomKeysVersionError, self).__init__(
403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION
@@ -415,7 +426,7 @@ class UnsupportedRoomVersionError(SynapseError):
"""The client's request to create a room used a room version that the server does
not support."""
- def __init__(self, msg="Homeserver does not support this room version"):
+ def __init__(self, msg: str = "Homeserver does not support this room version"):
super(UnsupportedRoomVersionError, self).__init__(
code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
)
@@ -437,7 +448,7 @@ class IncompatibleRoomVersionError(SynapseError):
failing.
"""
- def __init__(self, room_version):
+ def __init__(self, room_version: str):
super(IncompatibleRoomVersionError, self).__init__(
code=400,
msg="Your homeserver does not support the features required to "
@@ -457,8 +468,8 @@ class PasswordRefusedError(SynapseError):
def __init__(
self,
- msg="This password doesn't comply with the server's policy",
- errcode=Codes.WEAK_PASSWORD,
+ msg: str = "This password doesn't comply with the server's policy",
+ errcode: str = Codes.WEAK_PASSWORD,
):
super(PasswordRefusedError, self).__init__(
code=400, msg=msg, errcode=errcode,
@@ -483,14 +494,14 @@ class RequestSendFailed(RuntimeError):
self.can_retry = can_retry
-def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
+def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
""" Utility method for constructing an error response for client-server
interactions.
Args:
- msg (str): The error message.
- code (str): The error code.
- kwargs : Additional keys to add to the response.
+ msg: The error message.
+ code: The error code.
+ kwargs: Additional keys to add to the response.
Returns:
A dict representing the error response JSON.
"""
@@ -512,7 +523,14 @@ class FederationError(RuntimeError):
is wrong (e.g., it referred to an invalid event)
"""
- def __init__(self, level, code, reason, affected, source=None):
+ def __init__(
+ self,
+ level: str,
+ code: int,
+ reason: str,
+ affected: str,
+ source: Optional[str] = None,
+ ):
if level not in ["FATAL", "ERROR", "WARN"]:
raise ValueError("Level is not valid: %s" % (level,))
self.level = level
@@ -539,16 +557,16 @@ class HttpResponseException(CodeMessageException):
Represents an HTTP-level failure of an outbound request
Attributes:
- response (bytes): body of response
+ response: body of response
"""
- def __init__(self, code, msg, response):
+ def __init__(self, code: int, msg: str, response: bytes):
"""
Args:
- code (int): HTTP status code
- msg (str): reason phrase from HTTP response status line
- response (bytes): body of response
+ code: HTTP status code
+ msg: reason phrase from HTTP response status line
+ response: body of response
"""
super(HttpResponseException, self).__init__(code, msg)
self.response = response
@@ -573,7 +591,7 @@ class HttpResponseException(CodeMessageException):
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
- j = json.loads(self.response)
+ j = json.loads(self.response.decode("utf-8"))
except ValueError:
j = {}
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f6792d9f..5841454c 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
-from twisted.internet import address, defer, reactor
+from twisted.internet import address, reactor
import synapse
import synapse.events
@@ -87,7 +87,6 @@ from synapse.replication.tcp.streams import (
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
- TypingStream,
)
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
@@ -111,6 +110,7 @@ from synapse.rest.client.v1.room import (
RoomSendEventRestServlet,
RoomStateEventRestServlet,
RoomStateRestServlet,
+ RoomTypingRestServlet,
)
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
@@ -374,9 +374,8 @@ class GenericWorkerPresence(BasePresenceHandler):
return _user_syncing()
- @defer.inlineCallbacks
- def notify_from_replication(self, states, stream_id):
- parties = yield get_interested_parties(self.store, states)
+ async def notify_from_replication(self, states, stream_id):
+ parties = await get_interested_parties(self.store, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@@ -386,8 +385,7 @@ class GenericWorkerPresence(BasePresenceHandler):
users=users_to_states.keys(),
)
- @defer.inlineCallbacks
- def process_replication_rows(self, token, rows):
+ async def process_replication_rows(self, token, rows):
states = [
UserPresenceState(
row.user_id,
@@ -405,7 +403,7 @@ class GenericWorkerPresence(BasePresenceHandler):
self.user_to_current_state[state.user_id] = state
stream_id = token
- yield self.notify_from_replication(states, stream_id)
+ await self.notify_from_replication(states, stream_id)
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
return [
@@ -451,37 +449,6 @@ class GenericWorkerPresence(BasePresenceHandler):
await self._bump_active_client(user_id=user_id)
-class GenericWorkerTyping(object):
- def __init__(self, hs):
- self._latest_room_serial = 0
- self._reset()
-
- def _reset(self):
- """
- Reset the typing handler's data caches.
- """
- # map room IDs to serial numbers
- self._room_serials = {}
- # map room IDs to sets of users currently typing
- self._room_typing = {}
-
- def process_replication_rows(self, token, rows):
- if self._latest_room_serial > token:
- # The master has gone backwards. To prevent inconsistent data, just
- # clear everything.
- self._reset()
-
- # Set the latest serial token to whatever the server gave us.
- self._latest_room_serial = token
-
- for row in rows:
- self._room_serials[row.room_id] = token
- self._room_typing[row.room_id] = row.user_ids
-
- def get_current_token(self) -> int:
- return self._latest_room_serial
-
-
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
@@ -511,25 +478,7 @@ class GenericWorkerSlavedStore(
SearchWorkerStore,
BaseSlavedStore,
):
- def __init__(self, database, db_conn, hs):
- super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs)
-
- # We pull out the current federation stream position now so that we
- # always have a known value for the federation position in memory so
- # that we don't have to bounce via a deferred once when we start the
- # replication streams.
- self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
-
- def _get_federation_out_pos(self, db_conn):
- sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?"
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, ("federation",))
- rows = txn.fetchall()
- txn.close()
-
- return rows[0][0] if rows else -1
+ pass
class GenericWorkerServer(HomeServer):
@@ -576,6 +525,7 @@ class GenericWorkerServer(HomeServer):
KeyUploadServlet(self).register(resource)
AccountDataServlet(self).register(resource)
RoomAccountDataServlet(self).register(resource)
+ RoomTypingRestServlet(self).register(resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
@@ -687,16 +637,12 @@ class GenericWorkerServer(HomeServer):
def build_presence_handler(self):
return GenericWorkerPresence(self)
- def build_typing_handler(self):
- return GenericWorkerTyping(self)
-
class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
super(GenericWorkerReplicationHandler, self).__init__(hs)
self.store = hs.get_datastore()
- self.typing_handler = hs.get_typing_handler()
self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence
self.notifier = hs.get_notifier()
@@ -733,11 +679,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
- elif stream_name == TypingStream.NAME:
- self.typing_handler.process_replication_rows(token, rows)
- self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows]
- )
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
@@ -812,19 +753,11 @@ class FederationSenderHandler(object):
self.federation_sender = hs.get_federation_sender()
self._hs = hs
- # if the worker is restarted, we want to pick up where we left off in
- # the replication stream, so load the position from the database.
- #
- # XXX is this actually worthwhile? Whenever the master is restarted, we'll
- # drop some rows anyway (which is mostly fine because we're only dropping
- # typing and presence notifications). If the replication stream is
- # unreliable, why do we do all this hoop-jumping to store the position in the
- # database? See also https://github.com/matrix-org/synapse/issues/7535.
- #
- self.federation_position = self.store.federation_out_pos_startup
+ # Stores the latest position in the federation stream we've gotten up
+ # to. This is always set before we use it.
+ self.federation_position = None
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
- self._last_ack = self.federation_position
def on_start(self):
# There may be some events that are persisted but haven't been sent,
@@ -932,7 +865,6 @@ class FederationSenderHandler(object):
# We ACK this token over replication so that the master can drop
# its in memory queues
self._hs.get_tcp_replication().send_federation_ack(current_position)
- self._last_ack = current_position
except Exception:
logger.exception("Error updating federation stream position")
@@ -960,7 +892,7 @@ def start(config_options):
)
if config.worker_app == "synapse.app.appservice":
- if config.notify_appservices:
+ if config.appservice.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -970,13 +902,13 @@ def start(config_options):
sys.exit(1)
# Force the appservice to start since they will be disabled in the main config
- config.notify_appservices = True
+ config.appservice.notify_appservices = True
else:
# For other worker types we force this to off.
- config.notify_appservices = False
+ config.appservice.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
- if config.start_pushers:
+ if config.server.start_pushers:
sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -986,13 +918,13 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.start_pushers = True
+ config.server.start_pushers = True
else:
# For other worker types we force this to off.
- config.start_pushers = False
+ config.server.start_pushers = False
if config.worker_app == "synapse.app.user_dir":
- if config.update_user_directory:
+ if config.server.update_user_directory:
sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -1002,13 +934,13 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.update_user_directory = True
+ config.server.update_user_directory = True
else:
# For other worker types we force this to off.
- config.update_user_directory = False
+ config.server.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
- if config.send_federation:
+ if config.worker.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
@@ -1018,10 +950,10 @@ def start(config_options):
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
- config.send_federation = True
+ config.worker.send_federation = True
else:
# For other worker types we force this to off.
- config.send_federation = False
+ config.worker.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 09291d86..ec7401f9 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -483,8 +483,7 @@ class SynapseService(service.Service):
_stats_process = []
-@defer.inlineCallbacks
-def phone_stats_home(hs, stats, stats_process=_stats_process):
+async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
uptime = int(now - hs.start_time)
@@ -522,28 +521,28 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
- stats["total_users"] = yield hs.get_datastore().count_all_users()
+ stats["total_users"] = await hs.get_datastore().count_all_users()
- total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+ total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
- daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
+ daily_user_type_results = await hs.get_datastore().count_daily_user_type()
for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
- room_count = yield hs.get_datastore().get_room_count()
+ room_count = await hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
- stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
- stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+ stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
- r30_results = yield hs.get_datastore().count_r30_users()
+ r30_results = await hs.get_datastore().count_r30_users()
for name, count in r30_results.items():
stats["r30_users_" + name] = count
- daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
@@ -558,7 +557,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
- yield hs.get_proxied_http_client().put_json(
+ await hs.get_proxied_http_client().put_json(
hs.config.report_stats_endpoint, stats
)
except Exception as e:
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index f92bfb42..1e0e4d49 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -19,7 +19,7 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.api.constants import ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
@@ -207,7 +207,7 @@ class ApplicationServiceApi(SimpleHttpClient):
if service.url is None:
return True
- events = self._serialize(events)
+ events = self._serialize(service, events)
if txn_id is None:
logger.warning(
@@ -233,6 +233,18 @@ class ApplicationServiceApi(SimpleHttpClient):
failed_transactions_counter.labels(service.id).inc()
return False
- def _serialize(self, events):
+ def _serialize(self, service, events):
time_now = self.clock.time_msec()
- return [serialize_event(e, time_now, as_client_event=True) for e in events]
+ return [
+ serialize_event(
+ e,
+ time_now,
+ as_client_event=True,
+ is_invite=(
+ e.type == EventTypes.Member
+ and e.membership == "invite"
+ and service.is_interested_in_user(e.state_key)
+ ),
+ )
+ for e in events
+ ]
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fca1d844..c32eb62a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,9 +19,11 @@ import argparse
import errno
import os
from collections import OrderedDict
+from hashlib import sha256
from textwrap import dedent
-from typing import Any, MutableMapping, Optional
+from typing import Any, List, MutableMapping, Optional
+import attr
import yaml
@@ -727,4 +729,36 @@ def find_config_files(search_paths):
return config_files
-__all__ = ["Config", "RootConfig"]
+@attr.s
+class ShardedWorkerHandlingConfig:
+ """Algorithm for choosing which instance is responsible for handling some
+ sharded work.
+
+ For example, the federation senders use this to determine which instances
+ handles sending stuff to a given destination (which is used as the `key`
+ below).
+ """
+
+ instances = attr.ib(type=List[str])
+
+ def should_handle(self, instance_name: str, key: str) -> bool:
+ """Whether this instance is responsible for handling the given key.
+ """
+
+ # If multiple instances are not defined we always return true.
+ if not self.instances or len(self.instances) == 1:
+ return True
+
+ # We shard by taking the hash, modulo it by the number of instances and
+ # then checking whether this instance matches the instance at that
+ # index.
+ #
+ # (Technically this introduces some bias and is not entirely uniform,
+ # but since the hash is so large the bias is ridiculously small).
+ dest_hash = sha256(key.encode("utf8")).digest()
+ dest_int = int.from_bytes(dest_hash, byteorder="little")
+ remainder = dest_int % (len(self.instances))
+ return self.instances[remainder] == instance_name
+
+
+__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 9e576060..eb911e8f 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -137,3 +137,8 @@ class Config:
def read_config_files(config_files: List[str]): ...
def find_config_files(search_paths: List[str]): ...
+
+class ShardedWorkerHandlingConfig:
+ instances: List[str]
+ def __init__(self, instances: List[str]) -> None: ...
+ def should_handle(self, instance_name: str, key: str) -> bool: ...
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 1064c269..62bccd9e 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -55,7 +55,7 @@ DEFAULT_CONFIG = """\
#database:
# name: psycopg2
# args:
-# user: synapse
+# user: synapse_user
# password: secretpassword
# database: synapse
# host: localhost
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index b1dc7ad5..a63acbdc 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -22,6 +22,7 @@ import os
from enum import Enum
from typing import Optional
+import attr
import pkg_resources
from ._base import Config, ConfigError
@@ -32,6 +33,33 @@ Password reset emails are enabled on this homeserver due to a partial
%s
"""
+DEFAULT_SUBJECTS = {
+ "message_from_person_in_room": "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room...",
+ "message_from_person": "[%(app)s] You have a message on %(app)s from %(person)s...",
+ "messages_from_person": "[%(app)s] You have messages on %(app)s from %(person)s...",
+ "messages_in_room": "[%(app)s] You have messages on %(app)s in the %(room)s room...",
+ "messages_in_room_and_others": "[%(app)s] You have messages on %(app)s in the %(room)s room and others...",
+ "messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...",
+ "invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...",
+ "invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...",
+ "password_reset": "[%(server_name)s] Password reset",
+ "email_validation": "[%(server_name)s] Validate your email",
+}
+
+
+@attr.s
+class EmailSubjectConfig:
+ message_from_person_in_room = attr.ib(type=str)
+ message_from_person = attr.ib(type=str)
+ messages_from_person = attr.ib(type=str)
+ messages_in_room = attr.ib(type=str)
+ messages_in_room_and_others = attr.ib(type=str)
+ messages_from_person_and_others = attr.ib(type=str)
+ invite_from_person = attr.ib(type=str)
+ invite_from_person_to_room = attr.ib(type=str)
+ password_reset = attr.ib(type=str)
+ email_validation = attr.ib(type=str)
+
class EmailConfig(Config):
section = "email"
@@ -294,8 +322,17 @@ class EmailConfig(Config):
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p,))
+ subjects_config = email_config.get("subjects", {})
+ subjects = {}
+
+ for key, default in DEFAULT_SUBJECTS.items():
+ subjects[key] = subjects_config.get(key, default)
+
+ self.email_subjects = EmailSubjectConfig(**subjects)
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
- return """\
+ return (
+ """\
# Configuration for sending emails from Synapse.
#
email:
@@ -323,17 +360,17 @@ class EmailConfig(Config):
# notif_from defines the "From" address to use when sending emails.
# It must be set if email sending is enabled.
#
- # The placeholder '%(app)s' will be replaced by the application name,
+ # The placeholder '%%(app)s' will be replaced by the application name,
# which is normally 'app_name' (below), but may be overridden by the
# Matrix client application.
#
- # Note that the placeholder must be written '%(app)s', including the
+ # Note that the placeholder must be written '%%(app)s', including the
# trailing 's'.
#
- #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
+ #notif_from: "Your Friendly %%(app)s homeserver <noreply@example.com>"
- # app_name defines the default value for '%(app)s' in notif_from. It
- # defaults to 'Matrix'.
+ # app_name defines the default value for '%%(app)s' in notif_from and email
+ # subjects. It defaults to 'Matrix'.
#
#app_name: my_branded_matrix_server
@@ -401,7 +438,76 @@ class EmailConfig(Config):
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
+
+ # Subjects to use when sending emails from Synapse.
+ #
+ # The placeholder '%%(app)s' will be replaced with the value of the 'app_name'
+ # setting above, or by a value dictated by the Matrix client application.
+ #
+ # If a subject isn't overridden in this configuration file, the value used as
+ # its example will be used.
+ #
+ #subjects:
+
+ # Subjects for notification emails.
+ #
+ # On top of the '%%(app)s' placeholder, these can use the following
+ # placeholders:
+ #
+ # * '%%(person)s', which will be replaced by the display name of the user(s)
+ # that sent the message(s), e.g. "Alice and Bob".
+ # * '%%(room)s', which will be replaced by the name of the room the
+ # message(s) have been sent to, e.g. "My super room".
+ #
+ # See the example provided for each setting to see which placeholder can be
+ # used and how to use them.
+ #
+ # Subject to use to notify about one message from one or more user(s) in a
+ # room which has a name.
+ #message_from_person_in_room: "%(message_from_person_in_room)s"
+ #
+ # Subject to use to notify about one message from one or more user(s) in a
+ # room which doesn't have a name.
+ #message_from_person: "%(message_from_person)s"
+ #
+ # Subject to use to notify about multiple messages from one or more users in
+ # a room which doesn't have a name.
+ #messages_from_person: "%(messages_from_person)s"
+ #
+ # Subject to use to notify about multiple messages in a room which has a
+ # name.
+ #messages_in_room: "%(messages_in_room)s"
+ #
+ # Subject to use to notify about multiple messages in multiple rooms.
+ #messages_in_room_and_others: "%(messages_in_room_and_others)s"
+ #
+ # Subject to use to notify about multiple messages from multiple persons in
+ # multiple rooms. This is similar to the setting above except it's used when
+ # the room in which the notification was triggered has no name.
+ #messages_from_person_and_others: "%(messages_from_person_and_others)s"
+ #
+ # Subject to use to notify about an invite to a room which has a name.
+ #invite_from_person_to_room: "%(invite_from_person_to_room)s"
+ #
+ # Subject to use to notify about an invite to a room which doesn't have a
+ # name.
+ #invite_from_person: "%(invite_from_person)s"
+
+ # Subject for emails related to account administration.
+ #
+ # On top of the '%%(app)s' placeholder, these one can use the
+ # '%%(server_name)s' placeholder, which will be replaced by the value of the
+ # 'server_name' setting in your Synapse configuration.
+ #
+ # Subject to use when sending a password reset email.
+ #password_reset: "%(password_reset)s"
+ #
+ # Subject to use when sending a verification email to assert an address's
+ # ownership.
+ #email_validation: "%(email_validation)s"
"""
+ % DEFAULT_SUBJECTS
+ )
class ThreepidBehaviour(Enum):
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
new file mode 100644
index 00000000..2c77d8f8
--- /dev/null
+++ b/synapse/config/federation.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from netaddr import IPSet
+
+from ._base import Config, ConfigError
+
+
+class FederationConfig(Config):
+ section = "federation"
+
+ def read_config(self, config, **kwargs):
+ # FIXME: federation_domain_whitelist needs sytests
+ self.federation_domain_whitelist = None # type: Optional[dict]
+ federation_domain_whitelist = config.get("federation_domain_whitelist", None)
+
+ if federation_domain_whitelist is not None:
+ # turn the whitelist into a hash for speed of lookup
+ self.federation_domain_whitelist = {}
+
+ for domain in federation_domain_whitelist:
+ self.federation_domain_whitelist[domain] = True
+
+ self.federation_ip_range_blacklist = config.get(
+ "federation_ip_range_blacklist", []
+ )
+
+ # Attempt to create an IPSet from the given ranges
+ try:
+ self.federation_ip_range_blacklist = IPSet(
+ self.federation_ip_range_blacklist
+ )
+
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+ except Exception as e:
+ raise ConfigError(
+ "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
+ )
+
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ return """\
+ # Restrict federation to the following whitelist of domains.
+ # N.B. we recommend also firewalling your federation listener to limit
+ # inbound federation traffic as early as possible, rather than relying
+ # purely on this application-layer restriction. If not specified, the
+ # default is to whitelist everything.
+ #
+ #federation_domain_whitelist:
+ # - lon.example.com
+ # - nyc.example.com
+ # - syd.example.com
+
+ # Prevent federation requests from being sent to the following
+ # blacklist IP address CIDR ranges. If this option is not specified, or
+ # specified with an empty list, no ip range blacklist will be enforced.
+ #
+ # As of Synapse v1.4.0 this option also affects any outbound requests to identity
+ # servers provided by user input.
+ #
+ # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+ # listed here, since they correspond to unroutable addresses.)
+ #
+ federation_ip_range_blacklist:
+ - '127.0.0.0/8'
+ - '10.0.0.0/8'
+ - '172.16.0.0/12'
+ - '192.168.0.0/16'
+ - '100.64.0.0/10'
+ - '169.254.0.0/16'
+ - '::1/128'
+ - 'fe80::/64'
+ - 'fc00::/7'
+ """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 264c274c..556e2914 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -23,6 +23,7 @@ from .cas import CasConfig
from .consent_config import ConsentConfig
from .database import DatabaseConfig
from .emailconfig import EmailConfig
+from .federation import FederationConfig
from .groups import GroupsConfig
from .jwt_config import JWTConfig
from .key import KeyConfig
@@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
TlsConfig,
+ FederationConfig,
CacheConfig,
DatabaseConfig,
LoggingConfig,
@@ -76,7 +78,6 @@ class HomeServerConfig(RootConfig):
JWTConfig,
PasswordConfig,
EmailConfig,
- WorkerConfig,
PasswordAuthProviderConfig,
PushConfig,
SpamCheckerConfig,
@@ -89,5 +90,7 @@ class HomeServerConfig(RootConfig):
RoomDirectoryConfig,
ThirdPartyRulesConfig,
TracerConfig,
+ WorkerConfig,
RedisConfig,
+ FederationConfig,
]
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index fce96b4a..3252ad9e 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -32,6 +32,11 @@ class JWTConfig(Config):
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
+ # The issuer and audiences are optional, if provided, it is asserted
+ # that the claims exist on the JWT.
+ self.jwt_issuer = jwt_config.get("issuer")
+ self.jwt_audiences = jwt_config.get("audiences")
+
try:
import jwt
@@ -42,6 +47,8 @@ class JWTConfig(Config):
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
+ self.jwt_issuer = None
+ self.jwt_audiences = None
def generate_config_section(self, **kwargs):
return """\
@@ -52,6 +59,9 @@ class JWTConfig(Config):
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
+ # Additionally, the expiration time ("exp"), not before time ("nbf"),
+ # and issued at ("iat") claims are validated if present.
+ #
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
@@ -78,4 +88,22 @@ class JWTConfig(Config):
# Required if 'enabled' is true.
#
#algorithm: "provided-by-your-issuer"
+
+ # The issuer to validate the "iss" claim against.
+ #
+ # Optional, if provided the "iss" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ #issuer: "provided-by-your-issuer"
+
+ # A list of audiences to validate the "aud" claim against.
+ #
+ # Optional, if provided the "aud" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ # Note that if the "aud" claim is included in a JSON web token then
+ # validation will fail without configuring audiences.
+ #
+ #audiences:
+ # - "provided-by-your-issuer"
"""
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 49f6c32b..dd775a97 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -214,7 +214,7 @@ def setup_logging(
Set up the logging subsystem.
Args:
- config (LoggingConfig | synapse.config.workers.WorkerConfig):
+ config (LoggingConfig | synapse.config.worker.WorkerConfig):
configuration data
use_worker_options (bool): True to use the 'worker_log_config' option
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 6f2b3a7f..a1f3752c 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+from ._base import Config, ShardedWorkerHandlingConfig
class PushConfig(Config):
@@ -24,6 +24,9 @@ class PushConfig(Config):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
+ pusher_instances = config.get("pusher_instances") or []
+ self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
+
# There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade.
diff --git a/synapse/config/redis.py b/synapse/config/redis.py
index d5d3ca1c..13733023 100644
--- a/synapse/config/redis.py
+++ b/synapse/config/redis.py
@@ -21,7 +21,7 @@ class RedisConfig(Config):
section = "redis"
def read_config(self, config, **kwargs):
- redis_config = config.get("redis", {})
+ redis_config = config.get("redis") or {}
self.redis_enabled = redis_config.get("enabled", False)
if not self.redis_enabled:
@@ -32,3 +32,24 @@ class RedisConfig(Config):
self.redis_host = redis_config.get("host", "localhost")
self.redis_port = redis_config.get("port", 6379)
self.redis_password = redis_config.get("password")
+
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ return """\
+ # Configuration for Redis when using workers. This *must* be enabled when
+ # using workers (unless using old style direct TCP configuration).
+ #
+ redis:
+ # Uncomment the below to enable Redis support.
+ #
+ #enabled: true
+
+ # Optional host and port to use to connect to redis. Defaults to
+ # localhost and 6379
+ #
+ #host: localhost
+ #port: 6379
+
+ # Optional password if configured on the Redis instance
+ #
+ #password: <secret_password>
+ """
diff --git a/synapse/config/room.py b/synapse/config/room.py
index 6aa4de06..52cf0b62 100644
--- a/synapse/config/room.py
+++ b/synapse/config/room.py
@@ -50,7 +50,12 @@ class RoomConfig(Config):
RoomCreationPreset.PRIVATE_CHAT,
RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
]
- elif encryption_for_room_type == RoomDefaultEncryptionTypes.OFF:
+ elif (
+ encryption_for_room_type == RoomDefaultEncryptionTypes.OFF
+ or encryption_for_room_type is False
+ ):
+ # PyYAML translates "off" into False if it's unquoted, so we also need to
+ # check for encryption_for_room_type being False.
self.encryption_enabled_by_default_for_room_presets = []
else:
raise ConfigError(
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 82046648..3747a01c 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -23,7 +23,6 @@ from typing import Any, Dict, Iterable, List, Optional
import attr
import yaml
-from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
@@ -136,11 +135,6 @@ class ServerConfig(Config):
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
- # Whether to send federation traffic out in this process. This only
- # applies to some federation traffic, and so shouldn't be used to
- # "disable" federation
- self.send_federation = config.get("send_federation", True)
-
# Whether to enable user presence.
self.use_presence = config.get("use_presence", True)
@@ -213,7 +207,7 @@ class ServerConfig(Config):
# errors when attempting to search for messages.
self.enable_search = config.get("enable_search", True)
- self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
+ self.filter_timeline_limit = config.get("filter_timeline_limit", 100)
# Whether we should block invites sent to users on this server
# (other than those sent by local server admins)
@@ -263,34 +257,6 @@ class ServerConfig(Config):
# due to resource constraints
self.admin_contact = config.get("admin_contact", None)
- # FIXME: federation_domain_whitelist needs sytests
- self.federation_domain_whitelist = None # type: Optional[dict]
- federation_domain_whitelist = config.get("federation_domain_whitelist", None)
-
- if federation_domain_whitelist is not None:
- # turn the whitelist into a hash for speed of lookup
- self.federation_domain_whitelist = {}
-
- for domain in federation_domain_whitelist:
- self.federation_domain_whitelist[domain] = True
-
- self.federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", []
- )
-
- # Attempt to create an IPSet from the given ranges
- try:
- self.federation_ip_range_blacklist = IPSet(
- self.federation_ip_range_blacklist
- )
-
- # Always blacklist 0.0.0.0, ::
- self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
- except Exception as e:
- raise ConfigError(
- "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
- )
-
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
@@ -727,7 +693,9 @@ class ServerConfig(Config):
#gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
- # and sync operations. The default value is -1, means no upper limit.
+ # and sync operations. The default value is 100. -1 means no upper limit.
+ #
+ # Uncomment the following to increase the limit to 5000.
#
#filter_timeline_limit: 5000
@@ -743,38 +711,6 @@ class ServerConfig(Config):
#
#enable_search: false
- # Restrict federation to the following whitelist of domains.
- # N.B. we recommend also firewalling your federation listener to limit
- # inbound federation traffic as early as possible, rather than relying
- # purely on this application-layer restriction. If not specified, the
- # default is to whitelist everything.
- #
- #federation_domain_whitelist:
- # - lon.example.com
- # - nyc.example.com
- # - syd.example.com
-
- # Prevent federation requests from being sent to the following
- # blacklist IP address CIDR ranges. If this option is not specified, or
- # specified with an empty list, no ip range blacklist will be enforced.
- #
- # As of Synapse v1.4.0 this option also affects any outbound requests to identity
- # servers provided by user input.
- #
- # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
- # listed here, since they correspond to unroutable addresses.)
- #
- federation_ip_range_blacklist:
- - '127.0.0.0/8'
- - '10.0.0.0/8'
- - '172.16.0.0/12'
- - '192.168.0.0/16'
- - '100.64.0.0/10'
- - '169.254.0.0/16'
- - '::1/128'
- - 'fe80::/64'
- - 'fc00::/7'
-
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
@@ -803,7 +739,7 @@ class ServerConfig(Config):
# names: a list of names of HTTP resources. See below for a list of
# valid resource names.
#
- # compress: set to true to enable HTTP comression for this resource.
+ # compress: set to true to enable HTTP compression for this resource.
#
# additional_resources: Only valid for an 'http' listener. A map of
# additional endpoints which should be loaded via dynamic modules.
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index dbc66163..c784a715 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -15,7 +15,7 @@
import attr
-from ._base import Config, ConfigError
+from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
from .server import ListenerConfig, parse_listener_def
@@ -34,9 +34,11 @@ class WriterLocations:
Attributes:
events: The instance that writes to the event and backfill streams.
+ events: The instance that writes to the typing stream.
"""
events = attr.ib(default="master", type=str)
+ typing = attr.ib(default="master", type=str)
class WorkerConfig(Config):
@@ -83,6 +85,16 @@ class WorkerConfig(Config):
)
)
+ # Whether to send federation traffic out in this process. This only
+ # applies to some federation traffic, and so shouldn't be used to
+ # "disable" federation
+ self.send_federation = config.get("send_federation", True)
+
+ federation_sender_instances = config.get("federation_sender_instances") or []
+ self.federation_shard_config = ShardedWorkerHandlingConfig(
+ federation_sender_instances
+ )
+
# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {}
self.instance_map = {
@@ -93,16 +105,52 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
- # Check that the configured writer for events also appears in
+ # Check that the configured writer for events and typing also appears in
# `instance_map`.
- if (
- self.writers.events != "master"
- and self.writers.events not in self.instance_map
- ):
- raise ConfigError(
- "Instance %r is configured to write events but does not appear in `instance_map` config."
- % (self.writers.events,)
- )
+ for stream in ("events", "typing"):
+ instance = getattr(self.writers, stream)
+ if instance != "master" and instance not in self.instance_map:
+ raise ConfigError(
+ "Instance %r is configured to write %s but does not appear in `instance_map` config."
+ % (instance, stream)
+ )
+
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ return """\
+ ## Workers ##
+
+ # Disables sending of outbound federation transactions on the main process.
+ # Uncomment if using a federation sender worker.
+ #
+ #send_federation: false
+
+ # It is possible to run multiple federation sender workers, in which case the
+ # work is balanced across them.
+ #
+ # This configuration must be shared between all federation sender workers, and if
+ # changed all federation sender workers must be stopped at the same time and then
+ # started, to ensure that all instances are running with the same config (otherwise
+ # events may be dropped).
+ #
+ #federation_sender_instances:
+ # - federation_sender1
+
+ # When using workers this should be a map from `worker_name` to the
+ # HTTP replication listener of the worker, if configured.
+ #
+ #instance_map:
+ # worker1:
+ # host: localhost
+ # port: 8034
+
+ # Experimental: When using workers you can define which workers should
+ # handle event persistence and typing notifications. Any worker
+ # specified here must also be in the `instance_map`.
+ #
+ #stream_writers:
+ # events: worker1
+ # typing: worker1
+ """
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c5823551..c0981eee 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -65,14 +65,16 @@ def check(
room_id = event.room_id
- # I'm not really expecting to get auth events in the wrong room, but let's
- # sanity-check it
+ # We need to ensure that the auth events are actually for the same room, to
+ # stop people from using powers they've been granted in other rooms for
+ # example.
for auth_event in auth_events.values():
if auth_event.room_id != room_id:
- raise Exception(
+ raise AuthError(
+ 403,
"During auth for event %s in room %s, found event %s in the state "
"which is in room %s"
- % (event.event_id, room_id, auth_event.event_id, auth_event.room_id)
+ % (event.event_id, room_id, auth_event.event_id, auth_event.room_id),
)
if do_sig_check:
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 92aadfe7..0bb21641 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -106,8 +106,8 @@ class EventBuilder(object):
Deferred[FrozenEvent]
"""
- state_ids = yield self._state.get_current_state_ids(
- self.room_id, prev_event_ids
+ state_ids = yield defer.ensureDeferred(
+ self._state.get_current_state_ids(self.room_id, prev_event_ids)
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f6b50797..11f0d34e 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
+import collections.abc
import re
from typing import Any, Mapping, Union
@@ -424,7 +424,7 @@ def copy_power_levels_contents(
Raises:
TypeError if the input does not look like a valid power levels event content
"""
- if not isinstance(old_power_levels, collections.Mapping):
+ if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels = {}
@@ -434,7 +434,7 @@ def copy_power_levels_contents(
power_levels[k] = v
continue
- if isinstance(v, collections.Mapping):
+ if isinstance(v, collections.abc.Mapping):
power_levels[k] = h = {}
for k1, v1 in v.items():
# we should only have one level of nesting
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index a37cc9cb..994e6c8d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -374,29 +374,26 @@ class FederationClient(FederationBase):
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
- @defer.inlineCallbacks
- def handle_check_result(pdu: EventBase, deferred: Deferred):
+ async def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
- res = yield make_deferred_yieldable(deferred)
+ res = await make_deferred_yieldable(deferred)
except SynapseError:
res = None
if not res:
# Check local db.
- res = yield self.store.get_event(
+ res = await self.store.get_event(
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
try:
- res = yield defer.ensureDeferred(
- self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- room_version=room_version,
- outlier=outlier,
- timeout=10000,
- )
+ res = await self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ room_version=room_version,
+ outlier=outlier,
+ timeout=10000,
)
except SynapseError:
pass
@@ -995,24 +992,25 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.")
- @defer.inlineCallbacks
- def get_room_complexity(self, destination, room_id):
+ async def get_room_complexity(
+ self, destination: str, room_id: str
+ ) -> Optional[dict]:
"""
Fetch the complexity of a remote room from another server.
Args:
- destination (str): The remote server
- room_id (str): The room ID to ask about.
+ destination: The remote server
+ room_id: The room ID to ask about.
Returns:
- Deferred[dict] or Deferred[None]: Dict contains the complexity
- metric versions, while None means we could not fetch the complexity.
+ Dict contains the complexity metric versions, while None means we
+ could not fetch the complexity.
"""
try:
- complexity = yield self.transport_layer.get_room_complexity(
+ complexity = await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
- defer.returnValue(complexity)
+ return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
@@ -1029,4 +1027,4 @@ class FederationClient(FederationBase):
# If we don't manage to find it, return None. It's not an error if a
# server doesn't give it to us.
- defer.returnValue(None)
+ return None
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 86051dec..11c5d632 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,7 +15,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ List,
+ Match,
+ Optional,
+ Tuple,
+ Union,
+)
from canonicaljson import json
from prometheus_client import Counter, Histogram
@@ -56,6 +67,9 @@ from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
TRANSACTION_CONCURRENCY_LIMIT = 10
@@ -95,6 +109,9 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
+ self._state_ids_resp_cache = ResponseCache(
+ hs, "state_ids_resp", timeout_ms=30000
+ )
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
@@ -362,10 +379,16 @@ class FederationServer(FederationBase):
if not in_room:
raise AuthError(403, "Host not in room.")
+ resp = await self._state_ids_resp_cache.wrap(
+ (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
+ )
+
+ return 200, resp
+
+ async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
-
- return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
+ return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
@@ -526,9 +549,9 @@ class FederationServer(FederationBase):
json_result = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
+ for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json.loads(json_bytes)
+ key_id: json.loads(json_str)
}
logger.info(
@@ -768,11 +791,30 @@ class FederationHandlerRegistry(object):
query type for incoming federation traffic.
"""
- def __init__(self):
- self.edu_handlers = {}
- self.query_handlers = {}
+ def __init__(self, hs: "HomeServer"):
+ self.config = hs.config
+ self.http_client = hs.get_simple_http_client()
+ self.clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
- def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
+ # These are safe to load in monolith mode, but will explode if we try
+ # and use them. However we have guards before we use them to ensure that
+ # we don't route to ourselves, and in monolith mode that will always be
+ # the case.
+ self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
+ self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
+
+ self.edu_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
+ self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
+
+ # Map from type to instance name that we should route EDU handling to.
+ self._edu_type_to_instance = {} # type: Dict[str, str]
+
+ def register_edu_handler(
+ self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]]
+ ):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
@@ -809,66 +851,56 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler
+ def register_instance_for_edu(self, edu_type: str, instance_name: str):
+ """Register that the EDU handler is on a different instance than master.
+ """
+ self._edu_type_to_instance[edu_type] = instance_name
+
async def on_edu(self, edu_type: str, origin: str, content: dict):
+ if not self.config.use_presence and edu_type == "m.presence":
+ return
+
+ # Check if we have a handler on this instance
handler = self.edu_handlers.get(edu_type)
- if not handler:
- logger.warning("No handler registered for EDU type %s", edu_type)
+ if handler:
+ with start_active_span_from_edu(content, "handle_edu"):
+ try:
+ await handler(origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception:
+ logger.exception("Failed to handle edu %r", edu_type)
return
- with start_active_span_from_edu(content, "handle_edu"):
+ # Check if we can route it somewhere else that isn't us
+ route_to = self._edu_type_to_instance.get(edu_type, "master")
+ if route_to != self._instance_name:
try:
- await handler(origin, content)
+ await self._send_edu(
+ instance_name=route_to,
+ edu_type=edu_type,
+ origin=origin,
+ content=content,
+ )
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
-
- def on_query(self, query_type: str, args: dict) -> defer.Deferred:
- handler = self.query_handlers.get(query_type)
- if not handler:
- logger.warning("No handler registered for query type %s", query_type)
- raise NotFoundError("No handler for Query type '%s'" % (query_type,))
-
- return handler(args)
-
-
-class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
- """A FederationHandlerRegistry for worker processes.
-
- When receiving EDU or queries it will check if an appropriate handler has
- been registered on the worker, if there isn't one then it calls off to the
- master process.
- """
-
- def __init__(self, hs):
- self.config = hs.config
- self.http_client = hs.get_simple_http_client()
- self.clock = hs.get_clock()
-
- self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
- self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
-
- super(ReplicationFederationHandlerRegistry, self).__init__()
-
- async def on_edu(self, edu_type: str, origin: str, content: dict):
- """Overrides FederationHandlerRegistry
- """
- if not self.config.use_presence and edu_type == "m.presence":
return
- handler = self.edu_handlers.get(edu_type)
- if handler:
- return await super(ReplicationFederationHandlerRegistry, self).on_edu(
- edu_type, origin, content
- )
-
- return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
+ # Oh well, let's just log and move on.
+ logger.warning("No handler registered for EDU type %s", edu_type)
async def on_query(self, query_type: str, args: dict):
- """Overrides FederationHandlerRegistry
- """
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
- return await self._get_query_client(query_type=query_type, args=args)
+ # Check if we can route it somewhere else that isn't us
+ if self._instance_name == "master":
+ return await self._get_query_client(query_type=query_type, args=args)
+
+ # Uh oh, no handler! Let's raise an exception so the request returns an
+ # error.
+ logger.warning("No handler registered for query type %s", query_type)
+ raise NotFoundError("No handler for Query type '%s'" % (query_type,))
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 860b03f7..2b0ab2dc 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -55,6 +55,11 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
+ # We may have multiple federation sender instances, so we need to track
+ # their positions separately.
+ self._sender_instances = hs.config.worker.federation_shard_config.instances
+ self._sender_positions = {}
+
# Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState]
@@ -261,7 +266,14 @@ class FederationRemoteSendQueue(object):
def get_current_token(self):
return self.pos - 1
- def federation_ack(self, token):
+ def federation_ack(self, instance_name, token):
+ if self._sender_instances:
+ # If we have configured multiple federation sender instances we need
+ # to track their positions separately, and only clear the queue up
+ # to the token all instances have acked.
+ self._sender_positions[instance_name] = token
+ token = min(self._sender_positions.values())
+
self._clear_queue_before_pos(token)
async def get_replication_rows(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 464d7a41..6ae6522f 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -69,6 +69,9 @@ class FederationSender(object):
self._transaction_manager = TransactionManager(hs)
+ self._instance_name = hs.get_instance_name()
+ self._federation_shard_config = hs.config.worker.federation_shard_config
+
# map from destination to PerDestinationQueue
self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue]
@@ -191,7 +194,13 @@ class FederationSender(object):
)
return
- destinations = set(destinations)
+ destinations = {
+ d
+ for d in destinations
+ if self._federation_shard_config.should_handle(
+ self._instance_name, d
+ )
+ }
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server
@@ -321,8 +330,15 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
- domains = yield self.state.get_current_hosts_in_room(room_id)
- domains = [d for d in domains if d != self.server_name]
+ domains = yield defer.ensureDeferred(
+ self.state.get_current_hosts_in_room(room_id)
+ )
+ domains = [
+ d
+ for d in domains
+ if d != self.server_name
+ and self._federation_shard_config.should_handle(self._instance_name, d)
+ ]
if not domains:
return
@@ -427,6 +443,10 @@ class FederationSender(object):
for destination in destinations:
if destination == self.server_name:
continue
+ if not self._federation_shard_config.should_handle(
+ self._instance_name, destination
+ ):
+ continue
self._get_per_destination_queue(destination).send_presence(states)
@measure_func("txnqueue._process_presence")
@@ -435,12 +455,20 @@ class FederationSender(object):
"""Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination
"""
- hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
+ hosts_and_states = yield defer.ensureDeferred(
+ get_interested_remotes(self.store, states, self.state)
+ )
for destinations, states in hosts_and_states:
for destination in destinations:
if destination == self.server_name:
continue
+
+ if not self._federation_shard_config.should_handle(
+ self._instance_name, destination
+ ):
+ continue
+
self._get_per_destination_queue(destination).send_presence(states)
def build_and_send_edu(
@@ -462,6 +490,11 @@ class FederationSender(object):
logger.info("Not sending EDU to ourselves")
return
+ if not self._federation_shard_config.should_handle(
+ self._instance_name, destination
+ ):
+ return
+
edu = Edu(
origin=self.server_name,
destination=destination,
@@ -478,6 +511,11 @@ class FederationSender(object):
edu: edu to send
key: clobbering key for this edu
"""
+ if not self._federation_shard_config.should_handle(
+ self._instance_name, edu.destination
+ ):
+ return
+
queue = self._get_per_destination_queue(edu.destination)
if key:
queue.send_keyed_edu(edu, key)
@@ -489,6 +527,11 @@ class FederationSender(object):
logger.warning("Not sending device update to ourselves")
return
+ if not self._federation_shard_config.should_handle(
+ self._instance_name, destination
+ ):
+ return
+
self._get_per_destination_queue(destination).attempt_new_transaction()
def wake_destination(self, destination: str):
@@ -502,6 +545,11 @@ class FederationSender(object):
logger.warning("Not waking up ourselves")
return
+ if not self._federation_shard_config.should_handle(
+ self._instance_name, destination
+ ):
+ return
+
self._get_per_destination_queue(destination).attempt_new_transaction()
@staticmethod
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 12966e23..dd150f89 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -74,6 +74,20 @@ class PerDestinationQueue(object):
self._clock = hs.get_clock()
self._store = hs.get_datastore()
self._transaction_manager = transaction_manager
+ self._instance_name = hs.get_instance_name()
+ self._federation_shard_config = hs.config.worker.federation_shard_config
+
+ self._should_send_on_this_instance = True
+ if not self._federation_shard_config.should_handle(
+ self._instance_name, destination
+ ):
+ # We don't raise an exception here to avoid taking out any other
+ # processing. We have a guard in `attempt_new_transaction` that
+ # ensure we don't start sending stuff.
+ logger.error(
+ "Create a per destination queue for %s on wrong worker", destination,
+ )
+ self._should_send_on_this_instance = False
self._destination = destination
self.transmission_loop_running = False
@@ -180,6 +194,14 @@ class PerDestinationQueue(object):
logger.debug("TX [%s] Transaction already in progress", self._destination)
return
+ if not self._should_send_on_this_instance:
+ # We don't raise an exception here to avoid taking out any other
+ # processing.
+ logger.error(
+ "Trying to start a transaction to %s on wrong worker", self._destination
+ )
+ return
+
logger.debug("TX [%s] Starting transaction loop", self._destination)
run_as_background_process(
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index a2752a54..8280f8b9 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -61,8 +61,6 @@ class TransactionManager(object):
# all the edus in that transaction. This needs to be done since there is
# no active span here, so if the edus were not received by the remote the
# span would have no causality and it would be forgotten.
- # The span_contexts is a generator so that it won't be evaluated if
- # opentracing is disabled. (Yay speed!)
span_contexts = []
keep_destination = whitelisted_homeserver(destination)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index d1bac318..5e111aa9 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -20,8 +20,6 @@ import logging
import re
from typing import Optional, Tuple, Type
-from twisted.internet.defer import maybeDeferred
-
import synapse
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions
@@ -340,6 +338,12 @@ class BaseFederationServlet(object):
if origin:
with ratelimiter.ratelimit(origin) as d:
await d
+ if request._disconnected:
+ logger.warning(
+ "client disconnected before we started processing "
+ "request"
+ )
+ return -1, None
response = await func(
origin, content, request.args, *args, **kwargs
)
@@ -795,12 +799,8 @@ class PublicRoomList(BaseFederationServlet):
# zero is a special value which corresponds to no limit.
limit = None
- data = await maybeDeferred(
- self.handler.get_local_public_room_list,
- limit,
- since_token,
- network_tuple=network_tuple,
- from_federation=True,
+ data = await self.handler.get_local_public_room_list(
+ limit, since_token, network_tuple=network_tuple, from_federation=True
)
return 200, data
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 61dc4bea..ba2bf998 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -15,8 +15,8 @@
import logging
-from twisted.internet import defer
-
+import synapse.state
+import synapse.storage
import synapse.types
from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +28,6 @@ logger = logging.getLogger(__name__)
class BaseHandler(object):
"""
Common base class for the event handlers.
-
- Attributes:
- store (synapse.storage.DataStore):
- state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
@@ -39,10 +35,10 @@ class BaseHandler(object):
Args:
hs (synapse.server.HomeServer):
"""
- self.store = hs.get_datastore()
+ self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
- self.state_handler = hs.get_state_handler()
+ self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
@@ -68,8 +64,7 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
- @defer.inlineCallbacks
- def ratelimit(self, requester, update=True, is_admin_redaction=False):
+ async def ratelimit(self, requester, update=True, is_admin_redaction=False):
"""Ratelimits requests.
Args:
@@ -101,7 +96,7 @@ class BaseHandler(object):
burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB.
- override = yield self.store.get_ratelimit_for_user(user_id)
+ override = await self.store.get_ratelimit_for_user(user_id)
if override:
# If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f3c0aece..506bb2b2 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -72,7 +72,7 @@ class AdminHandler(BaseHandler):
writer (ExfiltrationWriter)
Returns:
- defer.Deferred: Resolves when all data for a user has been written.
+ Resolves when all data for a user has been written.
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a162392e..c7d921c2 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import time
import unicodedata
@@ -863,11 +864,15 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
- await provider.on_logged_out(
+ # This might return an awaitable, if it does block the log out
+ # until it completes.
+ result = provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
)
+ if inspect.isawaitable(result):
+ await result
# delete pushers associated with this access token
if user_info["token_id"] is not None:
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index d79ffefd..786e608f 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -104,7 +104,7 @@ class CasHandler:
return user, displayname
def _parse_cas_response(
- self, cas_response_body: str
+ self, cas_response_body: bytes
) -> Tuple[str, Dict[str, Optional[str]]]:
"""
Retrieve the user and other parameters from the CAS response.
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 2afb390a..25169157 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -29,6 +30,7 @@ class DeactivateAccountHandler(BaseHandler):
def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs)
+ self.hs = hs
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
@@ -40,23 +42,25 @@ class DeactivateAccountHandler(BaseHandler):
# Start the user parter loop so it can resume parting users from rooms where
# it left off (if it has work left to do).
- hs.get_reactor().callWhenRunning(self._start_user_parting)
+ if hs.config.worker_app is None:
+ hs.get_reactor().callWhenRunning(self._start_user_parting)
self._account_validity_enabled = hs.config.account_validity.enabled
- async def deactivate_account(self, user_id, erase_data, id_server=None):
+ async def deactivate_account(
+ self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+ ) -> bool:
"""Deactivate a user's account
Args:
- user_id (str): ID of user to be deactivated
- erase_data (bool): whether to GDPR-erase the user's data
- id_server (str|None): Use the given identity server when unbinding
+ user_id: ID of user to be deactivated
+ erase_data: whether to GDPR-erase the user's data
+ id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
Returns:
- Deferred[bool]: True if identity server supports removing
- threepids, otherwise False.
+ True if identity server supports removing threepids, otherwise False.
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
@@ -133,11 +137,11 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding
- async def _reject_pending_invites_for_user(self, user_id):
+ async def _reject_pending_invites_for_user(self, user_id: str):
"""Reject pending invites addressed to a given user ID.
Args:
- user_id (str): The user ID to reject pending invites for.
+ user_id: The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
@@ -165,22 +169,16 @@ class DeactivateAccountHandler(BaseHandler):
room.room_id,
)
- def _start_user_parting(self):
+ def _start_user_parting(self) -> None:
"""
Start the process that goes through the table of users
pending deactivation, if it isn't already running.
-
- Returns:
- None
"""
if not self._user_parter_running:
run_as_background_process("user_parter_loop", self._user_parter_loop)
- async def _user_parter_loop(self):
+ async def _user_parter_loop(self) -> None:
"""Loop that parts deactivated users from rooms
-
- Returns:
- None
"""
self._user_parter_running = True
logger.info("Starting user parter")
@@ -197,11 +195,8 @@ class DeactivateAccountHandler(BaseHandler):
finally:
self._user_parter_running = False
- async def _part_user(self, user_id):
+ async def _part_user(self, user_id: str) -> None:
"""Causes the given user_id to leave all the rooms they're joined to
-
- Returns:
- None
"""
user = UserID.from_string(user_id)
@@ -223,3 +218,31 @@ class DeactivateAccountHandler(BaseHandler):
user_id,
room_id,
)
+
+ async def activate_account(self, user_id: str) -> None:
+ """
+ Activate an account that was previously deactivated.
+
+ This marks the user as active and not erased in the database, but does
+ not attempt to rejoin rooms, re-add threepids, etc.
+
+ If enabled, the user will be re-added to the user directory.
+
+ The user will also need a password hash set to actually login.
+
+ Args:
+ user_id: ID of user to be re-activated
+ """
+ # Add the user to the directory, if necessary.
+ user = UserID.from_string(user_id)
+ if self.hs.config.user_directory_search_all_users:
+ profile = await self.store.get_profileinfo(user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
+ # Ensure the user is not marked as erased.
+ await self.store.mark_user_not_erased(user_id)
+
+ # Mark the user as active.
+ await self.store.set_user_deactivated_status(user_id, False)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 31346b56..db417d60 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, Optional
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Optional
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
- @defer.inlineCallbacks
- def get_devices_by_user(self, user_id):
+ async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
"""
Retrieve the given user's devices
Args:
- user_id (str):
+ user_id: The user ID to query for devices.
Returns:
- defer.Deferred: list[dict[str, X]]: info on each device
+ info on each device
"""
set_tag("user_id", user_id)
- device_map = yield self.store.get_devices_by_user(user_id)
+ device_map = await self.store.get_devices_by_user(user_id)
- ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
+ ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values())
for device in devices:
@@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
- @defer.inlineCallbacks
- def get_device(self, user_id, device_id):
+ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
""" Retrieve the given device
Args:
- user_id (str):
- device_id (str):
+ user_id: The user to get the device from
+ device_id: The device to fetch.
Returns:
- defer.Deferred: dict[str, X]: info on the device
+ info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
- device = yield self.store.get_device(user_id, device_id)
+ device = await self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
- ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
+ ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
set_tag("device", device)
@@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler):
return device
- @measure_func("device.get_user_ids_changed")
@trace
- @defer.inlineCallbacks
- def get_user_ids_changed(self, user_id, from_token):
+ @measure_func("device.get_user_ids_changed")
+ async def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
@@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler):
set_tag("user_id", user_id)
set_tag("from_token", from_token)
- now_room_key = yield self.store.get_room_events_max_id()
+ now_room_key = await self.store.get_room_events_max_id()
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed for users that we share
# rooms with.
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
@@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler):
# Always tell the user about their own devices
tracked_users.add(user_id)
- changed = yield self.store.get_users_whose_devices_changed(
+ changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
- member_events = yield self.store.get_membership_changes_for_user(
+ member_events = await self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key
)
rooms_changed.update(event.room_id for event in member_events)
@@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler):
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
- current_state_ids = yield self.store.get_current_state_ids(room_id)
+ current_state_ids = await self.store.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
@@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler):
# Fetch the current state at the time.
try:
- event_ids = yield self.store.get_forward_extremeties_for_room(
+ event_ids = await self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
@@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler):
continue
# mapping from event_id -> state_dict
- prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
+ prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def on_federation_query_user_devices(self, user_id):
- stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
- self_signing_key = yield self.store.get_e2e_cross_signing_key(
+ async def on_federation_query_user_devices(self, user_id):
+ stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
+ master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
+ self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
)
@@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- @defer.inlineCallbacks
- def check_device_registered(
+ async def check_device_registered(
self, user_id, device_id, initial_device_display_name=None
):
"""
@@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler):
str: device id (generated if none was supplied)
"""
if device_id is not None:
- new_device = yield self.store.store_device(
+ new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
- yield self.notify_device_update(user_id, [device_id])
+ await self.notify_device_update(user_id, [device_id])
return device_id
# if the device id is not specified, we'll autogen one, but loop a few
@@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler):
attempts = 0
while attempts < 5:
device_id = stringutils.random_string(10).upper()
- new_device = yield self.store.store_device(
+ new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
- yield self.notify_device_update(user_id, [device_id])
+ await self.notify_device_update(user_id, [device_id])
return device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@trace
- @defer.inlineCallbacks
- def delete_device(self, user_id, device_id):
+ async def delete_device(self, user_id: str, device_id: str) -> None:
""" Delete the given device
Args:
- user_id (str):
- device_id (str):
-
- Returns:
- defer.Deferred:
+ user_id: The user to delete the device from.
+ device_id: The device to delete.
"""
try:
- yield self.store.delete_device(user_id, device_id)
+ await self.store.delete_device(user_id, device_id)
except errors.StoreError as e:
if e.code == 404:
# no match
@@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
- yield defer.ensureDeferred(
- self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id
- )
+ await self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id
)
- yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
+ await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
- yield self.notify_device_update(user_id, [device_id])
+ await self.notify_device_update(user_id, [device_id])
@trace
- @defer.inlineCallbacks
- def delete_all_devices_for_user(self, user_id, except_device_id=None):
+ async def delete_all_devices_for_user(
+ self, user_id: str, except_device_id: Optional[str] = None
+ ) -> None:
"""Delete all of the user's devices
Args:
- user_id (str):
- except_device_id (str|None): optional device id which should not
- be deleted
-
- Returns:
- defer.Deferred:
+ user_id: The user to remove all devices from
+ except_device_id: optional device id which should not be deleted
"""
- device_map = yield self.store.get_devices_by_user(user_id)
+ device_map = await self.store.get_devices_by_user(user_id)
device_ids = list(device_map)
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
- yield self.delete_devices(user_id, device_ids)
+ await self.delete_devices(user_id, device_ids)
- @defer.inlineCallbacks
- def delete_devices(self, user_id, device_ids):
+ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
""" Delete several devices
Args:
- user_id (str):
- device_ids (List[str]): The list of device IDs to delete
-
- Returns:
- defer.Deferred:
+ user_id: The user to delete devices from.
+ device_ids: The list of device IDs to delete
"""
try:
- yield self.store.delete_devices(user_id, device_ids)
+ await self.store.delete_devices(user_id, device_ids)
except errors.StoreError as e:
if e.code == 404:
# no match
@@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
- yield defer.ensureDeferred(
- self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id
- )
+ await self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id
)
- yield self.store.delete_e2e_keys_by_device(
+ await self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
- yield self.notify_device_update(user_id, device_ids)
+ await self.notify_device_update(user_id, device_ids)
- @defer.inlineCallbacks
- def update_device(self, user_id, device_id, content):
+ async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
""" Update the given device
Args:
- user_id (str):
- device_id (str):
- content (dict): body of update request
-
- Returns:
- defer.Deferred:
+ user_id: The user to update devices of.
+ device_id: The device to update.
+ content: body of update request
"""
# Reject a new displayname which is too long.
@@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler):
)
try:
- yield self.store.update_device(
+ await self.store.update_device(
user_id, device_id, new_display_name=new_display_name
)
- yield self.notify_device_update(user_id, [device_id])
+ await self.notify_device_update(user_id, [device_id])
except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
@@ -443,12 +417,15 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
- @defer.inlineCallbacks
- def notify_device_update(self, user_id, device_ids):
+ async def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ if not device_ids:
+ # No changes to notify about, so this is a no-op.
+ return
+
+ users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
@@ -459,20 +436,24 @@ class DeviceHandler(DeviceWorkerHandler):
set_tag("target_hosts", hosts)
- position = yield self.store.add_device_change_to_streams(
+ position = await self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
+ if not position:
+ # This should only happen if there are no updates, so we bail.
+ return
+
for device_id in device_ids:
logger.debug(
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"device_list_key", position, users=[user_id], rooms=room_ids
)
@@ -484,29 +465,29 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender.send_device_messages(host)
log_kv({"message": "sent device update to host", "host": host})
- @defer.inlineCallbacks
- def notify_user_signature_update(self, from_user_id, user_ids):
+ async def notify_user_signature_update(
+ self, from_user_id: str, user_ids: List[str]
+ ) -> None:
"""Notify a user that they have made new signatures of other users.
Args:
- from_user_id (str): the user who made the signature
- user_ids (list[str]): the users IDs that have new signatures
+ from_user_id: the user who made the signature
+ user_ids: the users IDs that have new signatures
"""
- position = yield self.store.add_user_signature_change_to_streams(
+ position = await self.store.add_user_signature_change_to_streams(
from_user_id, user_ids
)
self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
- @defer.inlineCallbacks
- def user_left_room(self, user, room_id):
+ async def user_left_room(self, user, room_id):
user_id = user.to_string()
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
- yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+ await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips):
@@ -549,8 +530,7 @@ class DeviceListUpdater(object):
)
@trace
- @defer.inlineCallbacks
- def incoming_device_list_update(self, origin, edu_content):
+ async def incoming_device_list_update(self, origin, edu_content):
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
@@ -583,7 +563,7 @@ class DeviceListUpdater(object):
)
return
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
@@ -608,14 +588,13 @@ class DeviceListUpdater(object):
(device_id, stream_id, prev_ids, edu_content)
)
- yield self._handle_device_updates(user_id)
+ await self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update")
- @defer.inlineCallbacks
- def _handle_device_updates(self, user_id):
+ async def _handle_device_updates(self, user_id):
"Actually handle pending updates."
- with (yield self._remote_edu_linearizer.queue(user_id)):
+ with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
@@ -632,7 +611,7 @@ class DeviceListUpdater(object):
# Given a list of updates we check if we need to resync. This
# happens if we've missed updates.
- resync = yield self._need_to_do_resync(user_id, pending_updates)
+ resync = await self._need_to_do_resync(user_id, pending_updates)
if logger.isEnabledFor(logging.INFO):
logger.info(
@@ -643,16 +622,16 @@ class DeviceListUpdater(object):
)
if resync:
- yield self.user_device_resync(user_id)
+ await self.user_device_resync(user_id)
else:
# Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
- yield self.store.update_remote_device_list_cache_entry(
+ await self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id
)
- yield self.device_handler.notify_device_update(
+ await self.device_handler.notify_device_update(
user_id, [device_id for device_id, _, _, _ in pending_updates]
)
@@ -660,14 +639,13 @@ class DeviceListUpdater(object):
stream_id for _, stream_id, _, _ in pending_updates
)
- @defer.inlineCallbacks
- def _need_to_do_resync(self, user_id, updates):
+ async def _need_to_do_resync(self, user_id, updates):
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set())
- extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
+ extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
logger.debug("Current extremity for %r: %r", user_id, extremity)
@@ -692,8 +670,7 @@ class DeviceListUpdater(object):
return False
@trace
- @defer.inlineCallbacks
- def _maybe_retry_device_resync(self):
+ async def _maybe_retry_device_resync(self):
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
@@ -705,12 +682,12 @@ class DeviceListUpdater(object):
# we don't send too many requests.
self._resync_retry_in_progress = True
# Get all of the users that need resyncing.
- need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
+ need_resync = await self.store.get_user_ids_requiring_device_list_resync()
# Iterate over the set of user IDs.
for user_id in need_resync:
try:
# Try to resync the current user's devices list.
- result = yield self.user_device_resync(
+ result = await self.user_device_resync(
user_id=user_id, mark_failed_as_stale=False,
)
@@ -734,16 +711,17 @@ class DeviceListUpdater(object):
# Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False
- @defer.inlineCallbacks
- def user_device_resync(self, user_id, mark_failed_as_stale=True):
+ async def user_device_resync(
+ self, user_id: str, mark_failed_as_stale: bool = True
+ ) -> Optional[dict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
- user_id (str): The user's id whose device_list will be updated.
- mark_failed_as_stale (bool): Whether to mark the user's device list as stale
+ user_id: The user's id whose device_list will be updated.
+ mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
- Deferred[dict]: a dict with device info as under the "devices" in the result of this
+ A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
@@ -752,12 +730,12 @@ class DeviceListUpdater(object):
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
- result = yield self.federation.query_user_devices(origin, user_id)
+ result = await self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
- yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+ await self.store.mark_remote_user_device_cache_as_stale(user_id)
return
except (RequestSendFailed, HttpResponseException) as e:
@@ -768,7 +746,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
- yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+ await self.store.mark_remote_user_device_cache_as_stale(user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
@@ -792,7 +770,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
- yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+ await self.store.mark_remote_user_device_cache_as_stale(user_id)
return
log_kv({"result": result})
@@ -833,25 +811,24 @@ class DeviceListUpdater(object):
stream_id,
)
- yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
+ await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys.
- cross_signing_device_ids = yield self.process_cross_signing_key_update(
+ cross_signing_device_ids = await self.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
)
device_ids = device_ids + cross_signing_device_ids
- yield self.device_handler.notify_device_update(user_id, device_ids)
+ await self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = {stream_id}
- defer.returnValue(result)
+ return result
- @defer.inlineCallbacks
- def process_cross_signing_key_update(
+ async def process_cross_signing_key_update(
self,
user_id: str,
master_key: Optional[Dict[str, Any]],
@@ -872,14 +849,14 @@ class DeviceListUpdater(object):
device_ids = []
if master_key:
- yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+ await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
if self_signing_key:
- yield self.store.set_e2e_cross_signing_key(
+ await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
_, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index a7e60cbc..84169c10 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,10 +16,11 @@
# limitations under the License.
import logging
+from typing import Dict, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json, json
-from signedjson.key import decode_verify_key_bytes
+from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
@@ -77,8 +78,7 @@ class E2eKeysHandler(object):
)
@trace
- @defer.inlineCallbacks
- def query_devices(self, query_body, timeout, from_user_id):
+ async def query_devices(self, query_body, timeout, from_user_id):
""" Handle a device key query from a client
{
@@ -124,7 +124,7 @@ class E2eKeysHandler(object):
failures = {}
results = {}
if local_query:
- local_result = yield self.query_local_devices(local_query)
+ local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
@@ -142,7 +142,7 @@ class E2eKeysHandler(object):
(
user_ids_not_in_cache,
remote_results,
- ) = yield self.store.get_user_devices_from_cache(query_list)
+ ) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
@@ -161,14 +161,13 @@ class E2eKeysHandler(object):
r[user_id] = remote_queries[user_id]
# Get cached cross-signing keys
- cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)
# Now fetch any devices that we don't have in our cache
@trace
- @defer.inlineCallbacks
- def do_remote_query(destination):
+ async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
@@ -192,7 +191,7 @@ class E2eKeysHandler(object):
if device_list:
continue
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
@@ -201,11 +200,11 @@ class E2eKeysHandler(object):
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
- user_devices = yield self.device_handler.device_list_updater.user_device_resync(
+ user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
- user_devices = yield self._user_device_resync_client(
+ user_devices = await self._user_device_resync_client(
user_id=user_id
)
@@ -227,7 +226,7 @@ class E2eKeysHandler(object):
destination_query.pop(user_id)
try:
- remote_result = yield self.federation.query_client_keys(
+ remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)
@@ -251,7 +250,7 @@ class E2eKeysHandler(object):
set_tag("error", True)
set_tag("reason", failure)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
@@ -267,8 +266,9 @@ class E2eKeysHandler(object):
return ret
- @defer.inlineCallbacks
- def get_cross_signing_keys_from_cache(self, query, from_user_id):
+ async def get_cross_signing_keys_from_cache(
+ self, query, from_user_id
+ ) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
@@ -280,8 +280,7 @@ class E2eKeysHandler(object):
can see.
Returns:
- defer.Deferred[dict[str, dict[str, dict]]]: map from
- (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
+ A map from (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@@ -289,7 +288,7 @@ class E2eKeysHandler(object):
user_ids = list(query)
- keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
+ keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
for user_id, user_info in keys.items():
if user_info is None:
@@ -315,17 +314,17 @@ class E2eKeysHandler(object):
}
@trace
- @defer.inlineCallbacks
- def query_local_devices(self, query):
+ async def query_local_devices(
+ self, query: Dict[str, Optional[List[str]]]
+ ) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
Args:
- query (dict[string, list[string]|None): map from user_id to a list
+ query: map from user_id to a list
of devices to query (None for all devices)
Returns:
- defer.Deferred: (resolves to dict[string, dict[string, dict]]):
- map from user_id -> device_id -> device details
+ A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = []
@@ -354,7 +353,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = yield self.store.get_e2e_device_keys(local_query)
+ results = await self.store.get_e2e_device_keys(local_query)
# Build the result structure
for user_id, device_keys in results.items():
@@ -364,16 +363,15 @@ class E2eKeysHandler(object):
log_kv(results)
return result_dict
- @defer.inlineCallbacks
- def on_federation_query_client_keys(self, query_body):
+ async def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
- res = yield self.query_local_devices(device_keys_query)
+ res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
# add in the cross-signing keys
- cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, None
)
@@ -382,8 +380,7 @@ class E2eKeysHandler(object):
return ret
@trace
- @defer.inlineCallbacks
- def claim_one_time_keys(self, query, timeout):
+ async def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
@@ -399,7 +396,7 @@ class E2eKeysHandler(object):
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
- results = yield self.store.claim_e2e_one_time_keys(local_query)
+ results = await self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
@@ -411,12 +408,11 @@ class E2eKeysHandler(object):
}
@trace
- @defer.inlineCallbacks
- def claim_client_keys(destination):
+ async def claim_client_keys(destination):
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:
- remote_result = yield self.federation.claim_client_keys(
+ remote_result = await self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
@@ -429,7 +425,7 @@ class E2eKeysHandler(object):
set_tag("error", True)
set_tag("reason", failure)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(claim_client_keys, destination)
@@ -454,9 +450,8 @@ class E2eKeysHandler(object):
log_kv({"one_time_keys": json_result, "failures": failures})
return {"one_time_keys": json_result, "failures": failures}
- @defer.inlineCallbacks
@tag_args
- def upload_keys_for_user(self, user_id, device_id, keys):
+ async def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
@@ -477,12 +472,12 @@ class E2eKeysHandler(object):
}
)
# TODO: Sign the JSON with the server key
- changed = yield self.store.set_e2e_device_keys(
+ changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
- yield self.device_handler.notify_device_update(user_id, [device_id])
+ await self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None)
@@ -494,7 +489,7 @@ class E2eKeysHandler(object):
"device_id": device_id,
}
)
- yield self._upload_one_time_keys_for_user(
+ await self._upload_one_time_keys_for_user(
user_id, device_id, time_now, one_time_keys
)
else:
@@ -507,15 +502,14 @@ class E2eKeysHandler(object):
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
- yield self.device_handler.check_device_registered(user_id, device_id)
+ await self.device_handler.check_device_registered(user_id, device_id)
- result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+ result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", result)
return {"one_time_key_counts": result}
- @defer.inlineCallbacks
- def _upload_one_time_keys_for_user(
+ async def _upload_one_time_keys_for_user(
self, user_id, device_id, time_now, one_time_keys
):
logger.info(
@@ -533,7 +527,7 @@ class E2eKeysHandler(object):
key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
- existing_key_map = yield self.store.get_e2e_one_time_keys(
+ existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)
@@ -556,10 +550,9 @@ class E2eKeysHandler(object):
)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
- yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+ await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
- @defer.inlineCallbacks
- def upload_signing_keys_for_user(self, user_id, keys):
+ async def upload_signing_keys_for_user(self, user_id, keys):
"""Upload signing keys for cross-signing
Args:
@@ -574,7 +567,7 @@ class E2eKeysHandler(object):
_check_cross_signing_key(master_key, user_id, "master")
else:
- master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+ master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
# if there is no master key, then we can't do anything, because all the
# other cross-signing keys need to be signed by the master key
@@ -613,10 +606,10 @@ class E2eKeysHandler(object):
# if everything checks out, then store the keys and send notifications
deviceids = []
if "master_key" in keys:
- yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+ await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
deviceids.append(master_verify_key.version)
if "self_signing_key" in keys:
- yield self.store.set_e2e_cross_signing_key(
+ await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
try:
@@ -626,23 +619,22 @@ class E2eKeysHandler(object):
except ValueError:
raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
if "user_signing_key" in keys:
- yield self.store.set_e2e_cross_signing_key(
+ await self.store.set_e2e_cross_signing_key(
user_id, "user_signing", user_signing_key
)
# the signature stream matches the semantics that we want for
# user-signing key updates: only the user themselves is notified of
# their own user-signing key updates
- yield self.device_handler.notify_user_signature_update(user_id, [user_id])
+ await self.device_handler.notify_user_signature_update(user_id, [user_id])
# master key and self-signing key updates match the semantics of device
# list updates: all users who share an encrypted room are notified
if len(deviceids):
- yield self.device_handler.notify_device_update(user_id, deviceids)
+ await self.device_handler.notify_device_update(user_id, deviceids)
return {}
- @defer.inlineCallbacks
- def upload_signatures_for_device_keys(self, user_id, signatures):
+ async def upload_signatures_for_device_keys(self, user_id, signatures):
"""Upload device signatures for cross-signing
Args:
@@ -667,13 +659,13 @@ class E2eKeysHandler(object):
self_signatures = signatures.get(user_id, {})
other_signatures = {k: v for k, v in signatures.items() if k != user_id}
- self_signature_list, self_failures = yield self._process_self_signatures(
+ self_signature_list, self_failures = await self._process_self_signatures(
user_id, self_signatures
)
signature_list.extend(self_signature_list)
failures.update(self_failures)
- other_signature_list, other_failures = yield self._process_other_signatures(
+ other_signature_list, other_failures = await self._process_other_signatures(
user_id, other_signatures
)
signature_list.extend(other_signature_list)
@@ -681,21 +673,20 @@ class E2eKeysHandler(object):
# store the signature, and send the appropriate notifications for sync
logger.debug("upload signature failures: %r", failures)
- yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
+ await self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
self_device_ids = [item.target_device_id for item in self_signature_list]
if self_device_ids:
- yield self.device_handler.notify_device_update(user_id, self_device_ids)
+ await self.device_handler.notify_device_update(user_id, self_device_ids)
signed_users = [item.target_user_id for item in other_signature_list]
if signed_users:
- yield self.device_handler.notify_user_signature_update(
+ await self.device_handler.notify_user_signature_update(
user_id, signed_users
)
return {"failures": failures}
- @defer.inlineCallbacks
- def _process_self_signatures(self, user_id, signatures):
+ async def _process_self_signatures(self, user_id, signatures):
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@@ -728,7 +719,7 @@ class E2eKeysHandler(object):
_,
self_signing_key_id,
self_signing_verify_key,
- ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
+ ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
# get our master key, since we may have received a signature of it.
# We need to fetch it here so that we know what its key ID is, so
@@ -738,12 +729,12 @@ class E2eKeysHandler(object):
master_key,
_,
master_verify_key,
- ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
+ ) = await self._get_e2e_cross_signing_verify_key(user_id, "master")
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
# was sent if the device was signed
- devices = yield self.store.get_e2e_device_keys([(user_id, None)])
+ devices = await self.store.get_e2e_device_keys([(user_id, None)])
if user_id not in devices:
raise NotFoundError("No device keys found")
@@ -853,8 +844,7 @@ class E2eKeysHandler(object):
return master_key_signature_list
- @defer.inlineCallbacks
- def _process_other_signatures(self, user_id, signatures):
+ async def _process_other_signatures(self, user_id, signatures):
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
@@ -882,7 +872,7 @@ class E2eKeysHandler(object):
user_signing_key,
user_signing_key_id,
user_signing_verify_key,
- ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
+ ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
except SynapseError as e:
failure = _exception_to_failure(e)
for user, devicemap in signatures.items():
@@ -905,7 +895,7 @@ class E2eKeysHandler(object):
master_key,
master_key_id,
_,
- ) = yield self._get_e2e_cross_signing_verify_key(
+ ) = await self._get_e2e_cross_signing_verify_key(
target_user, "master", user_id
)
@@ -958,8 +948,7 @@ class E2eKeysHandler(object):
return signature_list, failures
- @defer.inlineCallbacks
- def _get_e2e_cross_signing_verify_key(
+ async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
):
"""Fetch locally or remotely query for a cross-signing public key.
@@ -983,7 +972,7 @@ class E2eKeysHandler(object):
SynapseError: if `user_id` is invalid
"""
user = UserID.from_string(user_id)
- key = yield self.store.get_e2e_cross_signing_key(
+ key = await self.store.get_e2e_cross_signing_key(
user_id, key_type, from_user_id
)
@@ -1009,17 +998,16 @@ class E2eKeysHandler(object):
key,
key_id,
verify_key,
- ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
+ ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
if key is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
return key, key_id, verify_key
- @defer.inlineCallbacks
- def _retrieve_cross_signing_keys_for_remote_user(
+ async def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str,
- ):
+ ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1030,12 +1018,11 @@ class E2eKeysHandler(object):
desired_key_type: The type of key to receive. One of "master", "self_signing"
Returns:
- Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple
- of the retrieved key content, the key's ID and the matching VerifyKey.
+ A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
try:
- remote_result = yield self.federation.query_user_devices(
+ remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
)
except Exception as e:
@@ -1101,14 +1088,14 @@ class E2eKeysHandler(object):
desired_key_id = key_id
# At the same time, store this key in the db for subsequent queries
- yield self.store.set_e2e_cross_signing_key(
+ await self.store.set_e2e_cross_signing_key(
user.to_string(), key_type, key_content
)
# Notify clients that new devices for this user have been discovered
if retrieved_device_ids:
# XXX is this necessary?
- yield self.device_handler.notify_device_update(
+ await self.device_handler.notify_device_update(
user.to_string(), retrieved_device_ids
)
@@ -1250,8 +1237,7 @@ class SigningKeyEduUpdater(object):
iterable=True,
)
- @defer.inlineCallbacks
- def incoming_signing_key_update(self, origin, edu_content):
+ async def incoming_signing_key_update(self, origin, edu_content):
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
@@ -1268,7 +1254,7 @@ class SigningKeyEduUpdater(object):
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
@@ -1278,10 +1264,9 @@ class SigningKeyEduUpdater(object):
(master_key, self_signing_key)
)
- yield self._handle_signing_key_updates(user_id)
+ await self._handle_signing_key_updates(user_id)
- @defer.inlineCallbacks
- def _handle_signing_key_updates(self, user_id):
+ async def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates.
Args:
@@ -1291,7 +1276,7 @@ class SigningKeyEduUpdater(object):
device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater
- with (yield self._remote_edu_linearizer.queue(user_id)):
+ with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
@@ -1302,9 +1287,9 @@ class SigningKeyEduUpdater(object):
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = yield device_list_updater.process_cross_signing_key_update(
+ new_device_ids = await device_list_updater.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
)
device_ids = device_ids + new_device_ids
- yield device_handler.notify_device_update(user_id, device_ids)
+ await device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f55470a7..0bb983dc 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import (
Codes,
NotFoundError,
@@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object):
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
- @defer.inlineCallbacks
- def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object):
# we deliberately take the lock to get keys so that changing the version
# works atomically
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
# make sure the backup version exists
try:
- yield self.store.get_e2e_room_keys_version_info(user_id, version)
+ await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
- results = yield self.store.get_e2e_room_keys(
+ results = await self.store.get_e2e_room_keys(
user_id, version, room_id, session_id
)
@@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object):
return results
@trace
- @defer.inlineCallbacks
- def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
@@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object):
"""
# lock for consistency with uploading
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
# make sure the backup version exists
try:
- version_info = yield self.store.get_e2e_room_keys_version_info(
+ version_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
except StoreError as e:
@@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object):
else:
raise
- yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
+ await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
version_etag = version_info["etag"] + 1
- yield self.store.update_e2e_room_keys_version(
+ await self.store.update_e2e_room_keys_version(
user_id, version, None, version_etag
)
- count = yield self.store.count_e2e_room_keys(user_id, version)
+ count = await self.store.count_e2e_room_keys(user_id, version)
return {"etag": str(version_etag), "count": count}
@trace
- @defer.inlineCallbacks
- def upload_room_keys(self, user_id, version, room_keys):
+ async def upload_room_keys(self, user_id, version, room_keys):
"""Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT().
@@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object):
# TODO: Validate the JSON to make sure it has the right keys.
# XXX: perhaps we should use a finer grained lock here?
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
# Check that the version we're trying to upload is the current version
try:
- version_info = yield self.store.get_e2e_room_keys_version_info(user_id)
+ version_info = await self.store.get_e2e_room_keys_version_info(user_id)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Version '%s' not found" % (version,))
@@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object):
if version_info["version"] != version:
# Check that the version we're trying to upload actually exists
try:
- version_info = yield self.store.get_e2e_room_keys_version_info(
+ version_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
# if we get this far, the version must exist
@@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object):
# submitted. Then compare them with the submitted keys. If the
# key is new, insert it; if the key should be updated, then update
# it; otherwise, drop it.
- existing_keys = yield self.store.get_e2e_room_keys_multi(
+ existing_keys = await self.store.get_e2e_room_keys_multi(
user_id, version, room_keys["rooms"]
)
to_insert = [] # batch the inserts together
@@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object):
# updates are done one at a time in the DB, so send
# updates right away rather than batching them up,
# like we do with the inserts
- yield self.store.update_e2e_room_key(
+ await self.store.update_e2e_room_key(
user_id, version, room_id, session_id, room_key
)
changed = True
@@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object):
changed = True
if len(to_insert):
- yield self.store.add_e2e_room_keys(user_id, version, to_insert)
+ await self.store.add_e2e_room_keys(user_id, version, to_insert)
version_etag = version_info["etag"]
if changed:
version_etag = version_etag + 1
- yield self.store.update_e2e_room_keys_version(
+ await self.store.update_e2e_room_keys_version(
user_id, version, None, version_etag
)
- count = yield self.store.count_e2e_room_keys(user_id, version)
+ count = await self.store.count_e2e_room_keys(user_id, version)
return {"etag": str(version_etag), "count": count}
@staticmethod
@@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object):
return True
@trace
- @defer.inlineCallbacks
- def create_version(self, user_id, version_info):
+ async def create_version(self, user_id, version_info):
"""Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be
writeable to.
@@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object):
# TODO: Validate the JSON to make sure it has the right keys.
# lock everyone out until we've switched version
- with (yield self._upload_linearizer.queue(user_id)):
- new_version = yield self.store.create_e2e_room_keys_version(
+ with (await self._upload_linearizer.queue(user_id)):
+ new_version = await self.store.create_e2e_room_keys_version(
user_id, version_info
)
return new_version
- @defer.inlineCallbacks
- def get_version_info(self, user_id, version=None):
+ async def get_version_info(self, user_id, version=None):
"""Get the info about a given version of the user's backup
Args:
@@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object):
}
"""
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
try:
- res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
+ res = await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
- res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
+ res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"])
res["etag"] = str(res["etag"])
return res
@trace
- @defer.inlineCallbacks
- def delete_version(self, user_id, version=None):
+ async def delete_version(self, user_id, version=None):
"""Deletes a given version of the user's e2e_room_keys backup
Args:
@@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object):
NotFoundError: if this backup version doesn't exist
"""
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
try:
- yield self.store.delete_e2e_room_keys_version(user_id, version)
+ await self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
@@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object):
raise
@trace
- @defer.inlineCallbacks
- def update_version(self, user_id, version, version_info):
+ async def update_version(self, user_id, version, version_info):
"""Update the info about a given version of the user's backup
Args:
@@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object):
raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM
)
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
try:
- old_info = yield self.store.get_e2e_room_keys_version_info(
+ old_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
except StoreError as e:
@@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object):
if old_info["algorithm"] != version_info["algorithm"]:
raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
- yield self.store.update_e2e_room_keys_version(
+ await self.store.update_e2e_room_keys_version(
user_id, version, version_info
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ca7da42a..f5f683bf 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
import itertools
import logging
-from collections import Container
+from collections.abc import Container
from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
@@ -44,6 +44,7 @@ from synapse.api.errors import (
FederationDeniedError,
FederationError,
HttpResponseException,
+ NotFoundError,
RequestSendFailed,
SynapseError,
)
@@ -61,6 +62,7 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.logging.utils import log_function
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
@@ -618,6 +620,11 @@ class FederationHandler(BaseHandler):
will be omitted from the result. Likewise, any events which turn out not to
be in the given room.
+ This function *does not* automatically get missing auth events of the
+ newly fetched events. Callers must include the full auth chain of
+ of the missing events in the `event_ids` argument, to ensure that any
+ missing auth events are correctly fetched.
+
Returns:
map from event_id to event
"""
@@ -784,15 +791,25 @@ class FederationHandler(BaseHandler):
resync = True
if resync:
- await self.store.mark_remote_user_device_cache_as_stale(event.sender)
+ run_as_background_process(
+ "resync_device_due_to_pdu", self._resync_device, event.sender
+ )
- # Immediately attempt a resync in the background
- if self.config.worker_app:
- return run_in_background(self._user_device_resync, event.sender)
- else:
- return run_in_background(
- self._device_list_updater.user_device_resync, event.sender
- )
+ async def _resync_device(self, sender: str) -> None:
+ """We have detected that the device list for the given user may be out
+ of sync, so we try and resync them.
+ """
+
+ try:
+ await self.store.mark_remote_user_device_cache_as_stale(sender)
+
+ # Immediately attempt a resync in the background
+ if self.config.worker_app:
+ await self._user_device_resync(user_id=sender)
+ else:
+ await self._device_list_updater.user_device_resync(sender)
+ except Exception:
+ logger.exception("Failed to resync device for %s", sender)
@log_function
async def backfill(self, dest, room_id, limit, extremities):
@@ -1131,12 +1148,16 @@ class FederationHandler(BaseHandler):
):
"""Fetch the given events from a server, and persist them as outliers.
+ This function *does not* recursively get missing auth events of the
+ newly fetched events. Callers must include in the `events` argument
+ any missing events from the auth chain.
+
Logs a warning if we can't find the given event.
"""
room_version = await self.store.get_room_version(room_id)
- event_infos = []
+ event_map = {} # type: Dict[str, EventBase]
async def get_event(event_id: str):
with nested_logging_context(event_id):
@@ -1150,17 +1171,7 @@ class FederationHandler(BaseHandler):
)
return
- # recursively fetch the auth events for this event
- auth_events = await self._get_events_from_store_or_dest(
- destination, room_id, event.auth_event_ids()
- )
- auth = {}
- for auth_event_id in event.auth_event_ids():
- ae = auth_events.get(auth_event_id)
- if ae:
- auth[(ae.type, ae.state_key)] = ae
-
- event_infos.append(_NewEventInfo(event, None, auth))
+ event_map[event.event_id] = event
except Exception as e:
logger.warning(
@@ -1172,6 +1183,32 @@ class FederationHandler(BaseHandler):
await concurrently_execute(get_event, events, 5)
+ # Make a map of auth events for each event. We do this after fetching
+ # all the events as some of the events' auth events will be in the list
+ # of requested events.
+
+ auth_events = [
+ aid
+ for event in event_map.values()
+ for aid in event.auth_event_ids()
+ if aid not in event_map
+ ]
+ persisted_events = await self.store.get_events(
+ auth_events, allow_rejected=True,
+ )
+
+ event_infos = []
+ for event in event_map.values():
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+ else:
+ logger.info("Missing auth event %s", auth_event_id)
+
+ event_infos.append(_NewEventInfo(event, None, auth))
+
await self._handle_new_events(
destination, event_infos,
)
@@ -1357,7 +1394,7 @@ class FederationHandler(BaseHandler):
# it's just a best-effort thing at this point. We do want to do
# them roughly in order, though, otherwise we'll end up making
# lots of requests for missing prev_events which we do actually
- # have. Hence we fire off the deferred, but don't wait for it.
+ # have. Hence we fire off the background task, but don't wait for it.
run_in_background(self._handle_queued_pdus, room_queue)
@@ -1403,10 +1440,20 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
- event_content = {"membership": Membership.JOIN}
-
+ # checking the room version will check that we've actually heard of the room
+ # (and return a 404 otherwise)
room_version = await self.store.get_room_version_id(room_id)
+ # now check that we are *still* in the room
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+ if not is_in_room:
+ logger.info(
+ "Got /make_join request for room %s we are no longer in", room_id,
+ )
+ raise NotFoundError("Not an active room on this server")
+
+ event_content = {"membership": Membership.JOIN}
+
builder = self.event_builder_factory.new(
room_version,
{
@@ -1840,9 +1887,6 @@ class FederationHandler(BaseHandler):
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
- # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
- # hack around with a try/finally instead.
- success = False
try:
if (
not event.internal_metadata.is_outlier()
@@ -1856,12 +1900,11 @@ class FederationHandler(BaseHandler):
await self.persist_events_and_notify(
[(event, context)], backfilled=backfilled
)
- success = True
- finally:
- if not success:
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
+ except Exception:
+ run_in_background(
+ self.store.remove_push_actions_from_staging, event.event_id
+ )
+ raise
return context
@@ -2947,7 +2990,9 @@ class FederationHandler(BaseHandler):
else:
user_joined_room(self.distributor, user, room_id)
- async def get_room_complexity(self, remote_room_hosts, room_id):
+ async def get_room_complexity(
+ self, remote_room_hosts: List[str], room_id: str
+ ) -> Optional[dict]:
"""
Fetch the complexity of a remote room over federation.
@@ -2956,7 +3001,7 @@ class FederationHandler(BaseHandler):
room_id (str): The room ID to ask about.
Returns:
- Deferred[dict] or Deferred[None]: Dict contains the complexity
+ Dict contains the complexity
metric versions, while None means we could not fetch the complexity.
"""
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 701233eb..0bd2c3e3 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -19,6 +19,7 @@
import logging
import urllib.parse
+from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from canonicaljson import json
from signedjson.key import decode_verify_key_bytes
@@ -36,6 +37,7 @@ from synapse.api.errors import (
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient
+from synapse.types import JsonDict, Requester
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -59,23 +61,23 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_http_client()
self.hs = hs
- async def threepid_from_creds(self, id_server, creds):
+ async def threepid_from_creds(
+ self, id_server: str, creds: Dict[str, str]
+ ) -> Optional[JsonDict]:
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
- id_server (str): The identity server to validate 3PIDs against. Must be a
+ id_server: The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
-
- creds (dict[str, str]): Dictionary containing the following keys:
+ creds: Dictionary containing the following keys:
* client_secret|clientSecret: A unique secret str provided by the client
* sid: The ID of the validation session
Returns:
- Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to
- the /getValidated3pid endpoint of the Identity Service API, or None if the
- threepid was not found
+ A dictionary consisting of response params to the /getValidated3pid
+ endpoint of the Identity Service API, or None if the threepid was not found
"""
client_secret = creds.get("client_secret") or creds.get("clientSecret")
if not client_secret:
@@ -119,26 +121,27 @@ class IdentityHandler(BaseHandler):
return None
async def bind_threepid(
- self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
- ):
+ self,
+ client_secret: str,
+ sid: str,
+ mxid: str,
+ id_server: str,
+ id_access_token: Optional[str] = None,
+ use_v2: bool = True,
+ ) -> JsonDict:
"""Bind a 3PID to an identity server
Args:
- client_secret (str): A unique secret provided by the client
-
- sid (str): The ID of the validation session
-
- mxid (str): The MXID to bind the 3PID to
-
- id_server (str): The domain of the identity server to query
-
- id_access_token (str): The access token to authenticate to the identity
+ client_secret: A unique secret provided by the client
+ sid: The ID of the validation session
+ mxid: The MXID to bind the 3PID to
+ id_server: The domain of the identity server to query
+ id_access_token: The access token to authenticate to the identity
server with, if necessary. Required if use_v2 is true
-
- use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True
+ use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
Returns:
- Deferred[dict]: The response from the identity server
+ The response from the identity server
"""
logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
@@ -151,7 +154,7 @@ class IdentityHandler(BaseHandler):
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
- headers["Authorization"] = create_id_access_token_header(id_access_token)
+ headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
else:
bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
@@ -187,20 +190,20 @@ class IdentityHandler(BaseHandler):
)
return res
- async def try_unbind_threepid(self, mxid, threepid):
+ async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
Args:
- mxid (str): Matrix user ID of binding to be removed
- threepid (dict): Dict with medium & address of binding to be
+ mxid: Matrix user ID of binding to be removed
+ threepid: Dict with medium & address of binding to be
removed, and an optional id_server.
Raises:
SynapseError: If we failed to contact the identity server
Returns:
- Deferred[bool]: True on success, otherwise False if the identity
+ True on success, otherwise False if the identity
server doesn't support unbinding (or no identity server found to
contact).
"""
@@ -223,19 +226,21 @@ class IdentityHandler(BaseHandler):
return changed
- async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
+ async def try_unbind_threepid_with_id_server(
+ self, mxid: str, threepid: dict, id_server: str
+ ) -> bool:
"""Removes a binding from an identity server
Args:
- mxid (str): Matrix user ID of binding to be removed
- threepid (dict): Dict with medium & address of binding to be removed
- id_server (str): Identity server to unbind from
+ mxid: Matrix user ID of binding to be removed
+ threepid: Dict with medium & address of binding to be removed
+ id_server: Identity server to unbind from
Raises:
SynapseError: If we failed to contact the identity server
Returns:
- Deferred[bool]: True on success, otherwise False if the identity
+ True on success, otherwise False if the identity
server doesn't support unbinding
"""
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
@@ -287,23 +292,23 @@ class IdentityHandler(BaseHandler):
async def send_threepid_validation(
self,
- email_address,
- client_secret,
- send_attempt,
- send_email_func,
- next_link=None,
- ):
+ email_address: str,
+ client_secret: str,
+ send_attempt: int,
+ send_email_func: Callable[[str, str, str, str], Awaitable],
+ next_link: Optional[str] = None,
+ ) -> str:
"""Send a threepid validation email for password reset or
registration purposes
Args:
- email_address (str): The user's email address
- client_secret (str): The provided client secret
- send_attempt (int): Which send attempt this is
- send_email_func (func): A function that takes an email address, token,
- client_secret and session_id, sends an email
- and returns a Deferred.
- next_link (str|None): The URL to redirect the user to after validation
+ email_address: The user's email address
+ client_secret: The provided client secret
+ send_attempt: Which send attempt this is
+ send_email_func: A function that takes an email address, token,
+ client_secret and session_id, sends an email
+ and returns an Awaitable.
+ next_link: The URL to redirect the user to after validation
Returns:
The new session_id upon success
@@ -372,17 +377,22 @@ class IdentityHandler(BaseHandler):
return session_id
async def requestEmailToken(
- self, id_server, email, client_secret, send_attempt, next_link=None
- ):
+ self,
+ id_server: str,
+ email: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str] = None,
+ ) -> JsonDict:
"""
Request an external server send an email on our behalf for the purposes of threepid
validation.
Args:
- id_server (str): The identity server to proxy to
- email (str): The email to send the message to
- client_secret (str): The unique client_secret sends by the user
- send_attempt (int): Which attempt this is
+ id_server: The identity server to proxy to
+ email: The email to send the message to
+ client_secret: The unique client_secret sends by the user
+ send_attempt: Which attempt this is
next_link: A link to redirect the user to once they submit the token
Returns:
@@ -419,22 +429,22 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
- id_server,
- country,
- phone_number,
- client_secret,
- send_attempt,
- next_link=None,
- ):
+ id_server: str,
+ country: str,
+ phone_number: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str] = None,
+ ) -> JsonDict:
"""
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
- id_server (str): The identity server to proxy to
- country (str): The country code of the phone number
- phone_number (str): The number to send the message to
- client_secret (str): The unique client_secret sends by the user
- send_attempt (int): Which attempt this is
+ id_server: The identity server to proxy to
+ country: The country code of the phone number
+ phone_number: The number to send the message to
+ client_secret: The unique client_secret sends by the user
+ send_attempt: Which attempt this is
next_link: A link to redirect the user to once they submit the token
Returns:
@@ -480,17 +490,18 @@ class IdentityHandler(BaseHandler):
)
return data
- async def validate_threepid_session(self, client_secret, sid):
+ async def validate_threepid_session(
+ self, client_secret: str, sid: str
+ ) -> Optional[JsonDict]:
"""Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally.
Args:
- client_secret (str): A secret provided by the client
-
- sid (str): The ID of the session
+ client_secret: A secret provided by the client
+ sid: The ID of the session
Returns:
- Dict[str, str|int] if validation was successful, otherwise None
+ The json response if validation was successful, otherwise None
"""
# XXX: We shouldn't need to keep wrapping and unwrapping this value
threepid_creds = {"client_secret": client_secret, "sid": sid}
@@ -523,23 +534,22 @@ class IdentityHandler(BaseHandler):
return validation_session
- async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+ async def proxy_msisdn_submit_token(
+ self, id_server: str, client_secret: str, sid: str, token: str
+ ) -> JsonDict:
"""Proxy a POST submitToken request to an identity server for verification purposes
Args:
- id_server (str): The identity server URL to contact
-
- client_secret (str): Secret provided by the client
-
- sid (str): The ID of the session
-
- token (str): The verification token
+ id_server: The identity server URL to contact
+ client_secret: Secret provided by the client
+ sid: The ID of the session
+ token: The verification token
Raises:
SynapseError: If we failed to contact the identity server
Returns:
- Deferred[dict]: The response dict from the identity server
+ The response dict from the identity server
"""
body = {"client_secret": client_secret, "sid": sid, "token": token}
@@ -554,19 +564,25 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
- async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
+ async def lookup_3pid(
+ self,
+ id_server: str,
+ medium: str,
+ address: str,
+ id_access_token: Optional[str] = None,
+ ) -> Optional[str]:
"""Looks up a 3pid in the passed identity server.
Args:
- id_server (str): The server name (including port, if required)
+ id_server: The server name (including port, if required)
of the identity server to use.
- medium (str): The type of the third party identifier (e.g. "email").
- address (str): The third party identifier (e.g. "foo@example.com").
- id_access_token (str|None): The access token to authenticate to the identity
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
+ id_access_token: The access token to authenticate to the identity
server with
Returns:
- str|None: the matrix ID of the 3pid, or None if it is not recognized.
+ the matrix ID of the 3pid, or None if it is not recognized.
"""
if id_access_token is not None:
try:
@@ -591,17 +607,19 @@ class IdentityHandler(BaseHandler):
return await self._lookup_3pid_v1(id_server, medium, address)
- async def _lookup_3pid_v1(self, id_server, medium, address):
+ async def _lookup_3pid_v1(
+ self, id_server: str, medium: str, address: str
+ ) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
- id_server (str): The server name (including port, if required)
+ id_server: The server name (including port, if required)
of the identity server to use.
- medium (str): The type of the third party identifier (e.g. "email").
- address (str): The third party identifier (e.g. "foo@example.com").
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
Returns:
- str: the matrix ID of the 3pid, or None if it is not recognized.
+ the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = await self.blacklisting_http_client.get_json(
@@ -621,18 +639,20 @@ class IdentityHandler(BaseHandler):
return None
- async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
+ async def _lookup_3pid_v2(
+ self, id_server: str, id_access_token: str, medium: str, address: str
+ ) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
- id_server (str): The server name (including port, if required)
+ id_server: The server name (including port, if required)
of the identity server to use.
- id_access_token (str): The access token to authenticate to the identity server with
- medium (str): The type of the third party identifier (e.g. "email").
- address (str): The third party identifier (e.g. "foo@example.com").
+ id_access_token: The access token to authenticate to the identity server with
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
Returns:
- Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
+ the matrix ID of the 3pid, or None if it is not recognised.
"""
# Check what hashing details are supported by this identity server
try:
@@ -757,49 +777,48 @@ class IdentityHandler(BaseHandler):
async def ask_id_server_for_third_party_invite(
self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter_user_id,
- room_alias,
- room_avatar_url,
- room_join_rules,
- room_name,
- inviter_display_name,
- inviter_avatar_url,
- id_access_token=None,
- ):
+ requester: Requester,
+ id_server: str,
+ medium: str,
+ address: str,
+ room_id: str,
+ inviter_user_id: str,
+ room_alias: str,
+ room_avatar_url: str,
+ room_join_rules: str,
+ room_name: str,
+ inviter_display_name: str,
+ inviter_avatar_url: str,
+ id_access_token: Optional[str] = None,
+ ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
"""
Asks an identity server for a third party invite.
Args:
- requester (Requester)
- id_server (str): hostname + optional port for the identity server.
- medium (str): The literal string "email".
- address (str): The third party address being invited.
- room_id (str): The ID of the room to which the user is invited.
- inviter_user_id (str): The user ID of the inviter.
- room_alias (str): An alias for the room, for cosmetic notifications.
- room_avatar_url (str): The URL of the room's avatar, for cosmetic
+ requester
+ id_server: hostname + optional port for the identity server.
+ medium: The literal string "email".
+ address: The third party address being invited.
+ room_id: The ID of the room to which the user is invited.
+ inviter_user_id: The user ID of the inviter.
+ room_alias: An alias for the room, for cosmetic notifications.
+ room_avatar_url: The URL of the room's avatar, for cosmetic
notifications.
- room_join_rules (str): The join rules of the email (e.g. "public").
- room_name (str): The m.room.name of the room.
- inviter_display_name (str): The current display name of the
+ room_join_rules: The join rules of the email (e.g. "public").
+ room_name: The m.room.name of the room.
+ inviter_display_name: The current display name of the
inviter.
- inviter_avatar_url (str): The URL of the inviter's avatar.
+ inviter_avatar_url: The URL of the inviter's avatar.
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
- A deferred tuple containing:
- token (str): The token which must be signed to prove authenticity.
+ A tuple containing:
+ token: The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
- display_name (str): A user-friendly name to represent the invited
- user.
+ display_name: A user-friendly name to represent the invited user.
"""
invite_config = {
"medium": medium,
@@ -896,15 +915,15 @@ class IdentityHandler(BaseHandler):
return token, public_keys, fallback_public_key, display_name
-def create_id_access_token_header(id_access_token):
+def create_id_access_token_header(id_access_token: str) -> List[str]:
"""Create an Authorization header for passing to SimpleHttpClient as the header value
of an HTTP request.
Args:
- id_access_token (str): An identity server access token.
+ id_access_token: An identity server access token.
Returns:
- list[str]: The ascii-encoded bearer token encased in a list.
+ The ascii-encoded bearer token encased in a list.
"""
# Prefix with Bearer
bearer_token = "Bearer %s" % id_access_token
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da206e1e..e451d6dc 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,12 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from canonicaljson import encode_canonical_json, json
-from twisted.internet import defer
-from twisted.internet.defer import succeed
from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth
@@ -41,13 +39,22 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
+from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
-from synapse.types import Collection, RoomAlias, UserID, create_requester
+from synapse.types import (
+ Collection,
+ Requester,
+ RoomAlias,
+ StreamToken,
+ UserID,
+ create_requester,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
@@ -84,14 +91,22 @@ class MessageHandler(object):
"_schedule_next_expiry", self._schedule_next_expiry
)
- @defer.inlineCallbacks
- def get_room_data(
- self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
- ):
+ async def get_room_data(
+ self,
+ user_id: str = None,
+ room_id: str = None,
+ event_type: Optional[str] = None,
+ state_key: str = "",
+ is_guest: bool = False,
+ ) -> dict:
""" Get data from a room.
Args:
- event : The room path event
+ user_id
+ room_id
+ event_type
+ state_key
+ is_guest
Returns:
The path data content.
Raises:
@@ -100,30 +115,29 @@ class MessageHandler(object):
(
membership,
membership_event_id,
- ) = yield self.auth.check_user_in_room_or_world_readable(
+ ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if membership == Membership.JOIN:
- data = yield self.state.get_current_state(room_id, event_type, state_key)
+ data = await self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
- room_state = yield self.state_store.get_state_for_events(
+ room_state = await self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
return data
- @defer.inlineCallbacks
- def get_state_events(
+ async def get_state_events(
self,
- user_id,
- room_id,
- state_filter=StateFilter.all(),
- at_token=None,
- is_guest=False,
- ):
+ user_id: str,
+ room_id: str,
+ state_filter: StateFilter = StateFilter.all(),
+ at_token: Optional[StreamToken] = None,
+ is_guest: bool = False,
+ ) -> List[dict]:
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left. If an explicit
@@ -131,15 +145,14 @@ class MessageHandler(object):
visible.
Args:
- user_id(str): The user requesting state events.
- room_id(str): The room ID to get all state events from.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- at_token(StreamToken|None): the stream token of the at which we are requesting
+ user_id: The user requesting state events.
+ room_id: The room ID to get all state events from.
+ state_filter: The state filter used to fetch state from the database.
+ at_token: the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table.
- is_guest(bool): whether this user is a guest
+ is_guest: whether this user is a guest
Returns:
A list of dicts representing state events. [{}, {}, {}]
Raises:
@@ -153,20 +166,20 @@ class MessageHandler(object):
# get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
- last_events, _ = yield self.store.get_recent_events_for_room(
+ last_events, _ = await self.store.get_recent_events_for_room(
room_id, end_token=at_token.room_key, limit=1
)
if not last_events:
raise NotFoundError("Can't find event for token %s" % (at_token,))
- visible_events = yield filter_events_for_client(
+ visible_events = await filter_events_for_client(
self.storage, user_id, last_events, filter_send_to_client=False
)
event = last_events[0]
if visible_events:
- room_state = yield self.state_store.get_state_for_events(
+ room_state = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
@@ -180,23 +193,23 @@ class MessageHandler(object):
(
membership,
membership_event_id,
- ) = yield self.auth.check_user_in_room_or_world_readable(
+ ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if membership == Membership.JOIN:
- state_ids = yield self.store.get_filtered_current_state_ids(
+ state_ids = await self.store.get_filtered_current_state_ids(
room_id, state_filter=state_filter
)
- room_state = yield self.store.get_events(state_ids.values())
+ room_state = await self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
- room_state = yield self.state_store.get_state_for_events(
+ room_state = await self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
- events = yield self._event_serializer.serialize_events(
+ events = await self._event_serializer.serialize_events(
room_state.values(),
now,
# We don't bother bundling aggregations in when asked for state
@@ -205,15 +218,14 @@ class MessageHandler(object):
)
return events
- @defer.inlineCallbacks
- def get_joined_members(self, requester, room_id):
+ async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
"""Get all the joined members in the room and their profile information.
If the user has left the room return the state events from when they left.
Args:
- requester(Requester): The user requesting state events.
- room_id(str): The room ID to get all state events from.
+ requester: The user requesting state events.
+ room_id: The room ID to get all state events from.
Returns:
A dict of user_id to profile info
"""
@@ -221,7 +233,7 @@ class MessageHandler(object):
if not requester.app_service:
# We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway.
- membership, _ = yield self.auth.check_user_in_room_or_world_readable(
+ membership, _ = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if membership != Membership.JOIN:
@@ -229,7 +241,7 @@ class MessageHandler(object):
"Getting joined members after leaving is not implemented"
)
- users_with_profile = yield self.state.get_current_users_in_room(room_id)
+ users_with_profile = await self.state.get_current_users_in_room(room_id)
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there
@@ -250,7 +262,7 @@ class MessageHandler(object):
for user_id, profile in users_with_profile.items()
}
- def maybe_schedule_expiry(self, event):
+ def maybe_schedule_expiry(self, event: EventBase):
"""Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided
timestamp.
@@ -259,7 +271,7 @@ class MessageHandler(object):
the master process, and therefore needs to be run on there.
Args:
- event (EventBase): The event to schedule the expiry of.
+ event: The event to schedule the expiry of.
"""
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
@@ -270,8 +282,7 @@ class MessageHandler(object):
# a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts)
- @defer.inlineCallbacks
- def _schedule_next_expiry(self):
+ async def _schedule_next_expiry(self):
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it.
@@ -279,18 +290,18 @@ class MessageHandler(object):
future call to save_expiry_ts can schedule a new expiry task.
"""
# Try to get the expiry timestamp of the next event to expire.
- res = yield self.store.get_next_event_to_expire()
+ res = await self.store.get_next_event_to_expire()
if res:
event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts)
- def _schedule_expiry_for_event(self, event_id, expiry_ts):
+ def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
"""Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one.
Args:
- event_id (str): The ID of the event to expire.
- expiry_ts (int): The timestamp at which to expire the event.
+ event_id: The ID of the event to expire.
+ expiry_ts: The timestamp at which to expire the event.
"""
if self._scheduled_expiry:
# If the provided timestamp refers to a time before the scheduled time of the
@@ -320,8 +331,7 @@ class MessageHandler(object):
event_id,
)
- @defer.inlineCallbacks
- def _expire_event(self, event_id):
+ async def _expire_event(self, event_id: str):
"""Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date
@@ -336,12 +346,12 @@ class MessageHandler(object):
try:
# Expire the event if we know about it. This function also deletes the expiry
# date from the database in the same database transaction.
- yield self.store.expire_event(event_id)
+ await self.store.expire_event(event_id)
except Exception as e:
logger.error("Could not expire event %s: %r", event_id, e)
# Schedule the expiry of the next event to expire.
- yield self._schedule_next_expiry()
+ await self._schedule_next_expiry()
# The duration (in ms) after which rooms should be removed
@@ -423,16 +433,15 @@ class EventCreationHandler(object):
self._dummy_events_threshold = hs.config.dummy_events_threshold
- @defer.inlineCallbacks
- def create_event(
+ async def create_event(
self,
- requester,
- event_dict,
- token_id=None,
- txn_id=None,
+ requester: Requester,
+ event_dict: dict,
+ token_id: Optional[str] = None,
+ txn_id: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
- require_consent=True,
- ):
+ require_consent: bool = True,
+ ) -> Tuple[EventBase, EventContext]:
"""
Given a dict from a client, create a new event.
@@ -443,31 +452,29 @@ class EventCreationHandler(object):
Args:
requester
- event_dict (dict): An entire event
- token_id (str)
- txn_id (str)
-
+ event_dict: An entire event
+ token_id
+ txn_id
prev_event_ids:
the forward extremities to use as the prev_events for the
new event.
If None, they will be requested from the database.
-
- require_consent (bool): Whether to check if the requester has
- consented to privacy policy.
+ require_consent: Whether to check if the requester has
+ consented to the privacy policy.
Raises:
ResourceLimitError if server is blocked to some resource being
exceeded
Returns:
- Tuple of created event (FrozenEvent), Context
+ Tuple of created event, Context
"""
- yield self.auth.check_auth_blocking(requester.user.to_string())
+ await self.auth.check_auth_blocking(requester.user.to_string())
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
else:
try:
- room_version = yield self.store.get_room_version_id(
+ room_version = await self.store.get_room_version_id(
event_dict["room_id"]
)
except NotFoundError:
@@ -488,11 +495,11 @@ class EventCreationHandler(object):
try:
if "displayname" not in content:
- displayname = yield profile.get_displayname(target)
+ displayname = await profile.get_displayname(target)
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
- avatar_url = yield profile.get_avatar_url(target)
+ avatar_url = await profile.get_avatar_url(target)
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
@@ -500,9 +507,9 @@ class EventCreationHandler(object):
"Failed to get profile information for %r: %s", target, e
)
- is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
+ is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
if require_consent and not is_exempt:
- yield self.assert_accepted_privacy_policy(requester)
+ await self.assert_accepted_privacy_policy(requester)
if token_id is not None:
builder.internal_metadata.token_id = token_id
@@ -510,7 +517,7 @@ class EventCreationHandler(object):
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
- event, context = yield self.create_new_client_event(
+ event, context = await self.create_new_client_event(
builder=builder, requester=requester, prev_event_ids=prev_event_ids,
)
@@ -526,10 +533,10 @@ class EventCreationHandler(object):
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = (
- yield self.store.get_event(prev_event_id, allow_none=True)
+ await self.store.get_event(prev_event_id, allow_none=True)
if prev_event_id
else None
)
@@ -552,37 +559,36 @@ class EventCreationHandler(object):
return (event, context)
- def _is_exempt_from_privacy_policy(self, builder, requester):
+ async def _is_exempt_from_privacy_policy(
+ self, builder: EventBuilder, requester: Requester
+ ) -> bool:
""""Determine if an event to be sent is exempt from having to consent
to the privacy policy
Args:
- builder (synapse.events.builder.EventBuilder): event being created
- requester (Requster): user requesting this event
+ builder: event being created
+ requester: user requesting this event
Returns:
- Deferred[bool]: true if the event can be sent without the user
- consenting
+ true if the event can be sent without the user consenting
"""
# the only thing the user can do is join the server notices room.
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
- return self._is_server_notices_room(builder.room_id)
+ return await self._is_server_notices_room(builder.room_id)
elif membership == Membership.LEAVE:
# the user is always allowed to leave (but not kick people)
return builder.state_key == requester.user.to_string()
- return succeed(False)
+ return False
- @defer.inlineCallbacks
- def _is_server_notices_room(self, room_id):
+ async def _is_server_notices_room(self, room_id: str) -> bool:
if self.config.server_notices_mxid is None:
return False
- user_ids = yield self.store.get_users_in_room(room_id)
+ user_ids = await self.store.get_users_in_room(room_id)
return self.config.server_notices_mxid in user_ids
- @defer.inlineCallbacks
- def assert_accepted_privacy_policy(self, requester):
+ async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
"""Check if a user has accepted the privacy policy
Called when the given user is about to do something that requires
@@ -591,12 +597,10 @@ class EventCreationHandler(object):
raised.
Args:
- requester (synapse.types.Requester):
- The user making the request
+ requester: The user making the request
Returns:
- Deferred[None]: returns normally if the user has consented or is
- exempt
+ Returns normally if the user has consented or is exempt
Raises:
ConsentNotGivenError: if the user has not given consent yet
@@ -617,7 +621,7 @@ class EventCreationHandler(object):
):
return
- u = yield self.store.get_user_by_id(user_id)
+ u = await self.store.get_user_by_id(user_id)
assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
@@ -635,16 +639,20 @@ class EventCreationHandler(object):
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
async def send_nonmember_event(
- self, requester, event, context, ratelimit=True
+ self,
+ requester: Requester,
+ event: EventBase,
+ context: EventContext,
+ ratelimit: bool = True,
) -> int:
"""
Persists and notifies local clients and federation of an event.
Args:
- event (FrozenEvent) the event to send.
- context (Context) the context of the event.
- ratelimit (bool): Whether to rate limit this send.
- is_guest (bool): Whether the sender is a guest.
+ requester
+ event the event to send.
+ context: the context of the event.
+ ratelimit: Whether to rate limit this send.
Return:
The stream_id of the persisted event.
@@ -672,19 +680,20 @@ class EventCreationHandler(object):
requester=requester, event=event, context=context, ratelimit=ratelimit
)
- @defer.inlineCallbacks
- def deduplicate_state_event(self, event, context):
+ async def deduplicate_state_event(
+ self, event: EventBase, context: EventContext
+ ) -> None:
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@@ -696,7 +705,11 @@ class EventCreationHandler(object):
return
async def create_and_send_nonmember_event(
- self, requester, event_dict, ratelimit=True, txn_id=None
+ self,
+ requester: Requester,
+ event_dict: EventBase,
+ ratelimit: bool = True,
+ txn_id: Optional[str] = None,
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
@@ -726,17 +739,17 @@ class EventCreationHandler(object):
return event, stream_id
@measure_func("create_new_client_event")
- @defer.inlineCallbacks
- def create_new_client_event(
- self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None
- ):
+ async def create_new_client_event(
+ self,
+ builder: EventBuilder,
+ requester: Optional[Requester] = None,
+ prev_event_ids: Optional[Collection[str]] = None,
+ ) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
Args:
- builder (EventBuilder):
-
- requester (synapse.types.Requester|None):
-
+ builder:
+ requester:
prev_event_ids:
the forward extremities to use as the prev_events for the
new event.
@@ -744,7 +757,7 @@ class EventCreationHandler(object):
If None, they will be requested from the database.
Returns:
- Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
+ Tuple of created event, context
"""
if prev_event_ids is not None:
@@ -753,10 +766,10 @@ class EventCreationHandler(object):
% (len(prev_event_ids),)
)
else:
- prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id)
+ prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
- event = yield builder.build(prev_event_ids=prev_event_ids)
- context = yield self.state.compute_event_context(event)
+ event = await builder.build(prev_event_ids=prev_event_ids)
+ context = await self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
@@ -770,7 +783,7 @@ class EventCreationHandler(object):
relates_to = relation["event_id"]
aggregation_key = relation["key"]
- already_exists = yield self.store.has_user_annotated_event(
+ already_exists = await self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender
)
if already_exists:
@@ -782,7 +795,12 @@ class EventCreationHandler(object):
@measure_func("handle_new_client_event")
async def handle_new_client_event(
- self, requester, event, context, ratelimit=True, extra_users=[]
+ self,
+ requester: Requester,
+ event: EventBase,
+ context: EventContext,
+ ratelimit: bool = True,
+ extra_users: List[UserID] = [],
) -> int:
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@@ -791,11 +809,11 @@ class EventCreationHandler(object):
processing.
Args:
- requester (Requester)
- event (FrozenEvent)
- context (EventContext)
- ratelimit (bool)
- extra_users (list(UserID)): Any extra users to notify about event
+ requester
+ event
+ context
+ ratelimit
+ extra_users: Any extra users to notify about event
Return:
The stream_id of the persisted event.
@@ -839,9 +857,6 @@ class EventCreationHandler(object):
await self.action_generator.handle_push_actions_for_event(event, context)
- # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
- # hack around with a try/finally instead.
- success = False
try:
# If we're a worker we need to hit out to the master.
if not self._is_event_writer:
@@ -857,27 +872,24 @@ class EventCreationHandler(object):
)
stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id
- success = True
return stream_id
stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- success = True
return stream_id
- finally:
- if not success:
- # Ensure that we actually remove the entries in the push actions
- # staging area, if we calculated them.
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
+ except Exception:
+ # Ensure that we actually remove the entries in the push actions
+ # staging area, if we calculated them.
+ run_in_background(
+ self.store.remove_push_actions_from_staging, event.event_id
+ )
+ raise
- @defer.inlineCallbacks
- def _validate_canonical_alias(
- self, directory_handler, room_alias_str, expected_room_id
- ):
+ async def _validate_canonical_alias(
+ self, directory_handler, room_alias_str: str, expected_room_id: str
+ ) -> None:
"""
Ensure that the given room alias points to the expected room ID.
@@ -888,9 +900,7 @@ class EventCreationHandler(object):
"""
room_alias = RoomAlias.from_string(room_alias_str)
try:
- mapping = yield defer.ensureDeferred(
- directory_handler.get_association(room_alias)
- )
+ mapping = await directory_handler.get_association(room_alias)
except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND:
@@ -909,7 +919,12 @@ class EventCreationHandler(object):
)
async def persist_and_notify_client_event(
- self, requester, event, context, ratelimit=True, extra_users=[]
+ self,
+ requester: Requester,
+ event: EventBase,
+ context: EventContext,
+ ratelimit: bool = True,
+ extra_users: List[UserID] = [],
) -> int:
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@@ -1102,7 +1117,7 @@ class EventCreationHandler(object):
return event_stream_id
- async def _bump_active_time(self, user):
+ async def _bump_active_time(self, user: UserID) -> None:
try:
presence = self.hs.get_presence_handler()
await presence.bump_presence_active_time(user)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d2f25ae1..b3a3bb8c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -30,8 +30,6 @@ from typing import Dict, Iterable, List, Set, Tuple
from prometheus_client import Counter
from typing_extensions import ContextManager
-from twisted.internet import defer
-
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
@@ -39,6 +37,8 @@ from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateHandler
+from synapse.storage.data_stores.main import DataStore
from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
@@ -895,16 +895,9 @@ class PresenceHandler(BasePresenceHandler):
await self._on_user_joined_room(room_id, state_key)
- async def _on_user_joined_room(self, room_id, user_id):
+ async def _on_user_joined_room(self, room_id: str, user_id: str) -> None:
"""Called when we detect a user joining the room via the current state
delta stream.
-
- Args:
- room_id (str)
- user_id (str)
-
- Returns:
- Deferred
"""
if self.is_mine_id(user_id):
@@ -935,8 +928,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the
# room.
- user_ids = await self.state.get_current_users_in_room(room_id)
- user_ids = list(filter(self.is_mine_id, user_ids))
+ users = await self.state.get_current_users_in_room(room_id)
+ user_ids = list(filter(self.is_mine_id, users))
states_d = await self.current_state_for_users(user_ids)
@@ -1296,22 +1289,24 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
return new_state, persist_and_notify, federation_ping
-@defer.inlineCallbacks
-def get_interested_parties(store, states):
+async def get_interested_parties(
+ store: DataStore, states: List[UserPresenceState]
+) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
"""Given a list of states return which entities (rooms, users)
are interested in the given states.
Args:
- states (list(UserPresenceState))
+ store
+ states
Returns:
- 2-tuple: `(room_ids_to_states, users_to_states)`,
+ A 2-tuple of `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]`
"""
room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
users_to_states = {} # type: Dict[str, List[UserPresenceState]]
for state in states:
- room_ids = yield store.get_rooms_for_user(state.user_id)
+ room_ids = await store.get_rooms_for_user(state.user_id)
for room_id in room_ids:
room_ids_to_states.setdefault(room_id, []).append(state)
@@ -1321,20 +1316,22 @@ def get_interested_parties(store, states):
return room_ids_to_states, users_to_states
-@defer.inlineCallbacks
-def get_interested_remotes(store, states, state_handler):
+async def get_interested_remotes(
+ store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
+) -> List[Tuple[List[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
All the presence states should be for local users only.
Args:
- store (DataStore)
- states (list(UserPresenceState))
+ store
+ states
+ state_handler
Returns:
- Deferred list of ([destinations], [UserPresenceState]), where for
- each row the list of UserPresenceState should be sent to each
+ A list of 2-tuples of destinations and states, where for
+ each tuple the list of UserPresenceState should be sent to each
destination
"""
hosts_and_states = []
@@ -1342,10 +1339,10 @@ def get_interested_remotes(store, states, state_handler):
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
# hosts in those rooms.
- room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
+ room_ids_to_states, users_to_states = await get_interested_parties(store, states)
for room_id, states in room_ids_to_states.items():
- hosts = yield state_handler.get_current_hosts_in_room(room_id)
+ hosts = await state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
for user_id, states in users_to_states.items():
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4b1e3073..31a2e5ea 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import (
AuthError,
Codes,
@@ -54,16 +52,15 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
- @defer.inlineCallbacks
- def get_profile(self, user_id):
+ async def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
try:
- displayname = yield self.store.get_profile_displayname(
+ displayname = await self.store.get_profile_displayname(
target_user.localpart
)
- avatar_url = yield self.store.get_profile_avatar_url(
+ avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -74,7 +71,7 @@ class BaseProfileHandler(BaseHandler):
return {"displayname": displayname, "avatar_url": avatar_url}
else:
try:
- result = yield self.federation.make_query(
+ result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": user_id},
@@ -86,8 +83,7 @@ class BaseProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- @defer.inlineCallbacks
- def get_profile_from_cache(self, user_id):
+ async def get_profile_from_cache(self, user_id):
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
@@ -95,10 +91,10 @@ class BaseProfileHandler(BaseHandler):
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
try:
- displayname = yield self.store.get_profile_displayname(
+ displayname = await self.store.get_profile_displayname(
target_user.localpart
)
- avatar_url = yield self.store.get_profile_avatar_url(
+ avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -108,14 +104,13 @@ class BaseProfileHandler(BaseHandler):
return {"displayname": displayname, "avatar_url": avatar_url}
else:
- profile = yield self.store.get_from_remote_profile_cache(user_id)
+ profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}
- @defer.inlineCallbacks
- def get_displayname(self, target_user):
+ async def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
try:
- displayname = yield self.store.get_profile_displayname(
+ displayname = await self.store.get_profile_displayname(
target_user.localpart
)
except StoreError as e:
@@ -126,7 +121,7 @@ class BaseProfileHandler(BaseHandler):
return displayname
else:
try:
- result = yield self.federation.make_query(
+ result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": "displayname"},
@@ -189,11 +184,10 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
- @defer.inlineCallbacks
- def get_avatar_url(self, target_user):
+ async def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user):
try:
- avatar_url = yield self.store.get_profile_avatar_url(
+ avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -203,7 +197,7 @@ class BaseProfileHandler(BaseHandler):
return avatar_url
else:
try:
- result = yield self.federation.make_query(
+ result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": "avatar_url"},
@@ -253,8 +247,7 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
- @defer.inlineCallbacks
- def on_profile_query(self, args):
+ async def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -264,12 +257,12 @@ class BaseProfileHandler(BaseHandler):
response = {}
try:
if just_field is None or just_field == "displayname":
- response["displayname"] = yield self.store.get_profile_displayname(
+ response["displayname"] = await self.store.get_profile_displayname(
user.localpart
)
if just_field is None or just_field == "avatar_url":
- response["avatar_url"] = yield self.store.get_profile_avatar_url(
+ response["avatar_url"] = await self.store.get_profile_avatar_url(
user.localpart
)
except StoreError as e:
@@ -304,8 +297,7 @@ class BaseProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e)
)
- @defer.inlineCallbacks
- def check_profile_query_allowed(self, target_user, requester=None):
+ async def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
@@ -337,8 +329,8 @@ class BaseProfileHandler(BaseHandler):
return
try:
- requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
- target_user_rooms = yield self.store.get_rooms_for_user(
+ requester_rooms = await self.store.get_rooms_for_user(requester.to_string())
+ target_user_rooms = await self.store.get_rooms_for_user(
target_user.to_string()
)
@@ -371,25 +363,24 @@ class MasterProfileHandler(BaseProfileHandler):
"Update remote profile", self._update_remote_profile_cache
)
- @defer.inlineCallbacks
- def _update_remote_profile_cache(self):
+ async def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
- entries = yield self.store.get_remote_profile_cache_entries_that_expire(
+ entries = await self.store.get_remote_profile_cache_entries_that_expire(
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
)
for user_id, displayname, avatar_url in entries:
- is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
+ is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
user_id
)
if not is_subscribed:
- yield self.store.maybe_delete_remote_profile_cache(user_id)
+ await self.store.maybe_delete_remote_profile_cache(user_id)
continue
try:
- profile = yield self.federation.make_query(
+ profile = await self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
args={"user_id": user_id},
@@ -398,7 +389,7 @@ class MasterProfileHandler(BaseProfileHandler):
except Exception:
logger.exception("Failed to get avatar_url")
- yield self.store.update_remote_profile_cache(
+ await self.store.update_remote_profile_cache(
user_id, displayname, avatar_url
)
continue
@@ -407,4 +398,4 @@ class MasterProfileHandler(BaseProfileHandler):
new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp
- yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
+ await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 8bc100db..f922d8a5 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
@@ -129,15 +127,14 @@ class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key)
- to_key = yield self.get_current_key()
+ to_key = self.get_current_key()
if from_key == to_key:
return [], to_key
- events = yield self.store.get_linearized_receipts_for_rooms(
+ events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
@@ -146,8 +143,7 @@ class ReceiptEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
+ async def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
if config.to_key:
@@ -155,8 +151,8 @@ class ReceiptEventSource(object):
else:
from_key = None
- room_ids = yield self.store.get_rooms_for_user(user.to_string())
- events = yield self.store.get_linearized_receipts_for_rooms(
+ room_ids = await self.store.get_rooms_for_user(user.to_string())
+ events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 78c3772a..501f0fe7 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -28,7 +28,6 @@ from synapse.replication.http.register import (
)
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
-from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@@ -50,14 +49,7 @@ class RegistrationHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
-
- self._next_generated_user_id = None
-
self.macaroon_gen = hs.get_macaroon_generator()
-
- self._generate_user_id_linearizer = Linearizer(
- name="_generate_user_id_linearizer"
- )
self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
@@ -219,7 +211,7 @@ class RegistrationHandler(BaseHandler):
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
- localpart = await self._generate_user_id()
+ localpart = await self.store.generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
@@ -510,18 +502,6 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- async def _generate_user_id(self):
- if self._next_generated_user_id is None:
- with await self._generate_user_id_linearizer.queue(()):
- if self._next_generated_user_id is None:
- self._next_generated_user_id = (
- await self.store.find_next_generated_user_id_localpart()
- )
-
- id = self._next_generated_user_id
- self._next_generated_user_id += 1
- return str(id)
-
def check_registration_ratelimit(self, address):
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 950a84ac..0c5b9923 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -22,11 +22,12 @@ import logging
import math
import string
from collections import OrderedDict
-from typing import Tuple
+from typing import Optional, Tuple
from synapse.api.constants import (
EventTypes,
JoinRules,
+ Membership,
RoomCreationPreset,
RoomEncryptionAlgorithms,
)
@@ -43,9 +44,10 @@ from synapse.types import (
StateMap,
StreamToken,
UserID,
+ create_requester,
)
from synapse.util import stringutils
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, maybe_awaitable
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
@@ -117,7 +119,7 @@ class RoomCreationHandler(BaseHandler):
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
- ):
+ ) -> str:
"""Replace a room with a new room with a different version
Args:
@@ -126,7 +128,7 @@ class RoomCreationHandler(BaseHandler):
new_version: the new room version to use
Returns:
- Deferred[unicode]: the new room id
+ the new room id
"""
await self.ratelimit(requester)
@@ -237,7 +239,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
- ):
+ ) -> None:
"""Send updated power levels in both rooms after an upgrade
Args:
@@ -245,9 +247,6 @@ class RoomCreationHandler(BaseHandler):
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
-
- Returns:
- Deferred
"""
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
@@ -320,7 +319,7 @@ class RoomCreationHandler(BaseHandler):
new_room_id: str,
new_room_version: RoomVersion,
tombstone_event_id: str,
- ):
+ ) -> None:
"""Populate a new room based on an old room
Args:
@@ -330,8 +329,6 @@ class RoomCreationHandler(BaseHandler):
created with _gemerate_room_id())
new_room_version: the new room version to use
tombstone_event_id: the ID of the tombstone event in the old room.
- Returns:
- Deferred
"""
user_id = requester.user.to_string()
@@ -1089,3 +1086,205 @@ class RoomEventSource(object):
def get_current_key_for_room(self, room_id):
return self.store.get_room_events_max_id(room_id)
+
+
+class RoomShutdownHandler(object):
+
+ DEFAULT_MESSAGE = (
+ "Sharing illegal content on this server is not permitted and rooms in"
+ " violation will be blocked."
+ )
+ DEFAULT_ROOM_NAME = "Content Violation Notification"
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.room_member_handler = hs.get_room_member_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self._replication = hs.get_replication_data_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.state = hs.get_state_handler()
+ self.store = hs.get_datastore()
+
+ async def shutdown_room(
+ self,
+ room_id: str,
+ requester_user_id: str,
+ new_room_user_id: Optional[str] = None,
+ new_room_name: Optional[str] = None,
+ message: Optional[str] = None,
+ block: bool = False,
+ ) -> dict:
+ """
+ Shuts down a room. Moves all local users and room aliases automatically
+ to a new room if `new_room_user_id` is set. Otherwise local users only
+ leave the room without any information.
+
+ The new room will be created with the user specified by the
+ `new_room_user_id` parameter as room administrator and will contain a
+ message explaining what happened. Users invited to the new room will
+ have power level `-10` by default, and thus be unable to speak.
+
+ The local server will only have the power to move local user and room
+ aliases to the new room. Users on other servers will be unaffected.
+
+ Args:
+ room_id: The ID of the room to shut down.
+ requester_user_id:
+ User who requested the action and put the room on the
+ blocking list.
+ new_room_user_id:
+ If set, a new room will be created with this user ID
+ as the creator and admin, and all users in the old room will be
+ moved into that room. If not set, no new room will be created
+ and the users will just be removed from the old room.
+ new_room_name:
+ A string representing the name of the room that new users will
+ be invited to. Defaults to `Content Violation Notification`
+ message:
+ A string containing the first message that will be sent as
+ `new_room_user_id` in the new room. Ideally this will clearly
+ convey why the original room was shut down.
+ Defaults to `Sharing illegal content on this server is not
+ permitted and rooms in violation will be blocked.`
+ block:
+ If set to `true`, this room will be added to a blocking list,
+ preventing future attempts to join the room. Defaults to `false`.
+
+ Returns: a dict containing the following keys:
+ kicked_users: An array of users (`user_id`) that were kicked.
+ failed_to_kick_users:
+ An array of users (`user_id`) that that were not kicked.
+ local_aliases:
+ An array of strings representing the local aliases that were
+ migrated from the old room to the new.
+ new_room_id: A string representing the room ID of the new room.
+ """
+
+ if not new_room_name:
+ new_room_name = self.DEFAULT_ROOM_NAME
+ if not message:
+ message = self.DEFAULT_MESSAGE
+
+ if not RoomID.is_valid(room_id):
+ raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
+
+ if not await self.store.get_room(room_id):
+ raise NotFoundError("Unknown room id %s" % (room_id,))
+
+ # This will work even if the room is already blocked, but that is
+ # desirable in case the first attempt at blocking the room failed below.
+ if block:
+ await self.store.block_room(room_id, requester_user_id)
+
+ if new_room_user_id is not None:
+ if not self.hs.is_mine_id(new_room_user_id):
+ raise SynapseError(
+ 400, "User must be our own: %s" % (new_room_user_id,)
+ )
+
+ room_creator_requester = create_requester(new_room_user_id)
+
+ info, stream_id = await self._room_creation_handler.create_room(
+ room_creator_requester,
+ config={
+ "preset": RoomCreationPreset.PUBLIC_CHAT,
+ "name": new_room_name,
+ "power_level_content_override": {"users_default": -10},
+ },
+ ratelimit=False,
+ )
+ new_room_id = info["room_id"]
+
+ logger.info(
+ "Shutting down room %r, joining to new room: %r", room_id, new_room_id
+ )
+
+ # We now wait for the create room to come back in via replication so
+ # that we can assume that all the joins/invites have propogated before
+ # we try and auto join below.
+ #
+ # TODO: Currently the events stream is written to from master
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", stream_id
+ )
+ else:
+ new_room_id = None
+ logger.info("Shutting down room %r", room_id)
+
+ users = await self.state.get_current_users_in_room(room_id)
+ kicked_users = []
+ failed_to_kick_users = []
+ for user_id in users:
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ logger.info("Kicking %r from %r...", user_id, room_id)
+
+ try:
+ # Kick users from room
+ target_requester = create_requester(user_id)
+ _, stream_id = await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=room_id,
+ action=Membership.LEAVE,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ # Wait for leave to come in over replication before trying to forget.
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", stream_id
+ )
+
+ await self.room_member_handler.forget(target_requester.user, room_id)
+
+ # Join users to new room
+ if new_room_user_id:
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=new_room_id,
+ action=Membership.JOIN,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ kicked_users.append(user_id)
+ except Exception:
+ logger.exception(
+ "Failed to leave old room and join new room for %r", user_id
+ )
+ failed_to_kick_users.append(user_id)
+
+ # Send message in new room and move aliases
+ if new_room_user_id:
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ room_creator_requester,
+ {
+ "type": "m.room.message",
+ "content": {"body": message, "msgtype": "m.text"},
+ "room_id": new_room_id,
+ "sender": new_room_user_id,
+ },
+ ratelimit=False,
+ )
+
+ aliases_for_room = await maybe_awaitable(
+ self.store.get_aliases_for_room(room_id)
+ )
+
+ await self.store.update_aliases_for_room(
+ room_id, new_room_id, requester_user_id
+ )
+ else:
+ aliases_for_room = []
+
+ return {
+ "kicked_users": kicked_users,
+ "failed_to_kick_users": failed_to_kick_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ }
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5e05be61..5dd7b283 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,12 +20,10 @@ from typing import Any, Dict, Optional
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
@@ -47,7 +45,7 @@ class RoomListHandler(BaseHandler):
hs, "remote_room_list", timeout_ms=30 * 1000
)
- def get_local_public_room_list(
+ async def get_local_public_room_list(
self,
limit=None,
since_token=None,
@@ -72,7 +70,7 @@ class RoomListHandler(BaseHandler):
API
"""
if not self.enable_room_list_search:
- return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
+ return {"chunk": [], "total_room_count_estimate": 0}
logger.info(
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
@@ -87,7 +85,7 @@ class RoomListHandler(BaseHandler):
# appservice specific lists.
logger.info("Bypassing cache as search request.")
- return self._get_public_room_list(
+ return await self._get_public_room_list(
limit,
since_token,
search_filter,
@@ -96,7 +94,7 @@ class RoomListHandler(BaseHandler):
)
key = (limit, since_token, network_tuple)
- return self.response_cache.wrap(
+ return await self.response_cache.wrap(
key,
self._get_public_room_list,
limit,
@@ -105,8 +103,7 @@ class RoomListHandler(BaseHandler):
from_federation=from_federation,
)
- @defer.inlineCallbacks
- def _get_public_room_list(
+ async def _get_public_room_list(
self,
limit: Optional[int] = None,
since_token: Optional[str] = None,
@@ -145,7 +142,7 @@ class RoomListHandler(BaseHandler):
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
- results = yield self.store.get_largest_public_rooms(
+ results = await self.store.get_largest_public_rooms(
network_tuple,
search_filter,
probing_limit,
@@ -221,44 +218,44 @@ class RoomListHandler(BaseHandler):
response["chunk"] = results
- response["total_room_count_estimate"] = yield self.store.count_public_rooms(
+ response["total_room_count_estimate"] = await self.store.count_public_rooms(
network_tuple, ignore_non_federatable=from_federation
)
return response
- @cachedInlineCallbacks(num_args=1, cache_context=True)
- def generate_room_entry(
+ @cached(num_args=1, cache_context=True)
+ async def generate_room_entry(
self,
- room_id,
- num_joined_users,
+ room_id: str,
+ num_joined_users: int,
cache_context,
- with_alias=True,
- allow_private=False,
- ):
+ with_alias: bool = True,
+ allow_private: bool = False,
+ ) -> Optional[dict]:
"""Returns the entry for a room
Args:
- room_id (str): The room's ID.
- num_joined_users (int): Number of users in the room.
+ room_id: The room's ID.
+ num_joined_users: Number of users in the room.
cache_context: Information for cached responses.
- with_alias (bool): Whether to return the room's aliases in the result.
- allow_private (bool): Whether invite-only rooms should be shown.
+ with_alias: Whether to return the room's aliases in the result.
+ allow_private: Whether invite-only rooms should be shown.
Returns:
- Deferred[dict|None]: Returns a room entry as a dictionary, or None if this
+ Returns a room entry as a dictionary, or None if this
room was determined not to be shown publicly.
"""
result = {"room_id": room_id, "num_joined_members": num_joined_users}
if with_alias:
- aliases = yield self.store.get_aliases_for_room(
+ aliases = await self.store.get_aliases_for_room(
room_id, on_invalidate=cache_context.invalidate
)
if aliases:
result["aliases"] = aliases
- current_state_ids = yield self.store.get_current_state_ids(
+ current_state_ids = await self.store.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
@@ -266,7 +263,7 @@ class RoomListHandler(BaseHandler):
# We're not in the room, so may as well bail out here.
return result
- event_map = yield self.store.get_events(
+ event_map = await self.store.get_events(
[
event_id
for key, event_id in current_state_ids.items()
@@ -336,8 +333,7 @@ class RoomListHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def get_remote_public_room_list(
+ async def get_remote_public_room_list(
self,
server_name,
limit=None,
@@ -356,7 +352,7 @@ class RoomListHandler(BaseHandler):
# to a locally-filtered search if we must.
try:
- res = yield self._get_remote_list_cached(
+ res = await self._get_remote_list_cached(
server_name,
limit=limit,
since_token=since_token,
@@ -381,7 +377,7 @@ class RoomListHandler(BaseHandler):
limit = None
since_token = None
- res = yield self._get_remote_list_cached(
+ res = await self._get_remote_list_cached(
server_name,
limit=limit,
since_token=since_token,
@@ -400,7 +396,7 @@ class RoomListHandler(BaseHandler):
return res
- def _get_remote_list_cached(
+ async def _get_remote_list_cached(
self,
server_name,
limit=None,
@@ -412,7 +408,7 @@ class RoomListHandler(BaseHandler):
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
- return repl_layer.get_public_rooms(
+ return await repl_layer.get_public_rooms(
server_name,
limit=limit,
since_token=since_token,
@@ -428,7 +424,7 @@ class RoomListHandler(BaseHandler):
include_all_networks,
third_party_instance_id,
)
- return self.remote_response_cache.wrap(
+ return await self.remote_response_cache.wrap(
key,
repl_layer.get_public_rooms,
server_name,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 4d40d3ac..9b312a15 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,6 +15,7 @@
import itertools
import logging
+from typing import Iterable
from unpaddedbase64 import decode_base64, encode_base64
@@ -37,7 +38,7 @@ class SearchHandler(BaseHandler):
self.state_store = self.storage.state
self.auth = hs.get_auth()
- async def get_old_rooms_from_upgraded_room(self, room_id):
+ async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
@@ -48,10 +49,10 @@ class SearchHandler(BaseHandler):
The full list of all found rooms in then returned.
Args:
- room_id (str): id of the room to search through.
+ room_id: id of the room to search through.
Returns:
- Deferred[iterable[str]]: predecessor room ids
+ Predecessor room ids
"""
historical_room_ids = []
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4c752449..ebd3e981 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -283,6 +283,7 @@ class SyncHandler(object):
timeout,
full_state,
)
+ logger.debug("Returning sync response for %s", user_id)
return res
async def _wait_for_sync_for_user(
@@ -420,10 +421,6 @@ class SyncHandler(object):
potential_recents: Optional[List[EventBase]] = None,
newly_joined_room: bool = False,
) -> TimelineBatch:
- """
- Returns:
- a Deferred TimelineBatch
- """
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
block_all_timeline = (
@@ -990,10 +987,14 @@ class SyncHandler(object):
joined_room_ids=joined_room_ids,
)
+ logger.debug("Fetching account data")
+
account_data_by_room = await self._generate_sync_entry_for_account_data(
sync_result_builder
)
+ logger.debug("Fetching room data")
+
res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
@@ -1004,10 +1005,12 @@ class SyncHandler(object):
since_token is None and sync_config.filter_collection.blocks_all_presence()
)
if self.hs_config.use_presence and not block_all_presence_data:
+ logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
)
+ logger.debug("Fetching to-device data")
await self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = await self._generate_sync_entry_for_device_list(
@@ -1018,6 +1021,7 @@ class SyncHandler(object):
newly_left_users=newly_left_users,
)
+ logger.debug("Fetching OTK data")
device_id = sync_config.device_id
one_time_key_counts = {} # type: JsonDict
if device_id:
@@ -1025,6 +1029,7 @@ class SyncHandler(object):
user_id, device_id
)
+ logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
# debug for https://github.com/matrix-org/synapse/issues/4422
@@ -1035,6 +1040,7 @@ class SyncHandler(object):
"Sync result for newly joined room %s: %r", room_id, joined_room
)
+ logger.debug("Sync response calculation complete")
return SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@@ -1407,8 +1413,9 @@ class SyncHandler(object):
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
- def handle_room_entries(room_entry):
- return self._generate_room_entry(
+ async def handle_room_entries(room_entry):
+ logger.debug("Generating room entry for %s", room_entry.room_id)
+ res = await self._generate_room_entry(
sync_result_builder,
ignored_users,
room_entry,
@@ -1417,6 +1424,8 @@ class SyncHandler(object):
account_data=account_data_by_room.get(room_entry.room_id, {}),
always_include=sync_result_builder.full_state,
)
+ logger.debug("Generated room entry for %s", room_entry.room_id)
+ return res
await concurrently_execute(handle_room_entries, room_entries, 10)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 879c4c07..a86ac015 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,15 +15,19 @@
import logging
from collections import namedtuple
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, SynapseError
-from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -39,48 +43,48 @@ FEDERATION_TIMEOUT = 60 * 1000
FEDERATION_PING_INTERVAL = 40 * 1000
-class TypingHandler(object):
- def __init__(self, hs):
+class FollowerTypingHandler:
+ """A typing handler on a different process than the writer that is updated
+ via replication.
+ """
+
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.server_name = hs.config.server_name
- self.auth = hs.get_auth()
- self.is_mine_id = hs.is_mine_id
- self.notifier = hs.get_notifier()
- self.state = hs.get_state_handler()
-
- self.hs = hs
-
self.clock = hs.get_clock()
- self.wheel_timer = WheelTimer(bucket_size=5000)
+ self.is_mine_id = hs.is_mine_id
- self.federation = hs.get_federation_sender()
+ self.federation = None
+ if hs.should_send_federation():
+ self.federation = hs.get_federation_sender()
- hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+ if hs.config.worker.writers.typing != hs.get_instance_name():
+ hs.get_federation_registry().register_instance_for_edu(
+ "m.typing", hs.config.worker.writers.typing,
+ )
- hs.get_distributor().observe("user_left_room", self.user_left_room)
+ # map room IDs to serial numbers
+ self._room_serials = {}
+ # map room IDs to sets of users currently typing
+ self._room_typing = {}
- self._member_typing_until = {} # clock time we expect to stop
self._member_last_federation_poke = {}
-
+ self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
- self._reset()
-
- # caches which room_ids changed at which serials
- self._typing_stream_change_cache = StreamChangeCache(
- "TypingStreamChangeCache", self._latest_room_serial
- )
self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self):
- """
- Reset the typing handler's data caches.
+ """Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
+ self._member_last_federation_poke = {}
+ self.wheel_timer = WheelTimer(bucket_size=5000)
+
def _handle_timeouts(self):
logger.debug("Checking for typing timeouts")
@@ -89,30 +93,140 @@ class TypingHandler(object):
members = set(self.wheel_timer.fetch(now))
for member in members:
- if not self.is_typing(member):
- # Nothing to do if they're no longer typing
- continue
-
- until = self._member_typing_until.get(member, None)
- if not until or until <= now:
- logger.info("Timing out typing for: %s", member.user_id)
- self._stopped_typing(member)
- continue
-
- # Check if we need to resend a keep alive over federation for this
- # user.
- if self.hs.is_mine_id(member.user_id):
- last_fed_poke = self._member_last_federation_poke.get(member, None)
- if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- run_in_background(self._push_remote, member=member, typing=True)
-
- # Add a paranoia timer to ensure that we always have a timer for
- # each person typing.
- self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
+ self._handle_timeout_for_member(now, member)
+
+ def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ if not self.is_typing(member):
+ # Nothing to do if they're no longer typing
+ return
+
+ # Check if we need to resend a keep alive over federation for this
+ # user.
+ if self.federation and self.is_mine_id(member.user_id):
+ last_fed_poke = self._member_last_federation_poke.get(member, None)
+ if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
+ run_as_background_process(
+ "typing._push_remote", self._push_remote, member=member, typing=True
+ )
+
+ # Add a paranoia timer to ensure that we always have a timer for
+ # each person typing.
+ self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
+ async def _push_remote(self, member, typing):
+ if not self.federation:
+ return
+
+ try:
+ users = await self.store.get_users_in_room(member.room_id)
+ self._member_last_federation_poke[member] = self.clock.time_msec()
+
+ now = self.clock.time_msec()
+ self.wheel_timer.insert(
+ now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
+ )
+
+ for domain in {get_domain_from_id(u) for u in users}:
+ if domain != self.server_name:
+ logger.debug("sending typing update to %s", domain)
+ self.federation.build_and_send_edu(
+ destination=domain,
+ edu_type="m.typing",
+ content={
+ "room_id": member.room_id,
+ "user_id": member.user_id,
+ "typing": typing,
+ },
+ key=member,
+ )
+ except Exception:
+ logger.exception("Error pushing typing notif to remotes")
+
+ def process_replication_rows(
+ self, token: int, rows: List[TypingStream.TypingStreamRow]
+ ):
+ """Should be called whenever we receive updates for typing stream.
+ """
+
+ if self._latest_room_serial > token:
+ # The master has gone backwards. To prevent inconsistent data, just
+ # clear everything.
+ self._reset()
+
+ # Set the latest serial token to whatever the server gave us.
+ self._latest_room_serial = token
+
+ for row in rows:
+ self._room_serials[row.room_id] = token
+
+ prev_typing = set(self._room_typing.get(row.room_id, []))
+ now_typing = set(row.user_ids)
+ self._room_typing[row.room_id] = row.user_ids
+
+ run_as_background_process(
+ "_handle_change_in_typing",
+ self._handle_change_in_typing,
+ row.room_id,
+ prev_typing,
+ now_typing,
+ )
+
+ async def _handle_change_in_typing(
+ self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
+ ):
+ """Process a change in typing of a room from replication, sending EDUs
+ for any local users.
+ """
+ for user_id in now_typing - prev_typing:
+ if self.is_mine_id(user_id):
+ await self._push_remote(RoomMember(room_id, user_id), True)
+
+ for user_id in prev_typing - now_typing:
+ if self.is_mine_id(user_id):
+ await self._push_remote(RoomMember(room_id, user_id), False)
+
+ def get_current_token(self):
+ return self._latest_room_serial
+
+
+class TypingWriterHandler(FollowerTypingHandler):
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ assert hs.config.worker.writers.typing == hs.get_instance_name()
+
+ self.auth = hs.get_auth()
+ self.notifier = hs.get_notifier()
+
+ self.hs = hs
+
+ hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+
+ hs.get_distributor().observe("user_left_room", self.user_left_room)
+
+ self._member_typing_until = {} # clock time we expect to stop
+
+ # caches which room_ids changed at which serials
+ self._typing_stream_change_cache = StreamChangeCache(
+ "TypingStreamChangeCache", self._latest_room_serial
+ )
+
+ def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ super()._handle_timeout_for_member(now, member)
+
+ if not self.is_typing(member):
+ # Nothing to do if they're no longer typing
+ return
+
+ until = self._member_typing_until.get(member, None)
+ if not until or until <= now:
+ logger.info("Timing out typing for: %s", member.user_id)
+ self._stopped_typing(member)
+ return
+
async def started_typing(self, target_user, auth_user, room_id, timeout):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
@@ -179,35 +293,11 @@ class TypingHandler(object):
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
- run_in_background(self._push_remote, member, typing)
-
- self._push_update_local(member=member, typing=typing)
-
- async def _push_remote(self, member, typing):
- try:
- users = await self.state.get_current_users_in_room(member.room_id)
- self._member_last_federation_poke[member] = self.clock.time_msec()
-
- now = self.clock.time_msec()
- self.wheel_timer.insert(
- now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
+ run_as_background_process(
+ "typing._push_remote", self._push_remote, member, typing
)
- for domain in {get_domain_from_id(u) for u in users}:
- if domain != self.server_name:
- logger.debug("sending typing update to %s", domain)
- self.federation.build_and_send_edu(
- destination=domain,
- edu_type="m.typing",
- content={
- "room_id": member.room_id,
- "user_id": member.user_id,
- "typing": typing,
- },
- key=member,
- )
- except Exception:
- logger.exception("Error pushing typing notif to remotes")
+ self._push_update_local(member=member, typing=typing)
async def _recv_edu(self, origin, content):
room_id = content["room_id"]
@@ -224,7 +314,7 @@ class TypingHandler(object):
)
return
- users = await self.state.get_current_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
domains = {get_domain_from_id(u) for u in users}
if self.server_name in domains:
@@ -304,8 +394,11 @@ class TypingHandler(object):
return rows, current_id, limited
- def get_current_token(self):
- return self._latest_room_serial
+ def process_replication_rows(
+ self, token: int, rows: List[TypingStream.TypingStreamRow]
+ ):
+ # The writing process should never get updates from replication.
+ raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource(object):
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8b24a733..a011e9fe 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
+from typing import Any
from canonicaljson import json
-from twisted.internet import defer
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
@@ -32,25 +33,25 @@ class UserInteractiveAuthChecker:
def __init__(self, hs):
pass
- def is_enabled(self):
+ def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work
Returns:
- bool: True if this login type is enabled.
+ True if this login type is enabled.
"""
- def check_auth(self, authdict, clientip):
+ async def check_auth(self, authdict: dict, clientip: str) -> Any:
"""Given the authentication dict from the client, attempt to check this step
Args:
- authdict (dict): authentication dictionary from the client
- clientip (str): The IP address of the client.
+ authdict: authentication dictionary from the client
+ clientip: The IP address of the client.
Raises:
SynapseError if authentication failed
Returns:
- Deferred: the result of authentication (to pass back to the client?)
+ The result of authentication (to pass back to the client?)
"""
raise NotImplementedError()
@@ -61,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return True
- def check_auth(self, authdict, clientip):
- return defer.succeed(True)
+ async def check_auth(self, authdict, clientip):
+ return True
class TermsAuthChecker(UserInteractiveAuthChecker):
@@ -71,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return True
- def check_auth(self, authdict, clientip):
- return defer.succeed(True)
+ async def check_auth(self, authdict, clientip):
+ return True
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
@@ -88,8 +89,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return self._enabled
- @defer.inlineCallbacks
- def check_auth(self, authdict, clientip):
+ async def check_auth(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
@@ -106,7 +106,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
- resp_body = yield self._http_client.post_urlencoded_get_json(
+ resp_body = await self._http_client.post_urlencoded_get_json(
self._url,
args={
"secret": self._secret,
@@ -117,7 +117,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
- resp_body = json.loads(data)
+ resp_body = json.loads(data.decode("utf-8"))
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
@@ -218,8 +218,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
ThreepidBehaviour.LOCAL,
)
- def check_auth(self, authdict, clientip):
- return defer.ensureDeferred(self._check_threepid("email", authdict))
+ async def check_auth(self, authdict, clientip):
+ return await self._check_threepid("email", authdict)
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@@ -232,8 +232,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
def is_enabled(self):
return bool(self.hs.config.account_threepid_delegate_msisdn)
- def check_auth(self, authdict, clientip):
- return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
+ async def check_auth(self, authdict, clientip):
+ return await self._check_threepid("msisdn", authdict)
INTERACTIVE_AUTH_CHECKERS = [
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8743e983..6bc51202 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -31,6 +31,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IResolutionReceiver,
)
+from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
@@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
return False
+_EPSILON = 0.00000001
+
+
+def _make_scheduler(reactor):
+ """Makes a schedular suitable for a Cooperator using the given reactor.
+
+ (This is effectively just a copy from `twisted.internet.task`)
+ """
+
+ def _scheduler(x):
+ return reactor.callLater(_EPSILON, x)
+
+ return _scheduler
+
+
class IPBlacklistingResolver(object):
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
@@ -212,6 +228,10 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+ # We use this for our body producers to ensure that they use the correct
+ # reactor.
+ self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
+
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
@@ -292,7 +312,9 @@ class SimpleHttpClient(object):
try:
body_producer = None
if data is not None:
- body_producer = QuieterFileBodyProducer(BytesIO(data))
+ body_producer = QuieterFileBodyProducer(
+ BytesIO(data), cooperator=self._cooperator,
+ )
request_deferred = treq.request(
method,
@@ -371,7 +393,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -412,7 +434,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -441,7 +463,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
body = yield self.get_raw(uri, args, headers=headers)
- return json.loads(body)
+ return json.loads(body.decode("utf-8"))
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}, headers=None):
@@ -485,7 +507,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -503,7 +525,7 @@ class SimpleHttpClient(object):
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body at text.
+ HTTP body as bytes.
Raises:
HttpResponseException on a non-2xx HTTP response.
"""
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index c5fc746f..0c026480 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -15,6 +15,7 @@
import logging
import urllib
+from typing import List
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
@@ -236,11 +237,10 @@ class MatrixHostnameEndpoint(object):
return run_in_background(self._do_connect, protocol_factory)
- @defer.inlineCallbacks
- def _do_connect(self, protocol_factory):
+ async def _do_connect(self, protocol_factory):
first_exception = None
- server_list = yield self._resolve_server()
+ server_list = await self._resolve_server()
for server in server_list:
host = server.host
@@ -251,7 +251,7 @@ class MatrixHostnameEndpoint(object):
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)
- result = yield make_deferred_yieldable(
+ result = await make_deferred_yieldable(
endpoint.connect(protocol_factory)
)
@@ -271,13 +271,9 @@ class MatrixHostnameEndpoint(object):
# to try and if that doesn't work then we'll have an exception.
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
- @defer.inlineCallbacks
- def _resolve_server(self):
+ async def _resolve_server(self) -> List[Server]:
"""Resolves the server name to a list of hosts and ports to attempt to
connect to.
-
- Returns:
- Deferred[list[Server]]
"""
if self._parsed_uri.scheme != b"matrix":
@@ -298,7 +294,7 @@ class MatrixHostnameEndpoint(object):
if port or _is_ip_literal(host):
return [Server(host, port or 8448)]
- server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
+ server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
if server_list:
return server_list
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 021b233a..2ede90a9 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -17,10 +17,10 @@
import logging
import random
import time
+from typing import List
import attr
-from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
@@ -113,16 +113,14 @@ class SrvResolver(object):
self._cache = cache
self._get_time = get_time
- @defer.inlineCallbacks
- def resolve_service(self, service_name):
+ async def resolve_service(self, service_name: bytes) -> List[Server]:
"""Look up a SRV record
Args:
service_name (bytes): record to look up
Returns:
- Deferred[list[Server]]:
- a list of the SRV records, or an empty list if none found
+ a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())
@@ -136,7 +134,7 @@ class SrvResolver(object):
return _sort_server_list(servers)
try:
- answers, _, _ = yield make_deferred_yieldable(
+ answers, _, _ = await make_deferred_yieldable(
self._dns_client.lookupService(service_name)
)
except DNSNameError:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2b35f860..d4f9ad6e 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -217,7 +217,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return NOT_DONE_YET
@wrap_async_request_handler
- async def _async_render_wrapper(self, request):
+ async def _async_render_wrapper(self, request: SynapseRequest):
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
@@ -237,7 +237,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
f = failure.Failure()
self._send_error_response(f, request)
- async def _async_render(self, request):
+ async def _async_render(self, request: Request):
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overriden in sub classes for
different routing.
@@ -278,7 +278,7 @@ class DirectServeJsonResource(_AsyncResource):
"""
def _send_response(
- self, request, code, response_object,
+ self, request: Request, code: int, response_object: Any,
):
"""Implements _AsyncResource._send_response
"""
@@ -442,21 +442,6 @@ class StaticResource(File):
return super().render_GET(request)
-def _options_handler(request):
- """Request handler for OPTIONS requests
-
- This is a request handler suitable for return from
- _get_handler_for_request. It returns a 200 and an empty body.
-
- Args:
- request (twisted.web.http.Request):
-
- Returns:
- Tuple[int, dict]: http code, response body.
- """
- return 200, {}
-
-
def _unrecognised_request_handler(request):
"""Request handler for unrecognised requests
@@ -490,11 +475,12 @@ class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
def render_OPTIONS(self, request):
- code, response_json_object = _options_handler(request)
+ request.setResponseCode(204)
+ request.setHeader(b"Content-Length", b"0")
- return respond_with_json(
- request, code, response_json_object, send_cors=True, canonical_json=False,
- )
+ set_cors_headers(request)
+
+ return b""
def getChildWithDefault(self, path, request):
if request.method == b"OPTIONS":
@@ -507,14 +493,29 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
def respond_with_json(
- request,
- code,
- json_object,
- send_cors=False,
- response_code_message=None,
- pretty_print=False,
- canonical_json=True,
+ request: Request,
+ code: int,
+ json_object: Any,
+ send_cors: bool = False,
+ pretty_print: bool = False,
+ canonical_json: bool = True,
):
+ """Sends encoded JSON in response to the given request.
+
+ Args:
+ request: The http request to respond to.
+ code: The HTTP response code.
+ json_object: The object to serialize to JSON.
+ send_cors: Whether to send Cross-Origin Resource Sharing headers
+ https://fetch.spec.whatwg.org/#http-cors-protocol
+ pretty_print: Whether to include indentation and line-breaks in the
+ resulting JSON bytes.
+ canonical_json: Whether to use the canonicaljson algorithm when encoding
+ the JSON bytes.
+
+ Returns:
+ twisted.web.server.NOT_DONE_YET if the request is still active.
+ """
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
@@ -522,7 +523,7 @@ def respond_with_json(
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
- return
+ return None
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + b"\n"
@@ -533,30 +534,26 @@ def respond_with_json(
else:
json_bytes = json.dumps(json_object).encode("utf-8")
- return respond_with_json_bytes(
- request,
- code,
- json_bytes,
- send_cors=send_cors,
- response_code_message=response_code_message,
- )
+ return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
def respond_with_json_bytes(
- request, code, json_bytes, send_cors=False, response_code_message=None
+ request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
):
"""Sends encoded JSON in response to the given request.
Args:
- request (twisted.web.http.Request): The http request to respond to.
- code (int): The HTTP response code.
- json_bytes (bytes): The json bytes to use as the response body.
- send_cors (bool): Whether to send Cross-Origin Resource Sharing headers
+ request: The http request to respond to.
+ code: The HTTP response code.
+ json_bytes: The json bytes to use as the response body.
+ send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
+
Returns:
- twisted.web.server.NOT_DONE_YET"""
+ twisted.web.server.NOT_DONE_YET if the request is still active.
+ """
- request.setResponseCode(code, message=response_code_message)
+ request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
@@ -564,8 +561,8 @@ def respond_with_json_bytes(
if send_cors:
set_cors_headers(request)
- # todo: we can almost certainly avoid this copy and encode the json straight into
- # the bytesIO, but it would involve faffing around with string->bytes wrappers.
+ # note that this is zero-copy (the bytesio shares a copy-on-write buffer with
+ # the original `bytes`).
bytes_io = BytesIO(json_bytes)
producer = NoRangeStaticProducer(request, bytes_io)
@@ -573,12 +570,12 @@ def respond_with_json_bytes(
return NOT_DONE_YET
-def set_cors_headers(request):
- """Set the CORs headers so that javascript running in a web browsers can
+def set_cors_headers(request: Request):
+ """Set the CORS headers so that javascript running in a web browsers can
use this API
Args:
- request (twisted.web.http.Request): The http request to add CORs to.
+ request: The http request to add CORS to.
"""
request.setHeader(b"Access-Control-Allow-Origin", b"*")
request.setHeader(
@@ -643,7 +640,7 @@ def set_clickjacking_protection_headers(request: Request):
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
-def finish_request(request):
+def finish_request(request: Request):
""" Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
@@ -662,7 +659,7 @@ def finish_request(request):
logger.info("Connection disconnected before response was written: %r", e)
-def _request_user_agent_is_curl(request):
+def _request_user_agent_is_curl(request: Request) -> bool:
user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 13fcb408..a34e5ead 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -214,16 +214,8 @@ def parse_json_value_from_request(request, allow_empty_body=False):
if not content_bytes and allow_empty_body:
return None
- # Decode to Unicode so that simplejson will return Unicode strings on
- # Python 2
try:
- content_unicode = content_bytes.decode("utf8")
- except UnicodeDecodeError:
- logger.warning("Unable to decode UTF-8")
- raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
- try:
- content = json.loads(content_unicode)
+ content = json.loads(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index cbc37eac..6f3b2258 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -215,9 +215,7 @@ class SynapseRequest(Request):
# It's useful to log it here so that we can get an idea of when
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
- logger.warning(
- "Error processing request %r: %s %s", self, reason.type, reason.value
- )
+ logger.info("Connection from client lost before response was sent")
if not self._is_processing:
self._finished_processing()
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 8b9c4e38..cbeeb870 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -566,36 +566,33 @@ class LoggingContextFilter(logging.Filter):
return True
-class PreserveLoggingContext(object):
- """Captures the current logging context and restores it when the scope is
- exited. Used to restore the context after a function using
- @defer.inlineCallbacks is resumed by a callback from the reactor."""
+class PreserveLoggingContext:
+ """Context manager which replaces the logging context
- __slots__ = ["current_context", "new_context", "has_parent"]
+ The previous logging context is restored on exit."""
+
+ __slots__ = ["_old_context", "_new_context"]
def __init__(
self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
) -> None:
- self.new_context = new_context
+ self._new_context = new_context
def __enter__(self) -> None:
- """Captures the current logging context"""
- self.current_context = set_current_context(self.new_context)
-
- if self.current_context:
- self.has_parent = self.current_context.previous_context is not None
+ self._old_context = set_current_context(self._new_context)
def __exit__(self, type, value, traceback) -> None:
- """Restores the current logging context"""
- context = set_current_context(self.current_context)
+ context = set_current_context(self._old_context)
- if context != self.new_context:
+ if context != self._new_context:
if not context:
- logger.warning("Expected logging context %s was lost", self.new_context)
+ logger.warning(
+ "Expected logging context %s was lost", self._new_context
+ )
else:
logger.warning(
"Expected logging context %s but found %s",
- self.new_context,
+ self._new_context,
context,
)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c6c0e623..21dbd9f4 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -733,37 +733,43 @@ def trace(func=None, opname=None):
_opname = opname if opname else func.__name__
- @wraps(func)
- def _trace_inner(*args, **kwargs):
- if opentracing is None:
- return func(*args, **kwargs)
+ if inspect.iscoroutinefunction(func):
- scope = start_active_span(_opname)
- scope.__enter__()
+ @wraps(func)
+ async def _trace_inner(*args, **kwargs):
+ with start_active_span(_opname):
+ return await func(*args, **kwargs)
- try:
- result = func(*args, **kwargs)
- if isinstance(result, defer.Deferred):
+ else:
+ # The other case here handles both sync functions and those
+ # decorated with inlineDeferred.
+ @wraps(func)
+ def _trace_inner(*args, **kwargs):
+ scope = start_active_span(_opname)
+ scope.__enter__()
- def call_back(result):
- scope.__exit__(None, None, None)
- return result
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, defer.Deferred):
- def err_back(result):
- scope.span.set_tag(tags.ERROR, True)
- scope.__exit__(None, None, None)
- return result
+ def call_back(result):
+ scope.__exit__(None, None, None)
+ return result
+
+ def err_back(result):
+ scope.__exit__(None, None, None)
+ return result
- result.addCallbacks(call_back, err_back)
+ result.addCallbacks(call_back, err_back)
- else:
- scope.__exit__(None, None, None)
+ else:
+ scope.__exit__(None, None, None)
- return result
+ return result
- except Exception as e:
- scope.__exit__(type(e), None, e.__traceback__)
- raise
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
return _trace_inner
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index dc3ab00c..026854b4 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -116,6 +116,8 @@ class _LogContextScope(Scope):
if self._enter_logcontext:
self.logcontext.__enter__()
+ return self
+
def __exit__(self, type, value, traceback):
if type == twisted.internet.defer._DefGen_Return:
super(_LogContextScope, self).__exit__(None, None, None)
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 99049bb5..fea774e2 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -14,9 +14,7 @@
# limitations under the License.
-import inspect
import logging
-import time
from functools import wraps
from inspect import getcallargs
@@ -74,127 +72,3 @@ def log_function(f):
wrapped.__name__ = func_name
return wrapped
-
-
-def time_function(f):
- func_name = f.__name__
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- global _TIME_FUNC_ID
- id = _TIME_FUNC_ID
- _TIME_FUNC_ID += 1
-
- start = time.clock()
-
- try:
- _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id))
-
- r = f(*args, **kwargs)
- finally:
- end = time.clock()
- _log_debug_as_f(
- f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start)
- )
-
- return r
-
- return wrapped
-
-
-def trace_function(f):
- func_name = f.__name__
- linenum = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- name = f.__module__
- logger = logging.getLogger(name)
- level = logging.DEBUG
-
- frame = inspect.currentframe()
- if frame is None:
- raise Exception("Can't get current frame!")
-
- s = frame.f_back
-
- to_print = [
- "\t%s:%s %s. Args: args=%s, kwargs=%s"
- % (pathname, linenum, func_name, args, kwargs)
- ]
- while s:
- if True or s.f_globals["__name__"].startswith("synapse"):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- to_print.append(
- "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string)
- )
-
- s = s.f_back
-
- msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print)
-
- record = logging.LogRecord(
- name=name,
- level=level,
- pathname=pathname,
- lineno=lineno,
- msg=msg,
- args=(),
- exc_info=None,
- )
-
- logger.handle(record)
-
- return f(*args, **kwargs)
-
- wrapped.__name__ = func_name
- return wrapped
-
-
-def get_previous_frames():
-
- frame = inspect.currentframe()
- if frame is None:
- raise Exception("Can't get current frame!")
-
- s = frame.f_back.f_back
- to_return = []
- while s:
- if s.f_globals["__name__"].startswith("synapse"):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- to_return.append(
- "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string)
- )
-
- s = s.f_back
-
- return ", ".join(to_return)
-
-
-def get_previous_frame(ignore=[]):
- frame = inspect.currentframe()
- if frame is None:
- raise Exception("Can't get current frame!")
- s = frame.f_back.f_back
-
- while s:
- if s.f_globals["__name__"].startswith("synapse"):
- if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
- filename, lineno, function, _, _ = inspect.getframeinfo(s)
- args_string = inspect.formatargvalues(*inspect.getargvalues(s))
-
- return "{{ %s:%d %s - Args: %s }}" % (
- filename,
- lineno,
- function,
- args_string,
- )
-
- s = s.f_back
-
- return None
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 43ffe6fa..472ddf9f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -304,7 +304,9 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(
+ context.get_current_state_ids()
+ )
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index dda560b2..af117fdd 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -27,6 +27,7 @@ import jinja2
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
+from synapse.config.emailconfig import EmailSubjectConfig
from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
@@ -42,23 +43,6 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
-MESSAGE_FROM_PERSON_IN_ROOM = (
- "You have a message on %(app)s from %(person)s in the %(room)s room..."
-)
-MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
-MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
-MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
-MESSAGES_IN_ROOM_AND_OTHERS = (
- "You have messages on %(app)s in the %(room)s room and others..."
-)
-MESSAGES_FROM_PERSON_AND_OTHERS = (
- "You have messages on %(app)s from %(person)s and others..."
-)
-INVITE_FROM_PERSON_TO_ROOM = (
- "%(person)s has invited you to join the %(room)s room on %(app)s..."
-)
-INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
-
CONTEXT_BEFORE = 1
CONTEXT_AFTER = 1
@@ -121,6 +105,7 @@ class Mailer(object):
self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self.app_name = app_name
+ self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig
logger.info("Created Mailer for app_name %s" % app_name)
@@ -147,7 +132,8 @@ class Mailer(object):
await self.send_email(
email_address,
- "[%s] Password Reset" % self.hs.config.server_name,
+ self.email_subjects.password_reset
+ % {"server_name": self.hs.config.server_name},
template_vars,
)
@@ -174,7 +160,8 @@ class Mailer(object):
await self.send_email(
email_address,
- "[%s] Register your Email Address" % self.hs.config.server_name,
+ self.email_subjects.email_validation
+ % {"server_name": self.hs.config.server_name},
template_vars,
)
@@ -202,7 +189,8 @@ class Mailer(object):
await self.send_email(
email_address,
- "[%s] Validate Your Email" % self.hs.config.server_name,
+ self.email_subjects.email_validation
+ % {"server_name": self.hs.config.server_name},
template_vars,
)
@@ -269,16 +257,13 @@ class Mailer(object):
user_id, app_id, email_address
),
"summary_text": summary_text,
- "app_name": self.app_name,
"rooms": rooms,
"reason": reason,
}
- await self.send_email(
- email_address, "[%s] %s" % (self.app_name, summary_text), template_vars
- )
+ await self.send_email(email_address, summary_text, template_vars)
- async def send_email(self, email_address, subject, template_vars):
+ async def send_email(self, email_address, subject, extra_template_vars):
"""Send an email with the given information and template text"""
try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@@ -291,6 +276,13 @@ class Mailer(object):
if raw_to == "":
raise RuntimeError("Invalid 'to' address")
+ template_vars = {
+ "app_name": self.app_name,
+ "server_name": self.hs.config.server.server_name,
+ }
+
+ template_vars.update(extra_template_vars)
+
html_text = self.template_html.render(**template_vars)
html_part = MIMEText(html_text, "html", "utf8")
@@ -476,12 +468,12 @@ class Mailer(object):
inviter_name = name_from_member_event(inviter_member_event)
if room_name is None:
- return INVITE_FROM_PERSON % {
+ return self.email_subjects.invite_from_person % {
"person": inviter_name,
"app": self.app_name,
}
else:
- return INVITE_FROM_PERSON_TO_ROOM % {
+ return self.email_subjects.invite_from_person_to_room % {
"person": inviter_name,
"room": room_name,
"app": self.app_name,
@@ -499,13 +491,13 @@ class Mailer(object):
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
- return MESSAGE_FROM_PERSON_IN_ROOM % {
+ return self.email_subjects.message_from_person_in_room % {
"person": sender_name,
"room": room_name,
"app": self.app_name,
}
elif sender_name is not None:
- return MESSAGE_FROM_PERSON % {
+ return self.email_subjects.message_from_person % {
"person": sender_name,
"app": self.app_name,
}
@@ -513,7 +505,10 @@ class Mailer(object):
# There's more than one notification for this room, so just
# say there are several
if room_name is not None:
- return MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name}
+ return self.email_subjects.messages_in_room % {
+ "room": room_name,
+ "app": self.app_name,
+ }
else:
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
@@ -531,7 +526,7 @@ class Mailer(object):
]
)
- return MESSAGES_FROM_PERSON % {
+ return self.email_subjects.messages_from_person % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name,
}
@@ -540,7 +535,7 @@ class Mailer(object):
# ...but we still refer to the 'reason' room which triggered the mail
if reason["room_name"] is not None:
- return MESSAGES_IN_ROOM_AND_OTHERS % {
+ return self.email_subjects.messages_in_room_and_others % {
"room": reason["room_name"],
"app": self.app_name,
}
@@ -560,7 +555,7 @@ class Mailer(object):
[room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
)
- return MESSAGES_FROM_PERSON_AND_OTHERS % {
+ return self.email_subjects.messages_from_person_and_others % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name,
}
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f6a54586..2456f12f 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,13 +15,12 @@
# limitations under the License.
import logging
-from collections import defaultdict
-from threading import Lock
-from typing import Dict, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Union
+
+from prometheus_client import Gauge
from twisted.internet import defer
-from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher
@@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory
from synapse.util.async_helpers import concurrently_execute
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
+synapse_pushers = Gauge(
+ "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
+)
+
+
class PusherPool:
"""
The pusher pool. This is responsible for dispatching notifications of new events to
@@ -47,36 +55,20 @@ class PusherPool:
Pusher.on_new_receipts are not expected to return deferreds.
"""
- def __init__(self, _hs):
- self.hs = _hs
- self.pusher_factory = PusherFactory(_hs)
- self._should_start_pushers = _hs.config.start_pushers
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.pusher_factory = PusherFactory(hs)
+ self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
+ # We shard the handling of push notifications by user ID.
+ self._pusher_shard_config = hs.config.push.pusher_shard_config
+ self._instance_name = hs.get_instance_name()
+
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
- # a lock for the pushers dict, since `count_pushers` is called from an different
- # and we otherwise get concurrent modification errors
- self._pushers_lock = Lock()
-
- def count_pushers():
- results = defaultdict(int) # type: Dict[Tuple[str, str], int]
- with self._pushers_lock:
- for pushers in self.pushers.values():
- for pusher in pushers.values():
- k = (type(pusher).__name__, pusher.app_id)
- results[k] += 1
- return results
-
- LaterGauge(
- name="synapse_pushers",
- desc="the number of active pushers",
- labels=["kind", "app_id"],
- caller=count_pushers,
- )
-
def start(self):
"""Starts the pushers off in a background process.
"""
@@ -104,6 +96,7 @@ class PusherPool:
Returns:
Deferred[EmailPusher|HttpPusher]
"""
+
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it
@@ -176,6 +169,9 @@ class PusherPool:
access_tokens (Iterable[int]): access token *ids* to remove pushers
for
"""
+ if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+ return
+
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
if p["access_token"] in tokens:
@@ -237,6 +233,9 @@ class PusherPool:
if not self._should_start_pushers:
return
+ if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+ return
+
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None
@@ -275,6 +274,11 @@ class PusherPool:
Returns:
Deferred[EmailPusher|HttpPusher]
"""
+ if not self._pusher_shard_config.should_handle(
+ self._instance_name, pusherdict["user_name"]
+ ):
+ return
+
try:
p = self.pusher_factory.create_pusher(pusherdict)
except PusherConfigException as e:
@@ -298,11 +302,12 @@ class PusherPool:
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
- with self._pushers_lock:
- byuser = self.pushers.setdefault(pusherdict["user_name"], {})
- if appid_pushkey in byuser:
- byuser[appid_pushkey].on_stop()
- byuser[appid_pushkey] = p
+ byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+ if appid_pushkey in byuser:
+ byuser[appid_pushkey].on_stop()
+ byuser[appid_pushkey] = p
+
+ synapse_pushers.labels(type(p).__name__, p.app_id).inc()
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
@@ -330,9 +335,10 @@ class PusherPool:
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
- byuser[appid_pushkey].on_stop()
- with self._pushers_lock:
- del byuser[appid_pushkey]
+ pusher = byuser.pop(appid_pushkey)
+ pusher.on_stop()
+
+ synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 5ef1c6c1..a84a064c 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -39,10 +39,10 @@ class ReplicationRestResource(JsonResource):
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
+ streams.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
- streams.register_servlets(hs, self)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index bd394f6b..a8a16dbc 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -26,7 +26,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 4985e40b..fcf8ebf1 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -24,6 +24,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams import TypingStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
@@ -104,6 +105,7 @@ class ReplicationDataHandler:
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
+ self._typing_handler = hs.get_typing_handler()
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
@@ -127,6 +129,12 @@ class ReplicationDataHandler:
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
+ if stream_name == TypingStream.NAME:
+ self._typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows]
+ )
+
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index ccc7f1f0..f33801f8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -293,20 +293,22 @@ class FederationAckCommand(Command):
Format::
- FEDERATION_ACK <token>
+ FEDERATION_ACK <instance_name> <token>
"""
NAME = "FEDERATION_ACK"
- def __init__(self, token):
+ def __init__(self, instance_name, token):
+ self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
- return cls(int(line))
+ instance_name, token = line.split(" ")
+ return cls(instance_name, int(token))
def to_line(self):
- return str(self.token)
+ return "%s %s" % (self.instance_name, self.token)
class RemovePusherCommand(Command):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 55b3b790..1c303f3a 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -14,13 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ Any,
+ Awaitable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
from prometheus_client import Counter
+from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
@@ -42,8 +56,8 @@ from synapse.replication.tcp.streams import (
EventsStream,
FederationStream,
Stream,
+ TypingStream,
)
-from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -55,12 +69,16 @@ inbound_rdata_count = Counter(
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter(
- "synapse_replication_tcp_resource_invalidate_cache", ""
-)
+
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+# the type of the entries in _command_queues_by_stream
+_StreamCommandQueue = Deque[
+ Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+]
+
+
class ReplicationCommandHandler:
"""Handles incoming commands from replication as well as sending commands
back out to connections.
@@ -96,6 +114,14 @@ class ReplicationCommandHandler:
continue
+ if isinstance(stream, TypingStream):
+ # Only add TypingStream as a source on the instance in charge of
+ # typing.
+ if hs.config.worker.writers.typing == hs.get_instance_name():
+ self._streams_to_replicate.append(stream)
+
+ continue
+
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue
@@ -107,10 +133,6 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream)
- self._position_linearizer = Linearizer(
- "replication_position", clock=self._clock
- )
-
# Map of stream name to batched updates. See RdataCommand for info on
# how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
@@ -122,10 +144,6 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
- # For each connection, the incoming stream names that are coming from
- # that connection.
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
-
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
@@ -133,6 +151,32 @@ class ReplicationCommandHandler:
lambda: len(self._connections),
)
+ # When POSITION or RDATA commands arrive, we stick them in a queue and process
+ # them in order in a separate background process.
+
+ # the streams which are currently being processed by _unsafe_process_queue
+ self._processing_streams = set() # type: Set[str]
+
+ # for each stream, a queue of commands that are awaiting processing, and the
+ # connection that they arrived on.
+ self._command_queues_by_stream = {
+ stream_name: _StreamCommandQueue() for stream_name in self._streams
+ }
+
+ # For each connection, the incoming stream names that have received a POSITION
+ # from that connection.
+ self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+
+ LaterGauge(
+ "synapse_replication_tcp_command_queue",
+ "Number of inbound RDATA/POSITION commands queued for processing",
+ ["stream_name"],
+ lambda: {
+ (stream_name,): len(queue)
+ for stream_name, queue in self._command_queues_by_stream.items()
+ },
+ )
+
self._is_master = hs.config.worker_app is None
self._federation_sender = None
@@ -143,6 +187,65 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
+ def _add_command_to_stream_queue(
+ self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ ) -> None:
+ """Queue the given received command for processing
+
+ Adds the given command to the per-stream queue, and processes the queue if
+ necessary
+ """
+ stream_name = cmd.stream_name
+ queue = self._command_queues_by_stream.get(stream_name)
+ if queue is None:
+ logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
+ return
+
+ queue.append((cmd, conn))
+
+ # if we're already processing this stream, there's nothing more to do:
+ # the new entry on the queue will get picked up in due course
+ if stream_name in self._processing_streams:
+ return
+
+ # fire off a background process to start processing the queue.
+ run_as_background_process(
+ "process-replication-data", self._unsafe_process_queue, stream_name
+ )
+
+ async def _unsafe_process_queue(self, stream_name: str):
+ """Processes the command queue for the given stream, until it is empty
+
+ Does not check if there is already a thread processing the queue, hence "unsafe"
+ """
+ assert stream_name not in self._processing_streams
+
+ self._processing_streams.add(stream_name)
+ try:
+ queue = self._command_queues_by_stream.get(stream_name)
+ while queue:
+ cmd, conn = queue.popleft()
+ try:
+ await self._process_command(cmd, conn, stream_name)
+ except Exception:
+ logger.exception("Failed to handle command %s", cmd)
+ finally:
+ self._processing_streams.discard(stream_name)
+
+ async def _process_command(
+ self,
+ cmd: Union[PositionCommand, RdataCommand],
+ conn: AbstractConnection,
+ stream_name: str,
+ ) -> None:
+ if isinstance(cmd, PositionCommand):
+ await self._process_position(stream_name, conn, cmd)
+ elif isinstance(cmd, RdataCommand):
+ await self._process_rdata(stream_name, conn, cmd)
+ else:
+ # This shouldn't be possible
+ raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
+
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
@@ -199,7 +302,7 @@ class ReplicationCommandHandler:
"""
return self._streams_to_replicate
- async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection):
@@ -218,57 +321,73 @@ class ReplicationCommandHandler:
)
)
- async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+ def on_USER_SYNC(
+ self, conn: AbstractConnection, cmd: UserSyncCommand
+ ) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
if self._is_master:
- await self._presence_handler.update_external_syncs_row(
+ return self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
+ else:
+ return None
- async def on_CLEAR_USER_SYNC(
+ def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
- ):
+ ) -> Optional[Awaitable[None]]:
if self._is_master:
- await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+ return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+ else:
+ return None
- async def on_FEDERATION_ACK(
- self, conn: AbstractConnection, cmd: FederationAckCommand
- ):
+ def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self._federation_sender:
- self._federation_sender.federation_ack(cmd.token)
+ self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
- async def on_REMOVE_PUSHER(
+ def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
- ):
+ ) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc()
if self._is_master:
- await self._store.delete_pusher_by_app_id_pushkey_user_id(
- app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
- )
+ return self._handle_remove_pusher(cmd)
+ else:
+ return None
- self._notifier.on_new_replication_data()
+ async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
+ await self._store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+ )
+
+ self._notifier.on_new_replication_data()
- async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+ def on_USER_IP(
+ self, conn: AbstractConnection, cmd: UserIpCommand
+ ) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
if self._is_master:
- await self._store.insert_client_ip(
- cmd.user_id,
- cmd.access_token,
- cmd.ip,
- cmd.user_agent,
- cmd.device_id,
- cmd.last_seen,
- )
+ return self._handle_user_ip(cmd)
+ else:
+ return None
+
+ async def _handle_user_ip(self, cmd: UserIpCommand):
+ await self._store.insert_client_ip(
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
+ cmd.last_seen,
+ )
- if self._server_notices_sender:
- await self._server_notices_sender.on_user_ip(cmd.user_id)
+ assert self._server_notices_sender is not None
+ await self._server_notices_sender.on_user_ip(cmd.user_id)
- async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -276,63 +395,71 @@ class ReplicationCommandHandler:
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
- try:
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
- except Exception:
- logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
- raise
-
- # We linearize here for two reasons:
+ # We put the received command into a queue here for two reasons:
# 1. so we don't try and concurrently handle multiple rows for the
# same stream, and
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
- with await self._position_linearizer.queue(cmd.stream_name):
- # make sure that we've processed a POSITION for this stream *on this
- # connection*. (A POSITION on another connection is no good, as there
- # is no guarantee that we have seen all the intermediate updates.)
- sbc = self._streams_by_connection.get(conn)
- if not sbc or stream_name not in sbc:
- # Let's drop the row for now, on the assumption we'll receive a
- # `POSITION` soon and we'll catch up correctly then.
- logger.debug(
- "Discarding RDATA for unconnected stream %s -> %s",
- stream_name,
- cmd.token,
- )
- return
-
- if cmd.token is None:
- # I.e. this is part of a batch of updates for this stream (in
- # which case batch until we get an update for the stream with a non
- # None token).
- self._pending_batches.setdefault(stream_name, []).append(row)
- else:
- # Check if this is the last of a batch of updates
- rows = self._pending_batches.pop(stream_name, [])
- rows.append(row)
-
- stream = self._streams.get(stream_name)
- if not stream:
- logger.error("Got RDATA for unknown stream: %s", stream_name)
- return
-
- # Find where we previously streamed up to.
- current_token = stream.current_token(cmd.instance_name)
-
- # Discard this data if this token is earlier than the current
- # position. Note that streams can be reset (in which case you
- # expect an earlier token), but that must be preceded by a
- # POSITION command.
- if cmd.token <= current_token:
- logger.debug(
- "Discarding RDATA from stream %s at position %s before previous position %s",
- stream_name,
- cmd.token,
- current_token,
- )
- else:
- await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+ self._add_command_to_stream_queue(conn, cmd)
+
+ async def _process_rdata(
+ self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ ) -> None:
+ """Process an RDATA command
+
+ Called after the command has been popped off the queue of inbound commands
+ """
+ try:
+ row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+ except Exception as e:
+ raise Exception(
+ "Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
+ ) from e
+
+ # make sure that we've processed a POSITION for this stream *on this
+ # connection*. (A POSITION on another connection is no good, as there
+ # is no guarantee that we have seen all the intermediate updates.)
+ sbc = self._streams_by_connection.get(conn)
+ if not sbc or stream_name not in sbc:
+ # Let's drop the row for now, on the assumption we'll receive a
+ # `POSITION` soon and we'll catch up correctly then.
+ logger.debug(
+ "Discarding RDATA for unconnected stream %s -> %s",
+ stream_name,
+ cmd.token,
+ )
+ return
+
+ if cmd.token is None:
+ # I.e. this is part of a batch of updates for this stream (in
+ # which case batch until we get an update for the stream with a non
+ # None token).
+ self._pending_batches.setdefault(stream_name, []).append(row)
+ return
+
+ # Check if this is the last of a batch of updates
+ rows = self._pending_batches.pop(stream_name, [])
+ rows.append(row)
+
+ stream = self._streams[stream_name]
+
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
+
+ # Discard this data if this token is earlier than the current
+ # position. Note that streams can be reset (in which case you
+ # expect an earlier token), but that must be preceded by a
+ # POSITION command.
+ if cmd.token <= current_token:
+ logger.debug(
+ "Discarding RDATA from stream %s at position %s before previous position %s",
+ stream_name,
+ cmd.token,
+ current_token,
+ )
+ else:
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -351,78 +478,74 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
- async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
- stream_name = cmd.stream_name
- stream = self._streams.get(stream_name)
- if not stream:
- logger.error("Got POSITION for unknown stream: %s", stream_name)
- return
+ self._add_command_to_stream_queue(conn, cmd)
- # We protect catching up with a linearizer in case the replication
- # connection reconnects under us.
- with await self._position_linearizer.queue(stream_name):
- # We're about to go and catch up with the stream, so remove from set
- # of connected streams.
- for streams in self._streams_by_connection.values():
- streams.discard(stream_name)
-
- # We clear the pending batches for the stream as the fetching of the
- # missing updates below will fetch all rows in the batch.
- self._pending_batches.pop(stream_name, [])
-
- # Find where we previously streamed up to.
- current_token = stream.current_token(cmd.instance_name)
-
- # If the position token matches our current token then we're up to
- # date and there's nothing to do. Otherwise, fetch all updates
- # between then and now.
- missing_updates = cmd.token != current_token
- while missing_updates:
- logger.info(
- "Fetching replication rows for '%s' between %i and %i",
- stream_name,
- current_token,
- cmd.token,
- )
- (
- updates,
- current_token,
- missing_updates,
- ) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
- )
+ async def _process_position(
+ self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ ) -> None:
+ """Process a POSITION command
- # TODO: add some tests for this
+ Called after the command has been popped off the queue of inbound commands
+ """
+ stream = self._streams[stream_name]
- # Some streams return multiple rows with the same stream IDs,
- # which need to be processed in batches.
+ # We're about to go and catch up with the stream, so remove from set
+ # of connected streams.
+ for streams in self._streams_by_connection.values():
+ streams.discard(stream_name)
- for token, rows in _batch_updates(updates):
- await self.on_rdata(
- stream_name,
- cmd.instance_name,
- token,
- [stream.parse_row(row) for row in rows],
- )
+ # We clear the pending batches for the stream as the fetching of the
+ # missing updates below will fetch all rows in the batch.
+ self._pending_batches.pop(stream_name, [])
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
- # We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ # If the position token matches our current token then we're up to
+ # date and there's nothing to do. Otherwise, fetch all updates
+ # between then and now.
+ missing_updates = cmd.token != current_token
+ while missing_updates:
+ logger.info(
+ "Fetching replication rows for '%s' between %i and %i",
+ stream_name,
+ current_token,
+ cmd.token,
+ )
+ (updates, current_token, missing_updates) = await stream.get_updates_since(
+ cmd.instance_name, current_token, cmd.token
)
- self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+ # TODO: add some tests for this
- async def on_REMOTE_SERVER_UP(
- self, conn: AbstractConnection, cmd: RemoteServerUpCommand
- ):
+ # Some streams return multiple rows with the same stream IDs,
+ # which need to be processed in batches.
+
+ for token, rows in _batch_updates(updates):
+ await self.on_rdata(
+ stream_name,
+ cmd.instance_name,
+ token,
+ [stream.parse_row(row) for row in rows],
+ )
+
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+ # We've now caught up to position sent to us, notify handler.
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
+
+ self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+
+ def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@@ -527,7 +650,7 @@ class ReplicationCommandHandler:
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
- self.send_command(FederationAckCommand(token))
+ self.send_command(FederationAckCommand(self._instance_name, token))
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index ca47f5cc..03509238 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc
import fcntl
import logging
import struct
+from inspect import isawaitable
from typing import TYPE_CHECKING, List
from prometheus_client import Counter
@@ -57,8 +58,12 @@ from prometheus_client import Counter
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
+from synapse.logging.context import PreserveLoggingContext
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ BackgroundProcessLoggingContext,
+ run_as_background_process,
+)
from synapse.replication.tcp.commands import (
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
@@ -124,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
+ `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
+ if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@@ -160,6 +167,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
+ # a logcontext which we use for processing incoming commands. We declare it as a
+ # background process so that the CPU stats get reported to prometheus.
+ ctx_name = "replication-conn-%s" % self.conn_id
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name)
+ self._logging_context.request = ctx_name
+
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -210,6 +223,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def lineReceived(self, line: bytes):
"""Called when we've received a line
"""
+ with PreserveLoggingContext(self._logging_context):
+ self._parse_and_dispatch_line(line)
+
+ def _parse_and_dispatch_line(self, line: bytes):
if line.strip() == "":
# Ignore blank lines
return
@@ -232,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
- # Now lets try and call on_<CMD_NAME> function
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
- )
+ self.handle_command(cmd)
- async def handle_command(self, cmd: Command):
+ def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls
- `self.command_handler.on_<COMMAND>` if it exists. This allows for
- protocol level handling of commands (e.g. PINGs), before delegating to
- the handler.
+ `self.command_handler.on_<COMMAND>` if it exists (which can optionally
+ return an Awaitable).
+
+ This allows for protocol level handling of commands (e.g. PINGs), before
+ delegating to the handler.
Args:
cmd: received command
@@ -254,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(self, cmd)
+ res = cmd_func(self, cmd)
+
+ # the handler might be a coroutine: fire it off as a background process
+ # if so.
+
+ if isawaitable(res):
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), lambda: res
+ )
+
handled = True
if not handled:
@@ -336,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- async def on_PING(self, line):
+ def on_PING(self, line):
self.received_ping = True
- async def on_ERROR(self, cmd):
+ def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -397,6 +422,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if self.transport:
self.transport.unregisterProducer()
+ # mark the logging context as finished
+ self._logging_context.__exit__(None, None, None)
+
def __str__(self):
addr = None
if self.transport:
@@ -431,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ServerCommand(self.server_name))
super().connectionMade()
- async def on_NAME(self, cmd):
+ def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
@@ -460,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Once we've connected subscribe to the necessary streams
self.replicate()
- async def on_SERVER(self, cmd):
+ def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 0a7e7f67..f225e533 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,12 +14,16 @@
# limitations under the License.
import logging
+from inspect import isawaitable
from typing import TYPE_CHECKING
import txredisapi
-from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import (
+ BackgroundProcessLoggingContext,
+ run_as_background_process,
+)
from synapse.replication.tcp.commands import (
Command,
ReplicateCommand,
@@ -66,6 +70,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
stream_name = None # type: str
outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # a logcontext which we use for processing incoming commands. We declare it as a
+ # background process so that the CPU stats get reported to prometheus.
+ self._logging_context = BackgroundProcessLoggingContext(
+ "replication_command_handler"
+ )
+
def connectionMade(self):
logger.info("Connected to redis")
super().connectionMade()
@@ -92,7 +105,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
+ with PreserveLoggingContext(self._logging_context):
+ self._parse_and_dispatch_message(message)
+ def _parse_and_dispatch_message(self, message: str):
if message.strip() == "":
# Ignore blank lines
return
@@ -109,42 +125,41 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
- # Now lets try and call on_<CMD_NAME> function
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
- )
+ self.handle_command(cmd)
- async def handle_command(self, cmd: Command):
+ def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>, which should return an awaitable.
+ Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
+ Awaitable).
Args:
cmd: received command
"""
- handled = False
-
- # First call any command handlers on this instance. These are for redis
- # specific handling.
- cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(cmd)
- handled = True
- # Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(self, cmd)
- handled = True
-
- if not handled:
+ if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
+ return
+
+ res = cmd_func(self, cmd)
+
+ # the handler might be a coroutine: fire it off as a background process
+ # if so.
+
+ if isawaitable(res):
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), lambda: res
+ )
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
self.handler.lost_connection(self)
+ # mark the logging context as finished
+ self._logging_context.__exit__(None, None, None)
+
def send_command(self, cmd: Command):
"""Send a command if connection has been established.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9076bbe9..7a42de3f 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -294,11 +294,12 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- if hs.config.worker_app is None:
- # on the master, query the typing handler
+ writer_instance = hs.config.worker.writers.typing
+ if writer_instance == hs.get_instance_name():
+ # On the writer, query the typing handler
update_function = typing_handler.get_all_typing_updates
else:
- # Query master process
+ # Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 1c2a4cce..16c63ff4 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
-from collections import Iterable
+from collections.abc import Iterable
from typing import List, Tuple, Type
import attr
diff --git a/synapse/res/templates/mail-Element.css b/synapse/res/templates/mail-Element.css
new file mode 100644
index 00000000..6a3e36ed
--- /dev/null
+++ b/synapse/res/templates/mail-Element.css
@@ -0,0 +1,7 @@
+.header {
+ border-bottom: 4px solid #e4f7ed ! important;
+}
+
+.notif_link a, .footer a {
+ color: #76CFA6 ! important;
+}
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index 6b94d8c3..d87311f6 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -22,6 +22,8 @@
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
+ {% elif app_name == "Element" %}
+ <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index 019506e5..a2dfeb9e 100644
--- a/synapse/res/templates/notif_mail.html
+++ b/synapse/res/templates/notif_mail.html
@@ -22,6 +22,8 @@
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
+ {% elif app_name == "Element" %}
+ <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 9eda592d..1c88c93f 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -35,8 +35,10 @@ from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
+ DeleteRoomRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
+ RoomMembersRestServlet,
RoomRestServlet,
ShutdownRoomRestServlet,
)
@@ -200,6 +202,8 @@ def register_servlets(hs, http_server):
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
+ RoomMembersRestServlet(hs).register(http_server)
+ DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index e07c3211..b8c95d04 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import List, Optional
-from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -32,7 +33,6 @@ from synapse.rest.admin._base import (
)
from synapse.storage.data_stores.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
-from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -46,20 +46,10 @@ class ShutdownRoomRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
- DEFAULT_MESSAGE = (
- "Sharing illegal content on this server is not permitted and rooms in"
- " violation will be blocked."
- )
-
def __init__(self, hs):
self.hs = hs
- self.store = hs.get_datastore()
- self.state = hs.get_state_handler()
- self._room_creation_handler = hs.get_room_creation_handler()
- self.event_creation_handler = hs.get_event_creation_handler()
- self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- self._replication = hs.get_replication_data_handler()
+ self.room_shutdown_handler = hs.get_room_shutdown_handler()
async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
@@ -67,116 +57,65 @@ class ShutdownRoomRestServlet(RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"])
- new_room_user_id = content["new_room_user_id"]
-
- room_creator_requester = create_requester(new_room_user_id)
-
- message = content.get("message", self.DEFAULT_MESSAGE)
- room_name = content.get("room_name", "Content Violation Notification")
-
- info, stream_id = await self._room_creation_handler.create_room(
- room_creator_requester,
- config={
- "preset": RoomCreationPreset.PUBLIC_CHAT,
- "name": room_name,
- "power_level_content_override": {"users_default": -10},
- },
- ratelimit=False,
- )
- new_room_id = info["room_id"]
-
- requester_user_id = requester.user.to_string()
- logger.info(
- "Shutting down room %r, joining to new room: %r", room_id, new_room_id
+ ret = await self.room_shutdown_handler.shutdown_room(
+ room_id=room_id,
+ new_room_user_id=content["new_room_user_id"],
+ new_room_name=content.get("room_name"),
+ message=content.get("message"),
+ requester_user_id=requester.user.to_string(),
+ block=True,
)
- # This will work even if the room is already blocked, but that is
- # desirable in case the first attempt at blocking the room failed below.
- await self.store.block_room(room_id, requester_user_id)
-
- # We now wait for the create room to come back in via replication so
- # that we can assume that all the joins/invites have propogated before
- # we try and auto join below.
- #
- # TODO: Currently the events stream is written to from master
- await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", stream_id
- )
+ return (200, ret)
- users = await self.state.get_current_users_in_room(room_id)
- kicked_users = []
- failed_to_kick_users = []
- for user_id in users:
- if not self.hs.is_mine_id(user_id):
- continue
- logger.info("Kicking %r from %r...", user_id, room_id)
+class DeleteRoomRestServlet(RestServlet):
+ """Delete a room from server. It is a combination and improvement of
+ shut down and purge room.
+ Shuts down a room by removing all local users from the room.
+ Blocking all future invites and joins to the room is optional.
+ If desired any local aliases will be repointed to a new room
+ created by `new_room_user_id` and kicked users will be auto
+ joined to the new room.
+ It will remove all trace of a room from the database.
+ """
- try:
- target_requester = create_requester(user_id)
- _, stream_id = await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=room_id,
- action=Membership.LEAVE,
- content={},
- ratelimit=False,
- require_consent=False,
- )
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
- # Wait for leave to come in over replication before trying to forget.
- await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", stream_id
- )
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_shutdown_handler = hs.get_room_shutdown_handler()
+ self.pagination_handler = hs.get_pagination_handler()
- await self.room_member_handler.forget(target_requester.user, room_id)
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
- await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=new_room_id,
- action=Membership.JOIN,
- content={},
- ratelimit=False,
- require_consent=False,
- )
+ content = parse_json_object_from_request(request)
- kicked_users.append(user_id)
- except Exception:
- logger.exception(
- "Failed to leave old room and join new room for %r", user_id
- )
- failed_to_kick_users.append(user_id)
-
- await self.event_creation_handler.create_and_send_nonmember_event(
- room_creator_requester,
- {
- "type": "m.room.message",
- "content": {"body": message, "msgtype": "m.text"},
- "room_id": new_room_id,
- "sender": new_room_user_id,
- },
- ratelimit=False,
- )
+ block = content.get("block", False)
+ if not isinstance(block, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'block' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
- aliases_for_room = await maybe_awaitable(
- self.store.get_aliases_for_room(room_id)
+ ret = await self.room_shutdown_handler.shutdown_room(
+ room_id=room_id,
+ new_room_user_id=content.get("new_room_user_id"),
+ new_room_name=content.get("room_name"),
+ message=content.get("message"),
+ requester_user_id=requester.user.to_string(),
+ block=block,
)
- await self.store.update_aliases_for_room(
- room_id, new_room_id, requester_user_id
- )
+ # Purge room
+ await self.pagination_handler.purge_room(room_id)
- return (
- 200,
- {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- },
- )
+ return (200, ret)
class ListRoomRestServlet(RestServlet):
@@ -292,6 +231,31 @@ class RoomRestServlet(RestServlet):
return 200, ret
+class RoomMembersRestServlet(RestServlet):
+ """
+ Get members list of a room.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, room_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ ret = await self.store.get_room(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ members = await self.store.get_users_in_room(room_id)
+ ret = {"members": members, "total": len(members)}
+
+ return 200, ret
+
+
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index e4330c39..cc0bdfa5 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -239,6 +239,15 @@ class UserRestServletV2(RestServlet):
await self.deactivate_account_handler.deactivate_account(
target_user.to_string(), False
)
+ elif not deactivate and user["deactivated"]:
+ if "password" not in body:
+ raise SynapseError(
+ 400, "Must provide a password to re-activate an account."
+ )
+
+ await self.deactivate_account_handler.activate_account(
+ target_user.to_string()
+ )
user = await self.admin_handler.get_user(target_user)
return 200, user
@@ -254,7 +263,6 @@ class UserRestServletV2(RestServlet):
admin = body.get("admin", None)
user_type = body.get("user_type", None)
displayname = body.get("displayname", None)
- threepids = body.get("threepids", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 64d5c58b..379f668d 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
self.hs = hs
+
+ # JWT configuration variables.
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
+ self.jwt_issuer = hs.config.jwt_issuer
+ self.jwt_audiences = hs.config.jwt_audiences
+
+ # SSO configuration.
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -364,24 +371,28 @@ class LoginRestServlet(RestServlet):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
- 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
+ 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)
import jwt
- from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(
- token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+ token,
+ self.jwt_secret,
+ algorithms=[self.jwt_algorithm],
+ issuer=self.jwt_issuer,
+ audience=self.jwt_audiences,
+ )
+ except jwt.PyJWTError as e:
+ # A JWT error occurred, return some info back to the client.
+ raise LoginError(
+ 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
)
- except jwt.ExpiredSignatureError:
- raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
- except InvalidTokenError:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload.get("sub", None)
if user is None:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f40ed821..26d5a51c 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -15,6 +15,7 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
+
import logging
import re
from typing import List, Optional
@@ -515,9 +516,9 @@ class RoomMessageListRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
- filter_bytes = parse_string(request, b"filter", encoding=None)
- if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
if (
event_filter
@@ -627,9 +628,9 @@ class RoomEventContextServlet(RestServlet):
limit = parse_integer(request, "limit", default=10)
# picking the API shape for symmetry with /messages
- filter_bytes = parse_string(request, "filter")
- if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes)
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else:
event_filter = None
@@ -816,9 +817,18 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
+ # If we're not on the typing writer instance we should scream if we get
+ # requests.
+ self._is_typing_writer = (
+ hs.config.worker.writers.typing == hs.get_instance_name()
+ )
+
async def on_PUT(self, request, room_id, user_id):
requester = await self.auth.get_user_by_req(request)
+ if not self._is_typing_writer:
+ raise Exception("Got /typing request on instance that is not typing writer")
+
room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id))
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index bc11b4dd..f016b4f1 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -17,24 +17,32 @@
"""
import logging
import re
-
-from twisted.internet import defer
+from typing import Iterable, Pattern
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
-def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
+def client_patterns(
+ path_regex: str,
+ releases: Iterable[int] = (0,),
+ unstable: bool = True,
+ v1: bool = False,
+) -> Iterable[Pattern]:
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
- path_regex (str): The regex string to match. This should NOT have a ^
+ path_regex: The regex string to match. This should NOT have a ^
as this will be prefixed.
+ releases: An iterable of releases to include this endpoint under.
+ unstable: If true, include this endpoint under the "unstable" prefix.
+ v1: If true, include this endpoint under the "api/v1" prefix.
Returns:
- SRE_Pattern
+ An iterable of patterns.
"""
patterns = []
@@ -51,7 +59,15 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
return patterns
-def set_timeline_upper_limit(filter_json, filter_timeline_limit):
+def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None:
+ """
+ Enforces a maximum limit of a timeline query.
+
+ Params:
+ filter_json: The timeline query to modify.
+ filter_timeline_limit: The maximum limit to allow, passing -1 will
+ disable enforcing a maximum limit.
+ """
if filter_timeline_limit < 0:
return # no upper limits
timeline = filter_json.get("room", {}).get("timeline", {})
@@ -64,34 +80,22 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
- Takes a on_POST method which returns a deferred (errcode, body) response
+ Takes a on_POST method which returns an Awaitable (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
# ...
- yield self.auth_handler.check_auth
- """
+ await self.auth_handler.check_auth
+ """
- def wrapped(*args, **kwargs):
- res = defer.ensureDeferred(orig(*args, **kwargs))
- res.addErrback(_catch_incomplete_interactive_auth)
- return res
+ async def wrapped(*args, **kwargs):
+ try:
+ return await orig(*args, **kwargs)
+ except InteractiveAuthIncompleteError as e:
+ return 401, e.result
return wrapped
-
-
-def _catch_incomplete_interactive_auth(f):
- """helper for interactive_auth_handler
-
- Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
-
- Args:
- f (failure.Failure):
- """
- f.trap(InteractiveAuthIncompleteError)
- return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 8fa68dd3..a5c24fbd 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -178,14 +178,22 @@ class SyncRestServlet(RestServlet):
full_state=full_state,
)
+ # the client may have disconnected by now; don't bother to serialize the
+ # response if so.
+ if request._disconnected:
+ logger.info("Client has disconnected; not serializing response.")
+ return 200, {}
+
time_now = self.clock.time_msec()
response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
)
+ logger.debug("Event formatting complete")
return 200, response_content
async def encode_response(self, time_now, sync_result, access_token_id, filter):
+ logger.debug("Formatting events in sync response")
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == "federation":
@@ -213,6 +221,7 @@ class SyncRestServlet(RestServlet):
event_formatter,
)
+ logger.debug("building sync response dict")
return {
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index e149ac17..9b3f85b3 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -202,9 +202,11 @@ class RemoteKey(DirectServeJsonResource):
if miss:
cache_misses.setdefault(server_name, set()).add(key_id)
+ # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
for ts_added, result in results:
+ # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
@@ -213,7 +215,7 @@ class RemoteKey(DirectServeJsonResource):
else:
signed_keys = []
for key_json in json_results:
- key_json = json.loads(key_json)
+ key_json = json.loads(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 595849f9..9a847130 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -18,7 +18,6 @@ import logging
import os
import urllib
-from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@@ -77,8 +76,9 @@ def respond_404(request):
)
-@defer.inlineCallbacks
-def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
+async def respond_with_file(
+ request, media_type, file_path, file_size=None, upload_name=None
+):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+ await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
@@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x):
return True
-@defer.inlineCallbacks
-def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+async def respond_with_responder(
+ request, responder, media_type, file_size, upload_name=None
+):
"""Responds to the request with given responder. If responder is None then
returns 404.
@@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
- yield responder.write_to_consumer(request)
+ await responder.write_to_consumer(request)
except Exception as e:
# The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 79cb0ddd..66bc1c33 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -14,17 +14,18 @@
# limitations under the License.
import contextlib
+import inspect
import logging
import os
import shutil
+from typing import Optional
-from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
-from ._base import Responder
+from ._base import FileInfo, Responder
logger = logging.getLogger(__name__)
@@ -46,25 +47,24 @@ class MediaStorage(object):
self.filepaths = filepaths
self.storage_providers = storage_providers
- @defer.inlineCallbacks
- def store_file(self, source, file_info):
+ async def store_file(self, source, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Returns:
- Deferred[str]: the file path written to in the primary media store
+ the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- yield defer_to_thread(
+ await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
- yield finish_cb()
+ await finish_cb()
return fname
@@ -75,7 +75,7 @@ class MediaStorage(object):
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
- on disk, and finish_cb is a function that returns a Deferred.
+ on disk, and finish_cb is a function that returns an awaitable.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
@@ -91,7 +91,7 @@ class MediaStorage(object):
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
- yield finish_cb()
+ await finish_cb()
"""
path = self._file_info_to_path(file_info)
@@ -103,10 +103,13 @@ class MediaStorage(object):
finished_called = [False]
- @defer.inlineCallbacks
- def finish():
+ async def finish():
for provider in self.storage_providers:
- yield provider.store_file(path, file_info)
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = provider.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ await result
finished_called[0] = True
@@ -123,17 +126,15 @@ class MediaStorage(object):
if not finished_called:
raise Exception("Finished callback not called")
- @defer.inlineCallbacks
- def fetch_media(self, file_info):
+ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
- file_info (FileInfo)
+ file_info
Returns:
- Deferred[Responder|None]: Returns a Responder if the file was found,
- otherwise None.
+ Returns a Responder if the file was found, otherwise None.
"""
path = self._file_info_to_path(file_info)
@@ -142,23 +143,26 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
- res = yield provider.fetch(path, file_info)
+ res = provider.fetch(path, file_info)
+ # Fetch is supposed to return an Awaitable, but guard against
+ # improper implementations.
+ if inspect.isawaitable(res):
+ res = await res
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
return None
- @defer.inlineCallbacks
- def ensure_media_is_in_local_cache(self, file_info):
+ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
- file_info (FileInfo)
+ file_info
Returns:
- Deferred[str]: Full path to local file
+ Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
@@ -170,14 +174,18 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
- res = yield provider.fetch(path, file_info)
+ res = provider.fetch(path, file_info)
+ # Fetch is supposed to return an Awaitable, but guard against
+ # improper implementations.
+ if inspect.isawaitable(res):
+ res = await res
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
)
- yield res.write_to_consumer(consumer)
- yield consumer.wait()
+ await res.write_to_consumer(consumer)
+ await consumer.wait()
return local_path
raise Exception("file could not be found")
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index e52c86c7..13d1a6d2 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -26,6 +26,7 @@ import traceback
from typing import Dict, Optional
from urllib import parse as urlparse
+import attr
from canonicaljson import json
from twisted.internet import defer
@@ -56,6 +57,65 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
+ONE_HOUR = 60 * 60 * 1000
+
+# A map of globs to API endpoints.
+_oembed_globs = {
+ # Twitter.
+ "https://publish.twitter.com/oembed": [
+ "https://twitter.com/*/status/*",
+ "https://*.twitter.com/*/status/*",
+ "https://twitter.com/*/moments/*",
+ "https://*.twitter.com/*/moments/*",
+ # Include the HTTP versions too.
+ "http://twitter.com/*/status/*",
+ "http://*.twitter.com/*/status/*",
+ "http://twitter.com/*/moments/*",
+ "http://*.twitter.com/*/moments/*",
+ ],
+}
+# Convert the globs to regular expressions.
+_oembed_patterns = {}
+for endpoint, globs in _oembed_globs.items():
+ for glob in globs:
+ # Convert the glob into a sane regular expression to match against. The
+ # rules followed will be slightly different for the domain portion vs.
+ # the rest.
+ #
+ # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
+ # 2. The domain can have globs, but we limit it to characters that can
+ # reasonably be a domain part.
+ # TODO: This does not attempt to handle Unicode domain names.
+ # 3. Other parts allow a glob to be any one, or more, characters.
+ results = urlparse.urlparse(glob)
+
+ # Ensure the scheme does not have wildcards (and is a sane scheme).
+ if results.scheme not in {"http", "https"}:
+ raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
+
+ pattern = urlparse.urlunparse(
+ [
+ results.scheme,
+ re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
+ ]
+ + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
+ )
+ _oembed_patterns[re.compile(pattern)] = endpoint
+
+
+@attr.s
+class OEmbedResult:
+ # Either HTML content or URL must be provided.
+ html = attr.ib(type=Optional[str])
+ url = attr.ib(type=Optional[str])
+ title = attr.ib(type=Optional[str])
+ # Number of seconds to cache the content.
+ cache_age = attr.ib(type=int)
+
+
+class OEmbedError(Exception):
+ """An error occurred processing the oEmbed object."""
+
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
@@ -99,7 +159,7 @@ class PreviewUrlResource(DirectServeJsonResource):
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
- expiry_ms=60 * 60 * 1000,
+ expiry_ms=ONE_HOUR,
)
if self._worker_run_media_background_jobs:
@@ -310,6 +370,87 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8")
+ def _get_oembed_url(self, url: str) -> Optional[str]:
+ """
+ Check whether the URL should be downloaded as oEmbed content instead.
+
+ Params:
+ url: The URL to check.
+
+ Returns:
+ A URL to use instead or None if the original URL should be used.
+ """
+ for url_pattern, endpoint in _oembed_patterns.items():
+ if url_pattern.fullmatch(url):
+ return endpoint
+
+ # No match.
+ return None
+
+ async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ """
+ Request content from an oEmbed endpoint.
+
+ Params:
+ endpoint: The oEmbed API endpoint.
+ url: The URL to pass to the API.
+
+ Returns:
+ An object representing the metadata returned.
+
+ Raises:
+ OEmbedError if fetching or parsing of the oEmbed information fails.
+ """
+ try:
+ logger.debug("Trying to get oEmbed content for url '%s'", url)
+ result = await self.client.get_json(
+ endpoint,
+ # TODO Specify max height / width.
+ # Note that only the JSON format is supported.
+ args={"url": url},
+ )
+
+ # Ensure there's a version of 1.0.
+ if result.get("version") != "1.0":
+ raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+ oembed_type = result.get("type")
+
+ # Ensure the cache age is None or an int.
+ cache_age = result.get("cache_age")
+ if cache_age:
+ cache_age = int(cache_age)
+
+ oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+ # HTML content.
+ if oembed_type == "rich":
+ oembed_result.html = result.get("html")
+ return oembed_result
+
+ if oembed_type == "photo":
+ oembed_result.url = result.get("url")
+ return oembed_result
+
+ # TODO Handle link and video types.
+
+ if "thumbnail_url" in result:
+ oembed_result.url = result.get("thumbnail_url")
+ return oembed_result
+
+ raise OEmbedError("Incompatible oEmbed information.")
+
+ except OEmbedError as e:
+ # Trap OEmbedErrors first so we can directly re-raise them.
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+ raise
+
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+ raise OEmbedError() from e
+
async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
@@ -319,54 +460,90 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ # If this URL can be accessed via oEmbed, use that instead.
+ url_to_download = url
+ oembed_url = self._get_oembed_url(url)
+ if oembed_url:
+ # The result might be a new URL to download, or it might be HTML content.
try:
- logger.debug("Trying to get preview for url '%s'", url)
- length, headers, uri, code = await self.client.get_file(
- url,
- output_stream=f,
- max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
- )
- except SynapseError:
- # Pass SynapseErrors through directly, so that the servlet
- # handler will return a SynapseError to the client instead of
- # blank data or a 500.
- raise
- except DNSLookupError:
- # DNS lookup returned no results
- # Note: This will also be the case if one of the resolved IP
- # addresses is blacklisted
- raise SynapseError(
- 502,
- "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN,
- )
- except Exception as e:
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading %s: %r", url, e)
+ oembed_result = await self._get_oembed_content(oembed_url, url)
+ if oembed_result.url:
+ url_to_download = oembed_result.url
+ elif oembed_result.html:
+ url_to_download = None
+ except OEmbedError:
+ # If an error occurs, try doing a normal preview.
+ pass
- raise SynapseError(
- 500,
- "Failed to download content: %s"
- % (traceback.format_exception_only(sys.exc_info()[0], e),),
- Codes.UNKNOWN,
- )
- await finish()
+ if url_to_download:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
+ logger.debug("Trying to get preview for url '%s'", url_to_download)
+ length, headers, uri, code = await self.client.get_file(
+ url_to_download,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
+ )
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
+ )
+ except Exception as e:
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading %s: %r", url_to_download, e)
+
+ raise SynapseError(
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+ await finish()
+
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
+
+ download_name = get_filename_from_headers(headers)
+
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ expires = ONE_HOUR
+ etag = headers["ETag"][0] if "ETag" in headers else None
+ else:
+ html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ f.write(html_bytes)
+ await finish()
+
+ media_type = "text/html"
+ download_name = oembed_result.title
+ length = len(html_bytes)
+ # If a specific cache age was not given, assume 1 hour.
+ expires = oembed_result.cache_age or ONE_HOUR
+ uri = oembed_url
+ code = 200
+ etag = None
try:
- if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode("ascii")
- else:
- media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
- download_name = get_filename_from_headers(headers)
-
await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
- time_now_ms=self.clock.time_msec(),
+ time_now_ms=time_now_ms,
upload_name=download_name,
media_length=length,
user_id=user,
@@ -389,10 +566,8 @@ class PreviewUrlResource(DirectServeJsonResource):
"filename": fname,
"uri": uri,
"response_code": code,
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- "expires": 60 * 60 * 1000,
- "etag": headers["ETag"][0] if "ETag" in headers else None,
+ "expires": expires,
+ "etag": etag,
}
def _start_expire_url_cache_data(self):
@@ -449,7 +624,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * 24 * 60 * 60 * 1000
+ expire_before = now - 2 * 24 * ONE_HOUR
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
diff --git a/synapse/server.py b/synapse/server.py
index 6acce2e2..8e411125 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -44,7 +44,6 @@ from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import (
FederationHandlerRegistry,
FederationServer,
- ReplicationFederationHandlerRegistry,
)
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.sender import FederationSender
@@ -73,14 +72,18 @@ from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
-from synapse.handlers.room import RoomContextHandler, RoomCreationHandler
+from synapse.handlers.room import (
+ RoomContextHandler,
+ RoomCreationHandler,
+ RoomShutdownHandler,
+)
from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
-from synapse.handlers.typing import TypingHandler
+from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
@@ -102,7 +105,7 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
-from synapse.storage import DataStores, Storage
+from synapse.storage import DataStore, DataStores, Storage
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -144,6 +147,7 @@ class HomeServer(object):
"handlers",
"auth",
"room_creation_handler",
+ "room_shutdown_handler",
"state_handler",
"state_resolution_handler",
"presence_handler",
@@ -307,7 +311,7 @@ class HomeServer(object):
def get_clock(self):
return self.clock
- def get_datastore(self):
+ def get_datastore(self) -> DataStore:
return self.datastores.main
def get_datastores(self):
@@ -357,6 +361,9 @@ class HomeServer(object):
def build_room_creation_handler(self):
return RoomCreationHandler(self)
+ def build_room_shutdown_handler(self):
+ return RoomShutdownHandler(self)
+
def build_sendmail(self):
return sendmail
@@ -370,7 +377,10 @@ class HomeServer(object):
return PresenceHandler(self)
def build_typing_handler(self):
- return TypingHandler(self)
+ if self.config.worker.writers.typing == self.get_instance_name():
+ return TypingWriterHandler(self)
+ else:
+ return FollowerTypingHandler(self)
def build_sync_handler(self):
return SyncHandler(self)
@@ -526,10 +536,7 @@ class HomeServer(object):
return RoomMemberMasterHandler(self)
def build_federation_registry(self):
- if self.config.worker_app:
- return ReplicationFederationHandlerRegistry(self)
- else:
- return FederationHandlerRegistry()
+ return FederationHandlerRegistry(self)
def build_server_notices_manager(self):
if self.config.worker_app:
diff --git a/synapse/server.pyi b/synapse/server.pyi
index fe8024d2..1aba408c 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -20,6 +20,7 @@ import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
+import synapse.http.matrixfederationclient
import synapse.notifier
import synapse.push.pusherpool
import synapse.replication.tcp.client
@@ -30,6 +31,7 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
+from synapse.handlers.typing import FollowerTypingHandler
from synapse.replication.tcp.streams import Stream
class HomeServer(object):
@@ -71,6 +73,8 @@ class HomeServer(object):
pass
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
pass
+ def get_room_shutdown_handler(self) -> synapse.handlers.room.RoomShutdownHandler:
+ pass
def get_event_creation_handler(
self,
) -> synapse.handlers.message.EventCreationHandler:
@@ -141,3 +145,11 @@ class HomeServer(object):
pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass
+ def get_http_client(
+ self,
+ ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
+ pass
+ def should_send_federation(self) -> bool:
+ pass
+ def get_typing_handler(self) -> FollowerTypingHandler:
+ pass
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 495d9f04..25ccef5a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,14 +16,12 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set
+from typing import Awaitable, Dict, Iterable, List, Optional, Set
import attr
from frozendict import frozendict
from prometheus_client import Histogram
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
@@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@@ -108,8 +107,7 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
- @defer.inlineCallbacks
- def get_current_state(
+ async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
@@ -126,20 +124,20 @@ class StateHandler(object):
map from (type, state_key) to event
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
- ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
return event
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
@@ -148,8 +146,7 @@ class StateHandler(object):
return state
- @defer.inlineCallbacks
- def get_current_state_ids(self, room_id, latest_event_ids=None):
+ async def get_current_state_ids(self, room_id, latest_event_ids=None):
"""Get the current state, or the state at a set of events, for a room
Args:
@@ -164,41 +161,38 @@ class StateHandler(object):
(event_type, state_key) -> event_id
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
return state
- @defer.inlineCallbacks
- def get_current_users_in_room(self, room_id, latest_event_ids=None):
+ async def get_current_users_in_room(
+ self, room_id: str, latest_event_ids: Optional[List[str]] = None
+ ) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.
Args:
- room_id (str): The ID of the room.
- latest_event_ids (List[str]|None): Precomputed list of latest
- event IDs. Will be computed if None.
+ room_id: The ID of the room.
+ latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
- Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
- profileinfo.
+ Dictionary of user IDs to their profileinfo.
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_users_in_room")
- entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
- joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
+ entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
- @defer.inlineCallbacks
- def get_current_hosts_in_room(self, room_id):
- event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
+ async def get_current_hosts_in_room(self, room_id):
+ event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ return await self.get_hosts_in_room_at_events(room_id, event_ids)
- @defer.inlineCallbacks
- def get_hosts_in_room_at_events(self, room_id, event_ids):
+ async def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids
Args:
@@ -208,12 +202,11 @@ class StateHandler(object):
Returns:
Deferred[list[str]]: the hosts in the room at the given events
"""
- entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
- joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
+ entry = await self.resolve_state_groups_for_events(room_id, event_ids)
+ joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
- @defer.inlineCallbacks
- def compute_event_context(
+ async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
"""Build an EventContext structure for the event.
@@ -278,7 +271,7 @@ class StateHandler(object):
# otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")
- entry = yield self.resolve_state_groups_for_events(
+ entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)
@@ -295,7 +288,7 @@ class StateHandler(object):
#
if not state_group_before_event:
- state_group_before_event = yield self.state_store.store_state_group(
+ state_group_before_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
@@ -335,7 +328,7 @@ class StateHandler(object):
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
- state_group_after_event = yield self.state_store.store_state_group(
+ state_group_after_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
@@ -353,8 +346,7 @@ class StateHandler(object):
)
@measure_func()
- @defer.inlineCallbacks
- def resolve_state_groups_for_events(self, room_id, event_ids):
+ async def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -373,7 +365,7 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
- state_groups_ids = yield self.state_store.get_state_groups_ids(
+ state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)
@@ -382,7 +374,7 @@ class StateHandler(object):
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
- prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
+ prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
return _StateCacheEntry(
state=state_list,
@@ -391,9 +383,9 @@ class StateHandler(object):
delta_ids=delta_ids,
)
- room_version = yield self.store.get_room_version_id(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
- result = yield self._state_resolution_handler.resolve_state_groups(
+ result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
@@ -402,8 +394,7 @@ class StateHandler(object):
)
return result
- @defer.inlineCallbacks
- def resolve_events(self, room_version, state_sets, event):
+ async def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@@ -414,7 +405,7 @@ class StateHandler(object):
state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_store(
+ new_state = await resolve_events_with_store(
self.clock,
event.room_id,
room_version,
@@ -451,9 +442,8 @@ class StateResolutionHandler(object):
reset_expiry_on_get=True,
)
- @defer.inlineCallbacks
@log_function
- def resolve_state_groups(
+ async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
):
"""Resolves conflicts between a set of state groups
@@ -479,13 +469,13 @@ class StateResolutionHandler(object):
state_res_store (StateResolutionStore)
Returns:
- Deferred[_StateCacheEntry]: resolved state
+ _StateCacheEntry: resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys())
- with (yield self.resolve_linearizer.queue(group_names)):
+ with (await self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
@@ -517,7 +507,7 @@ class StateResolutionHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_store(
+ new_state = await resolve_events_with_store(
self.clock,
room_id,
room_version,
@@ -598,7 +588,7 @@ def resolve_events_with_store(
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
-):
+) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
@@ -619,8 +609,7 @@ def resolve_events_with_store(
state_res_store: a place to fetch events from
Returns:
- Deferred[dict[(str, str), str]]:
- a map from (type, state_key) to event_id.
+ a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 7b531a83..ab5e2484 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,9 +15,7 @@
import hashlib
import logging
-from typing import Callable, Dict, List, Optional
-
-from twisted.internet import defer
+from typing import Awaitable, Callable, Dict, List, Optional
from synapse import event_auth
from synapse.api.constants import EventTypes
@@ -32,12 +30,11 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
-@defer.inlineCallbacks
-def resolve_events_with_store(
+async def resolve_events_with_store(
room_id: str,
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable,
+ state_map_factory: Callable[[List[str]], Awaitable],
):
"""
Args:
@@ -56,7 +53,7 @@ def resolve_events_with_store(
state_map_factory: will be called
with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event.
+ an Awaitable that resolves to a dict of event_id to event.
Returns:
Deferred[dict[(str, str), str]]:
@@ -80,7 +77,7 @@ def resolve_events_with_store(
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
- state_map = yield state_map_factory(needed_events)
+ state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@@ -110,7 +107,7 @@ def resolve_events_with_store(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
)
- state_map_new = yield state_map_factory(new_needed_events)
+ state_map_new = await state_map_factory(new_needed_events)
for event in state_map_new.values():
if event.room_id != room_id:
raise Exception(
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index bf6caa09..6634955c 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -18,8 +18,6 @@ import itertools
import logging
from typing import Dict, List, Optional
-from twisted.internet import defer
-
import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
@@ -32,14 +30,13 @@ from synapse.util import Clock
logger = logging.getLogger(__name__)
-# We want to yield to the reactor occasionally during state res when dealing
+# We want to await to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by
-# yielding to reactor during loops every N iterations.
-_YIELD_AFTER_ITERATIONS = 100
+# awaiting to reactor during loops every N iterations.
+_AWAIT_AFTER_ITERATIONS = 100
-@defer.inlineCallbacks
-def resolve_events_with_store(
+async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
@@ -87,7 +84,7 @@ def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
+ auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
full_conflicted_set = set(
itertools.chain(
@@ -95,7 +92,7 @@ def resolve_events_with_store(
)
)
- events = yield state_res_store.get_events(
+ events = await state_res_store.get_events(
[eid for eid in full_conflicted_set if eid not in event_map],
allow_rejected=True,
)
@@ -118,14 +115,14 @@ def resolve_events_with_store(
eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
)
- sorted_power_events = yield _reverse_topological_power_sort(
+ sorted_power_events = await _reverse_topological_power_sort(
clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
- resolved_state = yield _iterative_auth_checks(
+ resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
@@ -148,13 +145,13 @@ def resolve_events_with_store(
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
- leftover_events = yield _mainline_sort(
+ leftover_events = await _mainline_sort(
clock, room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
- resolved_state = yield _iterative_auth_checks(
+ resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
@@ -174,8 +171,7 @@ def resolve_events_with_store(
return resolved_state
-@defer.inlineCallbacks
-def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
+async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to
their auth events.
@@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
Returns:
Deferred[int]
"""
- event = yield _get_event(room_id, event_id, event_map, state_res_store)
+ event = await _get_event(room_id, event_id, event_map, state_res_store)
pl = None
for aid in event.auth_event_ids():
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids():
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
@@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
return int(level)
-@defer.inlineCallbacks
-def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
@@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Deferred[set[str]]: Set of event IDs
"""
- difference = yield state_res_store.get_auth_chain_difference(
+ difference = await state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
)
@@ -292,8 +287,7 @@ def _is_power_event(event):
return False
-@defer.inlineCallbacks
-def _add_event_and_auth_chain_to_graph(
+async def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
):
"""Helper function for _reverse_topological_power_sort that add the event
@@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph(
eid = state.pop()
graph.setdefault(eid, set())
- event = yield _get_event(room_id, eid, event_map, state_res_store)
+ event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
if aid in auth_diff:
if aid not in graph:
@@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph(
graph.setdefault(eid, set()).add(aid)
-@defer.inlineCallbacks
-def _reverse_topological_power_sort(
+async def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
"""Returns a list of the event_ids sorted by reverse topological ordering,
@@ -344,26 +337,26 @@ def _reverse_topological_power_sort(
graph = {}
for idx, event_id in enumerate(event_ids, start=1):
- yield _add_event_and_auth_chain_to_graph(
+ await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
)
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
event_to_pl = {}
for idx, event_id in enumerate(graph, start=1):
- pl = yield _get_power_level_for_sender(
+ pl = await _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
def _get_power_order(event_id):
ev = event_map[event_id]
@@ -378,8 +371,7 @@ def _reverse_topological_power_sort(
return sorted_events
-@defer.inlineCallbacks
-def _iterative_auth_checks(
+async def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
):
"""Sequentially apply auth checks to each event in given list, updating the
@@ -405,7 +397,7 @@ def _iterative_auth_checks(
auth_events = {}
for aid in event.auth_event_ids():
- ev = yield _get_event(
+ ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
@@ -420,7 +412,7 @@ def _iterative_auth_checks(
for key in event_auth.auth_types_for_event(event):
if key in resolved_state:
ev_id = resolved_state[key]
- ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
+ ev = await _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
@@ -438,16 +430,15 @@ def _iterative_auth_checks(
except AuthError:
pass
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
return resolved_state
-@defer.inlineCallbacks
-def _mainline_sort(
+async def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
@@ -474,21 +465,21 @@ def _mainline_sort(
idx = 0
while pl:
mainline.append(pl)
- pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
+ pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids()
pl = None
for aid in auth_events:
- ev = yield _get_event(
+ ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid
break
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
idx += 1
@@ -498,23 +489,24 @@ def _mainline_sort(
order_map = {}
for idx, ev_id in enumerate(event_ids, start=1):
- depth = yield _get_mainline_depth_for_event(
+ depth = await _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids
-@defer.inlineCallbacks
-def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
+async def _get_mainline_depth_for_event(
+ event, mainline_map, event_map, state_res_store
+):
"""Get the mainline depths for the given event based on the mainline map
Args:
@@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
event = None
for aid in auth_events:
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
return 0
-@defer.inlineCallbacks
-def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
+async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
"""Helper function to look up event in event_map, falling back to looking
it up in the store
@@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
Deferred[Optional[FrozenEvent]]
"""
if event_id not in event_map:
- events = yield state_res_store.get_events([event_id], allow_rejected=True)
+ events = await state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
event = event_map.get(event_id)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index bfce541c..985a0428 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -100,8 +100,8 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
- # Decode it to a Unicode string before feeding it to json.loads, so we
- # consistenty get a Unicode-containing object out.
+ # Decode it to a Unicode string before feeding it to json.loads, since
+ # Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 59f3394b..018826ef 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -249,7 +249,10 @@ class BackgroundUpdater(object):
retcol="progress_json",
)
- progress = json.loads(progress_json)
+ # Avoid a circular import.
+ from synapse.storage._base import db_to_json
+
+ progress = db_to_json(progress_json)
time_start = self._clock.time_msec()
items_updated = await update_handler(progress, batch_size)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 4b4763c7..932458f6 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -128,7 +128,7 @@ class DataStore(
db_conn, "presence_stream", "stream_id"
)
self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index b58f04d0..33cc372d 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -77,7 +77,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
global_account_data = {
- row["account_data_type"]: json.loads(row["content"]) for row in rows
+ row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
rows = self.db.simple_select_list_txn(
@@ -90,7 +90,7 @@ class AccountDataWorkerStore(SQLBaseStore):
by_room = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
- room_data[row["account_data_type"]] = json.loads(row["content"])
+ room_data[row["account_data_type"]] = db_to_json(row["content"])
return global_account_data, by_room
@@ -113,7 +113,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
- return json.loads(result)
+ return db_to_json(result)
else:
return None
@@ -137,7 +137,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
return {
- row["account_data_type"]: json.loads(row["content"]) for row in rows
+ row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
return self.db.runInteraction(
@@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore):
allow_none=True,
)
- return json.loads(content_json) if content_json else None
+ return db_to_json(content_json) if content_json else None
return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
@@ -255,7 +255,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
- global_account_data = {row[0]: json.loads(row[1]) for row in txn}
+ global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
@@ -267,7 +267,7 @@ class AccountDataWorkerStore(SQLBaseStore):
account_data_by_room = {}
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
- room_account_data[row[1]] = json.loads(row[2])
+ room_account_data[row[1]] = db_to_json(row[2])
return global_account_data, account_data_by_room
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 7a1fe8cd..56659fed 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
@@ -303,7 +303,7 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
- event_ids = json.loads(entry["event_ids"])
+ event_ids = db_to_json(entry["event_ids"])
events = yield self.get_events_as_list(event_ids)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index d313b970..da297b31 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
@@ -65,7 +65,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(json.loads(row[1]))
+ messages.append(db_to_json(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return messages, stream_pos
@@ -173,7 +173,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(json.loads(row[1]))
+ messages.append(db_to_json(row[1]))
if len(messages) < limit:
log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
@@ -424,9 +424,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
- sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
- txn.execute(sql, (stream_id, stream_id))
-
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 343cf9a2..45581a65 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -577,7 +577,7 @@ class DeviceWorkerStore(SQLBaseStore):
rows = yield self.db.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
- return {user for row in rows for user in json.loads(row[0])}
+ return {user for row in rows for user in db_to_json(row[0])}
else:
return set()
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 23f4570c..615364f0 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
class EndToEndRoomKeyStore(SQLBaseStore):
@@ -148,7 +148,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"forwarded_count": row["forwarded_count"],
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
- "session_data": json.loads(row["session_data"]),
+ "session_data": db_to_json(row["session_data"]),
}
return sessions
@@ -222,7 +222,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": row[2],
"forwarded_count": row[3],
"is_verified": row[4],
- "session_data": json.loads(row[5]),
+ "session_data": db_to_json(row[5]),
}
return ret
@@ -319,7 +319,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
)
- result["auth_data"] = json.loads(result["auth_data"])
+ result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 6c3cff82..317c07a8 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -366,7 +366,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
- key = json.loads(row["keydata"])
+ key = db_to_json(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index bc9f4f08..504babaa 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -58,7 +58,7 @@ def _deserialize_action(actions, is_highlight):
"""Custom deserializer for actions. This allows us to "compress" common actions
"""
if actions:
- return json.loads(actions)
+ return db_to_json(actions)
if is_highlight:
return DEFAULT_HIGHLIGHT_ACTION
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 230fb5cd..6f2e0d15 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,11 +17,9 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from functools import wraps
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
import attr
-from canonicaljson import json
from prometheus_client import Counter
from twisted.internet import defer
@@ -33,7 +31,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
-from synapse.storage._base import make_in_list_sql_clause
+from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.data_stores.main.search import SearchEntry
from synapse.storage.database import Database, LoggingTransaction
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -69,27 +67,6 @@ def encode_json(json_object):
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-def _retry_on_integrity_error(func):
- """Wraps a database function so that it gets retried on IntegrityError,
- with `delete_existing=True` passed in.
-
- Args:
- func: function that returns a Deferred and accepts a `delete_existing` arg
- """
-
- @wraps(func)
- @defer.inlineCallbacks
- def f(self, *args, **kwargs):
- try:
- res = yield func(self, *args, delete_existing=False, **kwargs)
- except self.database_engine.module.IntegrityError:
- logger.exception("IntegrityError, retrying.")
- res = yield func(self, *args, delete_existing=True, **kwargs)
- return res
-
- return f
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -134,7 +111,6 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
- @_retry_on_integrity_error
@defer.inlineCallbacks
def _persist_events_and_state_updates(
self,
@@ -143,7 +119,6 @@ class PersistEventsStore:
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
- delete_existing: bool = False,
):
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -157,7 +132,6 @@ class PersistEventsStore:
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
backfilled
- delete_existing
Returns:
Deferred: resolves when the events have been persisted
@@ -197,7 +171,6 @@ class PersistEventsStore:
self._persist_events_txn,
events_and_contexts=events_and_contexts,
backfilled=backfilled,
- delete_existing=delete_existing,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
@@ -262,7 +235,7 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
+ results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
yield self.db.runInteraction(
@@ -323,7 +296,7 @@ class PersistEventsStore:
if prev_event_id in existing_prevs:
continue
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
to_recursively_check.append(prev_event_id)
existing_prevs.add(prev_event_id)
@@ -341,7 +314,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
- delete_existing: bool = False,
state_delta_for_room: Dict[str, DeltaState] = {},
new_forward_extremeties: Dict[str, List[str]] = {},
):
@@ -393,13 +365,6 @@ class PersistEventsStore:
# From this point onwards the events are only events that we haven't
# seen before.
- if delete_existing:
- # For paranoia reasons, we go and delete all the existing entries
- # for these events so we can reinsert them.
- # This gets around any problems with some tables already having
- # entries.
- self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
-
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
# Insert into event_to_state_groups.
@@ -617,7 +582,7 @@ class PersistEventsStore:
txn.execute(sql, (room_id, EventTypes.Create, ""))
row = txn.fetchone()
if row:
- event_json = json.loads(row[0])
+ event_json = db_to_json(row[0])
content = event_json.get("content", {})
creator = content.get("creator")
room_version_id = content.get("room_version", RoomVersions.V1.identifier)
@@ -797,39 +762,6 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
- @classmethod
- def _delete_existing_rows_txn(cls, txn, events_and_contexts):
- if not events_and_contexts:
- # nothing to do here
- return
-
- logger.info("Deleting existing")
-
- for table in (
- "events",
- "event_auth",
- "event_json",
- "event_edges",
- "event_forward_extremities",
- "event_reference_hashes",
- "event_search",
- "event_to_state_groups",
- "state_events",
- "rejections",
- "redactions",
- "room_memberships",
- ):
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- [(ev.event_id,) for ev, _ in events_and_contexts],
- )
-
- for table in ("event_push_actions",):
- txn.executemany(
- "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
- [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
- )
-
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 62d28f44..663c94b2 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -15,12 +15,10 @@
import logging
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.constants import EventContentFields
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
logger = logging.getLogger(__name__)
@@ -125,7 +123,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in rows:
try:
event_id = row[1]
- event_json = json.loads(row[2])
+ event_json = db_to_json(row[2])
sender = event_json["sender"]
content = event_json["content"]
@@ -208,7 +206,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in ev_rows:
event_id = row["event_id"]
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@@ -317,7 +315,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
soft_failed = False
if metadata:
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
@@ -358,7 +356,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
graph[event_id] = {prev_event_id}
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
else:
@@ -543,7 +541,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
last_row_event_id = ""
for (event_id, event_json_raw) in results:
try:
- event_json = json.loads(event_json_raw)
+ event_json = db_to_json(event_json_raw)
self.db.simple_insert_many_txn(
txn=txn,
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 01cad7d4..e812c670 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -21,7 +21,6 @@ import threading
from collections import namedtuple
from typing import List, Optional, Tuple
-from canonicaljson import json
from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -40,7 +39,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
@@ -611,8 +610,8 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = json.loads(row["json"])
- internal_metadata = json.loads(row["internal_metadata"])
+ d = db_to_json(row["json"])
+ internal_metadata = db_to_json(row["internal_metadata"])
format_version = row["format_version"]
if format_version is None:
@@ -640,7 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
- logger.error(
+ logger.warning(
"Event %s in room %s has unknown room version %s",
event_id,
d["room_id"],
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 4fb9f985..01ff561e 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -197,7 +197,7 @@ class GroupServerWorkerStore(SQLBaseStore):
categories = {
row[0]: {
"is_public": row[1],
- "profile": json.loads(row[2]),
+ "profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -221,7 +221,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["category_id"]: {
"is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ "profile": db_to_json(row["profile"]),
}
for row in rows
}
@@ -235,7 +235,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group_category",
)
- category["profile"] = json.loads(category["profile"])
+ category["profile"] = db_to_json(category["profile"])
return category
@@ -251,7 +251,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["role_id"]: {
"is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ "profile": db_to_json(row["profile"]),
}
for row in rows
}
@@ -265,7 +265,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group_role",
)
- role["profile"] = json.loads(role["profile"])
+ role["profile"] = db_to_json(role["profile"])
return role
@@ -333,7 +333,7 @@ class GroupServerWorkerStore(SQLBaseStore):
roles = {
row[0]: {
"is_public": row[1],
- "profile": json.loads(row[2]),
+ "profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -462,7 +462,7 @@ class GroupServerWorkerStore(SQLBaseStore):
now = int(self._clock.time_msec())
if row and now < row["valid_until_ms"]:
- return json.loads(row["attestation_json"])
+ return db_to_json(row["attestation_json"])
return None
@@ -489,7 +489,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": row[0],
"type": row[1],
"membership": row[2],
- "content": json.loads(row[3]),
+ "content": db_to_json(row[3]),
}
for row in txn
]
@@ -519,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": group_id,
"membership": membership,
"type": gtype,
- "content": json.loads(content_json),
+ "content": db_to_json(content_json),
}
for group_id, membership, gtype, content_json in txn
]
@@ -567,7 +567,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [
- (stream_id, (group_id, user_id, gtype, json.loads(content_json)))
+ (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index f6e78ca5..c2292481 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -24,7 +24,7 @@ from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
@@ -43,8 +43,8 @@ def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
- rule["conditions"] = json.loads(rawrule["conditions"])
- rule["actions"] = json.loads(rawrule["actions"])
+ rule["conditions"] = db_to_json(rawrule["conditions"])
+ rule["actions"] = db_to_json(rawrule["actions"])
rule["default"] = False
ruleslist.append(rule)
@@ -259,7 +259,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 54610162..e18f1ca8 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -17,11 +17,11 @@
import logging
from typing import Iterable, Iterator, List, Tuple
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class PusherWorkerStore(SQLBaseStore):
for r in rows:
dataJson = r["data"]
try:
- r["data"] = json.loads(dataJson)
+ r["data"] = db_to_json(dataJson)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 8f5505bd..1d723f2d 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.async_helpers import ObservableDeferred
@@ -203,7 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
- ] = json.loads(row["data"])
+ ] = db_to_json(row["data"])
return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@@ -260,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
- receipt_type[row["user_id"]] = json.loads(row["data"])
+ receipt_type[row["user_id"]] = db_to_json(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -329,7 +329,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn]
+ updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
limited = False
upper_bound = current_id
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 587d4b91..27d2c502 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore):
self.config = hs.config
self.clock = hs.get_clock()
+ self._user_id_seq = build_sequence_generator(
+ database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+ )
+
@cached()
def get_user_by_id(self, user_id):
return self.db.simple_select_one(
@@ -481,39 +487,17 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret
- @defer.inlineCallbacks
- def find_next_generated_user_id_localpart(self):
- """
- Gets the localpart of the next generated user ID.
+ async def generate_user_id(self) -> str:
+ """Generate a suitable localpart for a guest user
- Generated user IDs are integers, so we find the largest integer user ID
- already taken and return that plus one.
+ Returns: a (hopefully) free localpart
"""
-
- def _find_next_generated_user_id(txn):
- # We bound between '@0' and '@a' to avoid pulling the entire table
- # out.
- txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
-
- regex = re.compile(r"^@(\d+):")
-
- max_found = 0
-
- for (user_id,) in txn:
- match = regex.search(user_id)
- if match:
- max_found = max(int(match.group(1)), max_found)
-
- return max_found + 1
-
- return (
- (
- yield self.db.runInteraction(
- "find_next_generated_user_id", _find_next_generated_user_id
- )
- )
+ next_id = await self.db.runInteraction(
+ "generate_user_id", self._user_id_seq.get_next_id_txn
)
+ return str(next_id)
+
async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid
@@ -1573,3 +1557,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
+
+
+def find_max_generated_user_id_localpart(cur: Cursor) -> int:
+ """
+ Gets the localpart of the max current generated user ID.
+
+ Generated user IDs are integers, so we find the largest integer user ID
+ already taken and return that.
+ """
+
+ # We bound between '@0' and '@a' to avoid pulling the entire table
+ # out.
+ cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
+
+ regex = re.compile(r"^@(\d+):")
+
+ max_found = 0
+
+ for (user_id,) in cur:
+ match = regex.search(user_id)
+ if match:
+ max_found = max(int(match.group(1)), max_found)
+ return max_found
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index c473cf15..d2e1e36e 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -28,7 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.search import SearchStore
from synapse.storage.database import Database, LoggingTransaction
from synapse.types import ThirdPartyInstanceID
@@ -118,7 +118,12 @@ class RoomWorkerStore(SQLBaseStore):
WHERE room_id = ?
"""
txn.execute(sql, [room_id])
- res = self.db.cursor_to_dict(txn)[0]
+ # Catch error if sql returns empty result to return "None" instead of an error
+ try:
+ res = self.db.cursor_to_dict(txn)[0]
+ except IndexError:
+ return None
+
res["federatable"] = bool(res["federatable"])
res["public"] = bool(res["public"])
return res
@@ -665,7 +670,7 @@ class RoomWorkerStore(SQLBaseStore):
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
- event_json = json.loads(content_json)
+ event_json = db_to_json(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
@@ -910,8 +915,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
if not row["json"]:
retention_policy = {}
else:
- ev = json.loads(row["json"])
- retention_policy = json.dumps(ev["content"])
+ ev = db_to_json(row["json"])
+ retention_policy = ev["content"]
self.db.simple_insert_txn(
txn=txn,
@@ -966,7 +971,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
updates = []
for room_id, event_json in txn:
- event_dict = json.loads(event_json)
+ event_dict = db_to_json(event_json)
room_version_id = event_dict.get("content", {}).get(
"room_version", RoomVersions.V1.identifier
)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 44bab65e..a92e401e 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -17,8 +17,6 @@
import logging
from typing import Iterable, List, Set
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -27,6 +25,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import (
LoggingTransaction,
SQLBaseStore,
+ db_to_json,
make_in_list_sql_clause,
)
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
@@ -498,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
@@ -938,7 +937,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
event_id = row["event_id"]
room_id = row["room_id"]
try:
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql
new file mode 100644
index 00000000..1cc2633a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We need to store the stream positions by instance in a sharded config world.
+--
+-- We default to master as we want the column to be NOT NULL and we correctly
+-- reset the instance name to match the config each time we start up.
+ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master';
+
+CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name);
diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
new file mode 100644
index 00000000..2011f6bc
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
@@ -0,0 +1,34 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adds a postgres SEQUENCE for generating guest user IDs.
+"""
+
+from synapse.storage.data_stores.main.registration import (
+ find_max_generated_user_id_localpart,
+)
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if not isinstance(database_engine, PostgresEngine):
+ return
+
+ next_id = find_max_generated_user_id_localpart(cur) + 1
+ cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,))
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index a8381dc5..d5222829 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -17,12 +17,10 @@ import logging
import re
from collections import namedtuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -157,7 +155,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try:
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 347cc507..bb38a04e 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn):
+ # get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ? ORDER BY room_id LIMIT ?
@@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
if not room_ids:
return True, set()
+ ###########################################################################
+ #
+ # exclude rooms where we have active members
+
sql = """
SELECT room_id
- FROM current_state_events
+ FROM local_current_membership
WHERE
room_id > ? AND room_id <= ?
- AND type = 'm.room.member'
AND membership = 'join'
- AND state_key LIKE ?
GROUP BY room_id
"""
- txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
-
+ txn.execute(sql, (last_room_id, room_ids[-1]))
joined_room_ids = {row[0] for row in txn}
+ to_delete = set(room_ids) - joined_room_ids
+
+ ###########################################################################
+ #
+ # exclude rooms which we are in the process of constructing; these otherwise
+ # qualify as "rooms with no local users", and would have their
+ # forward extremities cleaned up.
+
+ # the following query will return a list of rooms which have forward
+ # extremities that are *not* also the create event in the room - ie
+ # those that are not being created currently.
+
+ sql = """
+ SELECT DISTINCT efe.room_id
+ FROM event_forward_extremities efe
+ LEFT JOIN current_state_events cse ON
+ cse.event_id = efe.event_id
+ AND cse.type = 'm.room.create'
+ AND cse.state_key = ''
+ WHERE
+ cse.event_id IS NULL
+ AND efe.room_id > ? AND efe.room_id <= ?
+ """
+
+ txn.execute(sql, (last_room_id, room_ids[-1]))
+
+ # build a set of those rooms within `to_delete` that do not appear in
+ # the above, leaving us with the rooms in `to_delete` that *are* being
+ # created.
+ creating_rooms = to_delete.difference(row[0] for row in txn)
+ logger.info("skipping rooms which are being created: %s", creating_rooms)
+
+ # now remove the rooms being created from the list of those to delete.
+ #
+ # (we could have just taken the intersection of `to_delete` with the result
+ # of the sql query, but it's useful to be able to log `creating_rooms`; and
+ # having done so, it's quicker to remove the (few) creating rooms from
+ # `to_delete` than it is to form the intersection with the (larger) list of
+ # not-creating-rooms)
+
+ to_delete -= creating_rooms
- left_rooms = set(room_ids) - joined_room_ids
+ ###########################################################################
+ #
+ # now clear the state for the rooms
- logger.info("Deleting current state left rooms: %r", left_rooms)
+ logger.info("Deleting current state left rooms: %r", to_delete)
# First we get all users that we still think were joined to the
# room. This is so that we can mark those device lists as
@@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
retcols=("state_key",),
)
@@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={},
)
@@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="event_forward_extremities",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={},
)
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 379d758b..10d39b36 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -45,7 +45,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -253,6 +253,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+ self._instance_name = hs.get_instance_name()
+ self._send_federation = hs.should_send_federation()
+ self._federation_shard_config = hs.config.worker.federation_shard_config
+
+ # If we're a process that sends federation we may need to reset the
+ # `federation_stream_position` table to match the current sharding
+ # config. We don't do this now as otherwise two processes could conflict
+ # during startup which would cause one to die.
+ self._need_to_reset_federation_stream_positions = self._send_federation
+
events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self.db.get_cache_dict(
db_conn,
@@ -793,22 +803,95 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, events
- def get_federation_out_pos(self, typ):
- return self.db.simple_select_one_onecol(
+ async def get_federation_out_pos(self, typ: str) -> int:
+ if self._need_to_reset_federation_stream_positions:
+ await self.db.runInteraction(
+ "_reset_federation_positions_txn", self._reset_federation_positions_txn
+ )
+ self._need_to_reset_federation_stream_positions = False
+
+ return await self.db.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
- keyvalues={"type": typ},
+ keyvalues={"type": typ, "instance_name": self._instance_name},
desc="get_federation_out_pos",
)
- def update_federation_out_pos(self, typ, stream_id):
- return self.db.simple_update_one(
+ async def update_federation_out_pos(self, typ, stream_id):
+ if self._need_to_reset_federation_stream_positions:
+ await self.db.runInteraction(
+ "_reset_federation_positions_txn", self._reset_federation_positions_txn
+ )
+ self._need_to_reset_federation_stream_positions = False
+
+ return await self.db.simple_update_one(
table="federation_stream_position",
- keyvalues={"type": typ},
+ keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
+ def _reset_federation_positions_txn(self, txn):
+ """Fiddles with the `federation_stream_position` table to make it match
+ the configured federation sender instances during start up.
+ """
+
+ # The federation sender instances may have changed, so we need to
+ # massage the `federation_stream_position` table to have a row per type
+ # per instance sending federation. If there is a mismatch we update the
+ # table with the correct rows using the *minimum* stream ID seen. This
+ # may result in resending of events/EDUs to remote servers, but that is
+ # preferable to dropping them.
+
+ if not self._send_federation:
+ return
+
+ # Pull out the configured instances. If we don't have a shard config then
+ # we assume that we're the only instance sending.
+ configured_instances = self._federation_shard_config.instances
+ if not configured_instances:
+ configured_instances = [self._instance_name]
+ elif self._instance_name not in configured_instances:
+ return
+
+ instances_in_table = self.db.simple_select_onecol_txn(
+ txn,
+ table="federation_stream_position",
+ keyvalues={},
+ retcol="instance_name",
+ )
+
+ if set(instances_in_table) == set(configured_instances):
+ # Nothing to do
+ return
+
+ sql = """
+ SELECT type, MIN(stream_id) FROM federation_stream_position
+ GROUP BY type
+ """
+ txn.execute(sql)
+ min_positions = dict(txn) # Map from type -> min position
+
+ # Ensure we do actually have some values here
+ assert set(min_positions) == {"federation", "events"}
+
+ sql = """
+ DELETE FROM federation_stream_position
+ WHERE NOT (%s)
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "instance_name", configured_instances
+ )
+ txn.execute(sql % (clause,), args)
+
+ for typ, stream_id in min_positions.items():
+ self.db.simple_upsert_txn(
+ txn,
+ table="federation_stream_position",
+ keyvalues={"type": typ, "instance_name": self._instance_name},
+ values={"stream_id": stream_id},
+ )
+
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 290317fd..bd722777 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import db_to_json
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached
@@ -49,7 +50,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = json.loads(row["content"])
+ room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room
return deferred
@@ -180,7 +181,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(
- lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
+ lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
)
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index 4c044b1a..5f1b9197 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from typing import Any, Dict, Optional, Union
import attr
+from canonicaljson import json
from synapse.api.errors import StoreError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.types import JsonDict
from synapse.util import stringutils as stringutils
@@ -118,7 +118,7 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session",
)
- result["clientdict"] = json.loads(result["clientdict"])
+ result["clientdict"] = db_to_json(result["clientdict"])
return UIAuthSessionData(session_id, **result)
@@ -168,7 +168,7 @@ class UIAuthWorkerStore(SQLBaseStore):
retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages",
):
- results[row["stage_type"]] = json.loads(row["result"])
+ results[row["stage_type"]] = db_to_json(row["result"])
return results
@@ -224,7 +224,7 @@ class UIAuthWorkerStore(SQLBaseStore):
)
# Update it and add it back to the database.
- serverdict = json.loads(result["serverdict"])
+ serverdict = db_to_json(result["serverdict"])
serverdict[key] = value
self.db.simple_update_one_txn(
@@ -254,7 +254,7 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session_data",
)
- serverdict = json.loads(result["serverdict"])
+ serverdict = db_to_json(result["serverdict"])
return serverdict.get(key, default)
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 6b8130bf..942e51fd 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id
)
- users_with_profile = yield state.get_current_users_in_room(room_id)
+ users_with_profile = yield defer.ensureDeferred(
+ state.get_current_users_in_room(room_id)
+ )
user_ids = set(users_with_profile)
# Update each user in the user directory.
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index ec6b8a4f..d3038ff0 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -70,11 +70,11 @@ class UserErasureWorkerStore(SQLBaseStore):
class UserErasureStore(UserErasureWorkerStore):
- def mark_user_erased(self, user_id):
+ def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
- user_id (str): full user_id to be erased
+ user_id: full user_id to be erased
"""
def f(txn):
@@ -89,3 +89,25 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.db.runInteraction("mark_user_erased", f)
+
+ def mark_user_not_erased(self, user_id: str) -> None:
+ """Indicate that user_id is no longer erased.
+
+ Args:
+ user_id: full user_id to be un-erased
+ """
+
+ def f(txn):
+ # first check if they are already in the list
+ txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
+ if not txn.fetchone():
+ return
+
+ # They are there, delete them.
+ self.simple_delete_one_txn(
+ txn, "erased_users", keyvalues={"user_id": user_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
+
+ return self.db.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 5db9f201..128c09a2 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -92,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"*stateGroupMembersCache*", 500000,
)
+ def get_max_state_group_txn(txn: Cursor):
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ return txn.fetchone()[0]
+
+ self._state_group_seq_gen = build_sequence_generator(
+ self.database_engine, get_max_state_group_txn, "state_group_id_seq"
+ )
+
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@@ -386,7 +396,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
- state_group = self.database_engine.get_next_state_group_id(txn)
+ state_group = self._state_group_seq_gen.get_next_id_txn(txn)
self.db.simple_insert_txn(
txn,
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ab0bbe4b..908cbc79 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def lock_table(self, txn, table: str) -> None:
...
- @abc.abstractmethod
- def get_next_state_group_id(self, txn) -> int:
- """Returns an int that can be used as a new state_group ID
- """
- ...
-
@property
@abc.abstractmethod
def server_version(self) -> str:
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a3158808..ff39281f 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
- def get_next_state_group_id(self, txn):
- """Returns an int that can be used as a new state_group ID
- """
- txn.execute("SELECT nextval('state_group_id_seq')")
- return txn.fetchone()[0]
-
@property
def server_version(self):
"""Returns a string giving the server version. For example: '8.1.5'
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 215a9494..8a0f8c89 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def lock_table(self, txn, table):
return
- def get_next_state_group_id(self, txn):
- """Returns an int that can be used as a new state_group ID
- """
- # We do application locking here since if we're using sqlite then
- # we are a single process synapse.
- with self._current_state_group_id_lock:
- if self._current_state_group_id is None:
- txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
- self._current_state_group_id = txn.fetchone()[0]
-
- self._current_state_group_id += 1
- return self._current_state_group_id
-
@property
def server_version(self):
"""Gets a string giving the server version. For example: '3.22.0'
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index fa460416..78fbdcde 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -29,7 +29,6 @@ from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main.events import DeltaState
from synapse.types import StateMap
@@ -648,6 +647,10 @@ class EventsPersistenceStorage(object):
room_version = await self.main_store.get_room_version_id(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
+
+ # Avoid a circular import.
+ from synapse.state import StateResolutionStore
+
res = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0be..787cebfb 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple
from typing_extensions import Deque
from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.util.sequence import PostgresSequenceGenerator
class IdGenerator(object):
@@ -247,7 +248,6 @@ class MultiWriterIdGenerator:
):
self._db = db
self._instance_name = instance_name
- self._sequence_name = sequence_name
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
@@ -260,6 +260,8 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
@@ -283,9 +285,7 @@ class MultiWriterIdGenerator:
return current_positions
def _load_next_id_txn(self, txn):
- txn.execute("SELECT nextval(?)", (self._sequence_name,))
- (next_id,) = txn.fetchone()
- return next_id
+ return self._sequence_gen.get_next_id_txn(txn)
async def get_next(self):
"""
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 00000000..63dfea42
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+ """A class which generates a unique sequence of integers"""
+
+ @abc.abstractmethod
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ """Gets the next ID in the sequence"""
+ ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+ def __init__(self, sequence_name: str):
+ self._sequence_name = sequence_name
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ return txn.fetchone()[0]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses local locking
+
+ This only works reliably if there are no other worker processes generating IDs at
+ the same time.
+ """
+
+ def __init__(self, get_first_callback: GetFirstCallbackType):
+ """
+ Args:
+ get_first_callback: a callback which is called on the first call to
+ get_next_id_txn; should return the curreent maximum id
+ """
+ # the callback. this is cleared after it is called, so that it can be GCed.
+ self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
+
+ # The current max value, or None if we haven't looked in the DB yet.
+ self._current_max_id = None # type: Optional[int]
+ self._lock = threading.Lock()
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ self._current_max_id += 1
+ return self._current_max_id
+
+
+def build_sequence_generator(
+ database_engine: BaseDatabaseEngine,
+ get_first_callback: GetFirstCallbackType,
+ sequence_name: str,
+) -> SequenceGenerator:
+ """Get the best impl of SequenceGenerator available
+
+ This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+ sqlite.
+
+ Args:
+ database_engine: the database engine we are connected to
+ get_first_callback: a callback which gets the next sequence ID. Used if
+ we're on sqlite.
+ sequence_name: the name of a postgres sequence to use.
+ """
+ if isinstance(database_engine, PostgresEngine):
+ return PostgresSequenceGenerator(sequence_name)
+ else:
+ return LocalSequenceGenerator(get_first_callback)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index da20523b..22a857a3 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import inspect
import logging
from twisted.internet import defer
+from twisted.internet.defer import Deferred, fail, succeed
+from twisted.python import failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -79,6 +81,28 @@ class Distributor(object):
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
+def maybeAwaitableDeferred(f, *args, **kw):
+ """
+ Invoke a function that may or may not return a Deferred or an Awaitable.
+
+ This is a modified version of twisted.internet.defer.maybeDeferred.
+ """
+ try:
+ result = f(*args, **kw)
+ except Exception:
+ return fail(failure.Failure(captureVars=Deferred.debug))
+
+ if isinstance(result, Deferred):
+ return result
+ # Handle the additional case of an awaitable being returned.
+ elif inspect.isawaitable(result):
+ return defer.ensureDeferred(result)
+ elif isinstance(result, failure.Failure):
+ return fail(result)
+ else:
+ return succeed(result)
+
+
class Signal(object):
"""A Signal is a dispatch point that stores a list of callables as
observers of it.
@@ -122,7 +146,7 @@ class Signal(object):
),
)
- return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
+ return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
deferreds = [run_in_background(do, o) for o in self.observers]
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 08c86e92..2e2b40a4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -17,7 +17,7 @@ import itertools
import random
import re
import string
-from collections import Iterable
+from collections.abc import Iterable
from synapse.api.errors import Codes, SynapseError
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 640f5f3b..3a806262 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase):
serialize/deserialize.
"""
- event, context = create_event(
- self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ event, context = self.get_success(
+ create_event(
+ self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ )
)
self._check_serialize_deserialize(event, context)
@@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
- event, context = create_event(
- self.hs,
- room_id=self.room_id,
- type="m.test",
- sender=self.user_id,
- state_key="",
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
+ state_key="",
+ )
)
self._check_serialize_deserialize(event, context)
@@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
- event, context = create_event(
- self.hs,
- room_id=self.room_id,
- type="m.room.member",
- sender=self.user_id,
- state_key=self.user_id,
- content={"membership": "leave"},
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.room.member",
+ sender=self.user_id,
+ state_key=self.user_id,
+ content={"membership": "leave"},
+ )
)
self._check_serialize_deserialize(event, context)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 1a9bd5f3..d1bd18da 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -26,21 +26,24 @@ from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
+from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
+ # Ensure a new Awaitable is created for each call.
+ mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ ["test", "host2"]
+ )
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
+ state_handler=mock_state_handler,
federation_transport_client=Mock(spec=["send_transaction"]),
)
@override_config({"send_federation": True})
def test_send_receipts(self):
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
@@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
@@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
@@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
- # stub out get_current_hosts_in_room
- mock_state_handler = hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 62b47f65..6aa322bf 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
- res = self.handler.get_device(user1, "abc")
- self.pump()
- self.assertIsInstance(
- self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ self.get_failure(
+ self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a
@@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
- res = self.handler.update_device("user_id", "unknown_device_id", update)
- self.pump()
- self.assertIsInstance(
- self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ self.get_failure(
+ self.handler.update_device("user_id", "unknown_device_id", update),
+ synapse.api.errors.NotFoundError,
)
def _record_users(self):
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 1acf287c..210ddcbb 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"""If the user has no devices, we expect an empty list.
"""
local_user = "@boris:" + self.hs.hostname
- res = yield self.handler.query_local_devices({local_user: None})
+ res = yield defer.ensureDeferred(
+ self.handler.query_local_devices({local_user: None})
+ )
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
@@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ )
)
self.fail("No error when changing string key")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ )
)
self.fail("No error when replacing dict key with string")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+ )
)
self.fail("No error when replacing string key with dict")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user,
- device_id,
- {
- "one_time_keys": {
- "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
- }
- },
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {
+ "one_time_keys": {
+ "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+ }
+ },
+ )
)
self.fail("No error when replacing dict key")
except errors.SynapseError:
@@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz"
keys = {"alg1:k1": "key1"}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
- res2 = yield self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ res2 = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
)
self.assertEqual(
res2,
@@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
keys2 = {
"master_key": {
@@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys2)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys2)
+ )
- devices = yield self.handler.query_devices(
- {"device_keys": {local_user: []}}, 0, local_user
+ devices = yield defer.ensureDeferred(
+ self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
)
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
@@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
}
- yield self.handler.upload_keys_for_user(
- local_user, "abc", {"device_keys": device_key_1}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, "abc", {"device_keys": device_key_1}
+ )
)
- yield self.handler.upload_keys_for_user(
- local_user, "def", {"device_keys": device_key_2}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, "def", {"device_keys": device_key_2}
+ )
)
# sign the first device key and upload it
del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key)
- yield self.handler.upload_signatures_for_device_keys(
- local_user, {local_user: {"abc": device_key_1}}
+ yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1}}
+ )
)
# sign the second device key and upload both device keys. The server
@@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# signature for it
del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key)
- yield self.handler.upload_signatures_for_device_keys(
- local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ )
)
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
- devices = yield self.handler.query_devices(
- {"device_keys": {local_user: []}}, 0, local_user
+ devices = yield defer.ensureDeferred(
+ self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -292,20 +328,26 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
res = None
try:
- yield self.hs.get_device_handler().check_device_registered(
- user_id=local_user,
- device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
- initial_device_display_name="new display name",
+ yield defer.ensureDeferred(
+ self.hs.get_device_handler().check_device_registered(
+ user_id=local_user,
+ device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ initial_device_display_name="new display name",
+ )
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 400)
- res = yield self.handler.query_local_devices({local_user: None})
+ res = yield defer.ensureDeferred(
+ self.handler.query_local_devices({local_user: None})
+ )
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -331,8 +373,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
)
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"device_keys": device_key}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"device_keys": device_key}
+ )
)
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
@@ -372,7 +416,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key,
}
- yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+ )
# set up another user with a master key. This user will be signed by
# the first user
@@ -384,76 +430,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
}
- yield self.handler.upload_signing_keys_for_user(
- other_user, {"master_key": other_master_key}
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(
+ other_user, {"master_key": other_master_key}
+ )
)
# test various signature failures (see below)
- ret = yield self.handler.upload_signatures_for_device_keys(
- local_user,
- {
- local_user: {
- # fails because the signature is invalid
- # should fail with INVALID_SIGNATURE
- device_id: {
- "user_id": local_user,
- "device_id": device_id,
- "algorithms": [
- "m.olm.curve25519-aes-sha2",
- RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
- ],
- "keys": {
- "curve25519:xyz": "curve25519+key",
- # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
- "ed25519:xyz": device_pubkey,
- },
- "signatures": {
- local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ ret = yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ device_id: {
+ "user_id": local_user,
+ "device_id": device_id,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "curve25519:xyz": "curve25519+key",
+ # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+ "ed25519:xyz": device_pubkey,
+ },
+ "signatures": {
+ local_user: {
+ "ed25519:" + selfsigning_pubkey: "something"
+ }
+ },
},
- },
- # fails because device is unknown
- # should fail with NOT_FOUND
- "unknown": {
- "user_id": local_user,
- "device_id": "unknown",
- "signatures": {
- local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ # fails because device is unknown
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": local_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {
+ "ed25519:" + selfsigning_pubkey: "something"
+ }
+ },
},
- },
- # fails because the signature is invalid
- # should fail with INVALID_SIGNATURE
- master_pubkey: {
- "user_id": local_user,
- "usage": ["master"],
- "keys": {"ed25519:" + master_pubkey: master_pubkey},
- "signatures": {
- local_user: {"ed25519:" + device_pubkey: "something"}
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ master_pubkey: {
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_pubkey: master_pubkey},
+ "signatures": {
+ local_user: {"ed25519:" + device_pubkey: "something"}
+ },
},
},
- },
- other_user: {
- # fails because the device is not the user's master-signing key
- # should fail with NOT_FOUND
- "unknown": {
- "user_id": other_user,
- "device_id": "unknown",
- "signatures": {
- local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ other_user: {
+ # fails because the device is not the user's master-signing key
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": other_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {
+ "ed25519:" + usersigning_pubkey: "something"
+ }
+ },
},
- },
- other_master_pubkey: {
- # fails because the key doesn't match what the server has
- # should fail with UNKNOWN
- "user_id": other_user,
- "usage": ["master"],
- "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
- "something": "random",
- "signatures": {
- local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ other_master_pubkey: {
+ # fails because the key doesn't match what the server has
+ # should fail with UNKNOWN
+ "user_id": other_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:" + other_master_pubkey: other_master_pubkey
+ },
+ "something": "random",
+ "signatures": {
+ local_user: {
+ "ed25519:" + usersigning_pubkey: "something"
+ }
+ },
},
},
},
- },
+ )
)
user_failures = ret["failures"][local_user]
@@ -478,19 +538,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key)
- ret = yield self.handler.upload_signatures_for_device_keys(
- local_user,
- {
- local_user: {device_id: device_key, master_pubkey: master_key},
- other_user: {other_master_pubkey: other_master_key},
- },
+ ret = yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {device_id: device_key, master_pubkey: master_key},
+ other_user: {other_master_pubkey: other_master_key},
+ },
+ )
)
self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there
- ret = yield self.handler.query_devices(
- {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ ret = yield defer.ensureDeferred(
+ self.handler.query_devices(
+ {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ )
)
self.assertEqual(
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 822ea42d..3362050c 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_version_info(self.local_user)
+ yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_version_info(self.local_user, "bogus_version")
+ yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "bogus_version")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_create_version(self):
"""Check that we can create and then retrieve versions.
"""
- res = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "1")
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
version_etag = res["etag"]
self.assertIsInstance(version_etag, str)
del res["etag"]
@@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as a specific version
- res = yield self.handler.get_version_info(self.local_user, "1")
+ res = yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "1")
+ )
self.assertEqual(res["etag"], version_etag)
del res["etag"]
self.assertDictEqual(
@@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# upload a new one...
- res = yield self.handler.create_version(
- self.local_user,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- },
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "2")
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_version(self):
"""Check that we can update versions.
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- res = yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": version,
- },
+ res = yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ },
+ )
)
self.assertDictEqual(res, {})
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.update_version(
- self.local_user,
- "1",
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "1",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ "1",
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "1",
+ },
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_omitted_version(self):
"""Check that the update succeeds if the version is missing from the body
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ },
+ )
)
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
@@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_bad_version(self):
"""Check that we get a 400 if the version in the body doesn't match
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "incorrect",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "incorrect",
+ },
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.delete_version(self.local_user, "1")
+ yield defer.ensureDeferred(
+ self.handler.delete_version(self.local_user, "1")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.delete_version(self.local_user)
+ yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_version(self):
"""Check that we can create and then delete versions.
"""
- res = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "1")
# check we can delete it
- yield self.handler.delete_version(self.local_user, "1")
+ yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
# check that it's gone
res = None
try:
- yield self.handler.get_version_info(self.local_user, "1")
+ yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "1")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_room_keys(self.local_user, "bogus_version")
+ yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, "bogus_version")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_get_missing_room_keys(self):
"""Check we get an empty response from an empty backup
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys,
@@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.upload_room_keys(
- self.local_user, "no_version", room_keys
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
)
except errors.SynapseError as e:
res = e.code
@@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.upload_room_keys(
- self.local_user, "bogus_version", room_keys
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(
+ self.local_user, "bogus_version", room_keys
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_wrong_version(self):
"""Check that we get a 403 on uploading keys for an old version
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- version = yield self.handler.create_version(
- self.local_user,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- },
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "2")
res = None
try:
- yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 403)
@@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_insert(self):
"""Check that we can insert and retrieve keys for a session
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org"
+ )
)
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, room_keys)
@@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
# get the etag to compare to future versions
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"]
self.assertEqual(res["count"], 1)
@@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new"
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
# the etag should be the same since the session did not change
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should NOT be equal now, since the key changed
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"]
@@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other"
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should be the same since the session did not change
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here
@@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_room_keys(self):
"""Check that we can insert and delete keys for a session
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
# check for bulk-delete
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(self.local_user, version)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(self.local_user, version)
+ )
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(
- self.local_user, version, room_id="!abc:matrix.org"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org"
+ )
)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 29dd7d9c..4f1347cd 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -72,7 +72,9 @@ class ProfileTestCase(unittest.TestCase):
def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
- displayname = yield self.handler.get_displayname(self.frank)
+ displayname = yield defer.ensureDeferred(
+ self.handler.get_displayname(self.frank)
+ )
self.assertEquals("Frank", displayname)
@@ -140,7 +142,9 @@ class ProfileTestCase(unittest.TestCase):
{"displayname": "Alice"}
)
- displayname = yield self.handler.get_displayname(self.alice)
+ displayname = yield defer.ensureDeferred(
+ self.handler.get_displayname(self.alice)
+ )
self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with(
@@ -155,8 +159,10 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile("caroline")
yield self.store.set_profile_displayname("caroline", "Caroline")
- response = yield self.query_handlers["profile"](
- {"user_id": "@caroline:test", "field": "displayname"}
+ response = yield defer.ensureDeferred(
+ self.query_handlers["profile"](
+ {"user_id": "@caroline:test", "field": "displayname"}
+ )
)
self.assertEquals({"displayname": "Caroline"}, response)
@@ -166,8 +172,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)
-
- avatar_url = yield self.handler.get_avatar_url(self.frank)
+ avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
self.assertEquals("http://my.server/me.png", avatar_url)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1e6a53bf..5878f741 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -138,10 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- def get_current_users_in_room(room_id):
+ def get_users_in_room(room_id):
return defer.succeed({str(u) for u in self.room_members})
- hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
+ self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 954e059e..69945a8f 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory():
return test_server_connection_factory
+# Once Async Mocks or lambdas are supported this can go away.
+def generate_resolve_service(result):
+ async def resolve_service(_):
+ return result
+
+ return resolve_service
+
+
class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
@@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
@@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known delegates elsewhere
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA)
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
@@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when there is a single SRV record
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self._handle_well_known_connection(
client_factory,
@@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
# the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
+ )
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
@@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.
"""
-
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"target.com", port=8443),
- Server(host=b"target.com", port=8444),
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ )
self.reactor.lookups["target.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc2016..fee2985d 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
-
- self.assertNoResult(resolve_d)
-
- # should have reset to the sentinel context
- self.assertIs(current_context(), SENTINEL_CONTEXT)
-
- result = yield resolve_d
+ result = yield defer.ensureDeferred(resolve_d)
# should have restored our context
self.assertIs(current_context(), ctx)
@@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
dns_client_mock.lookupService.assert_called_once_with(service_name)
@@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client=dns_client_mock, cache=cache, get_time=clock.time
)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertFalse(dns_client_mock.lookupService.called)
@@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
- yield resolver.resolve_service(service_name)
+ yield defer.ensureDeferred(resolver.resolve_service(service_name))
@defer.inlineCallbacks
def test_name_error(self):
@@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
@@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
# returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback(
@@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in successResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
lookup_deferred.callback(
(
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe..06575ba0 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple
import attr
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
+from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
-from synapse.replication.http import streams
+from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeTransport
+from tests.server import FakeTransport, render
logger = logging.getLogger(__name__)
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
+class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests running multiple workers.
+
+ Automatically handle HTTP replication requests from workers to master,
+ unlike `BaseStreamTestCase`.
+ """
+
+ servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
+
+ def setUp(self):
+ super().setUp()
+
+ # build a replication server
+ self.server_factory = ReplicationStreamProtocolFactory(self.hs)
+ self.streamer = self.hs.get_replication_streamer()
+
+ store = self.hs.get_datastore()
+ self.database = store.db
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self._worker_hs_to_resource = {}
+
+ # When we see a connection attempt to the master replication listener we
+ # automatically set up the connection. This is so that tests don't
+ # manually have to go and explicitly set it up each time (plus sometimes
+ # it is impossible to write the handling explicitly in the tests).
+ self.reactor.add_tcp_client_callback(
+ "1.2.3.4", 8765, self._handle_http_replication_attempt
+ )
+
+ def create_test_json_resource(self):
+ """Overrides `HomeserverTestCase.create_test_json_resource`.
+ """
+ # We override this so that it automatically registers all the HTTP
+ # replication servlets, without having to explicitly do that in all
+ # subclassses.
+
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(self.hs, resource)
+
+ return resource
+
+ def make_worker_hs(
+ self, worker_app: str, extra_config: dict = {}, **kwargs
+ ) -> HomeServer:
+ """Make a new worker HS instance, correctly connecting replcation
+ stream to the master HS.
+
+ Args:
+ worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+ extra_config: Any extra config to use for this instances.
+ **kwargs: Options that get passed to `self.setup_test_homeserver`,
+ useful to e.g. pass some mocks for things like `http_client`
+
+ Returns:
+ The new worker HomeServer instance.
+ """
+
+ config = self._get_worker_hs_config()
+ config["worker_app"] = worker_app
+ config.update(extra_config)
+
+ worker_hs = self.setup_test_homeserver(
+ homeserverToUse=GenericWorkerServer,
+ config=config,
+ reactor=self.reactor,
+ **kwargs
+ )
+
+ store = worker_hs.get_datastore()
+ store.db._db_pool = self.database._db_pool
+
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
+
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
+
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
+
+ # Set up a resource for the worker
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(worker_hs, resource)
+
+ self._worker_hs_to_resource[worker_hs] = resource
+
+ return worker_hs
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
+ render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump()
+
+ def _handle_http_replication_attempt(self):
+ """Handles a connection attempt to the master replication HTTP
+ listener.
+ """
+
+ # We should have at least one outbound connection attempt, where the
+ # last is one to the HTTP repication IP/port.
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8765)
+
+ # Set up client side protocol
+ client_protocol = client_factory.buildProtocol(None)
+
+ request_factory = OneShotRequestFactory()
+
+ # Set up the server side protocol
+ channel = _PushHTTPChannel(self.reactor)
+ channel.requestFactory = request_factory
+ channel.site = self.site
+
+ # Connect client to server and vice versa.
+ client_to_server_transport = FakeTransport(
+ channel, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, channel
+ )
+ channel.makeConnection(server_to_client_transport)
+
+ # Note: at this point we've wired everything up, but we need to return
+ # before the data starts flowing over the connections as this is called
+ # inside `connecTCP` before the connection has been passed back to the
+ # code that requested the TCP connection.
+
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
+ def checkPersistence(self, request, version):
+ """Check whether the connection can be re-used
+ """
+ # We hijack this to always say no for ease of wiring stuff up in
+ # `handle_http_replication_attempt`.
+ request.responseHeaders.setRawHeaders(b"connection", [b"close"])
+ return False
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer.
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 097e1653..c9998e88 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
OTHER_USER = "@other_user:localhost"
# have the user join
- inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+ self.get_success(
+ inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+ )
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@@ -157,14 +159,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
- pl_event = inject_event(
- self.hs,
- prev_event_ids=prev_events,
- type=EventTypes.PowerLevels,
- state_key="",
- sender=self.user_id,
- room_id=self.room_id,
- content=pls,
+ pl_event = self.get_success(
+ inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
)
# one more bit of state that doesn't get rolled back
@@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# have the users join
for u in user_ids:
- inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+ self.get_success(
+ inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+ )
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@@ -306,14 +312,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_events = []
for u in user_ids:
pls["users"][u] = 0
- e = inject_event(
- self.hs,
- prev_event_ids=prev_events,
- type=EventTypes.PowerLevels,
- state_key="",
- sender=self.user_id,
- room_id=self.room_id,
- content=pls,
+ e = self.get_success(
+ inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
)
prev_events = [e.event_id]
pl_events.append(e)
@@ -434,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase):
body = "event %i" % (self.event_count,)
self.event_count += 1
- return inject_event(
- self.hs,
- room_id=self.room_id,
- sender=sender,
- type="test_event",
- content={"body": body},
- **kwargs
+ return self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_event",
+ content={"body": body},
+ **kwargs
+ )
)
def _inject_state_event(
@@ -459,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase):
if body is None:
body = "state event %s" % (state_key,)
- return inject_event(
- self.hs,
- room_id=self.room_id,
- sender=sender,
- type="test_state_event",
- state_key=state_key,
- content={"body": body},
+ return self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_state_event",
+ state_key=state_key,
+ content={"body": body},
+ )
)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
new file mode 100644
index 00000000..86c03fd8
--- /dev/null
+++ b/tests/replication/test_client_reader_shard.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.constants import LoginType
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
+from tests.server import FakeChannel
+
+logger = logging.getLogger(__name__)
+
+
+class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
+ """Base class for tests of the replication streams"""
+
+ servlets = [register.register_servlets]
+
+ def prepare(self, reactor, clock, hs):
+ self.recaptcha_checker = DummyRecaptchaChecker(hs)
+ auth_handler = hs.get_auth_handler()
+ auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def test_register_single_worker(self):
+ """Test that registration works when using a single client reader worker.
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+
+ request_1, channel_1 = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs, request_1)
+ self.assertEqual(request_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ request_2, channel_2 = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs, request_2)
+ self.assertEqual(request_2.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel_2.json_body["user_id"], "@user:test")
+
+ def test_register_multi_worker(self):
+ """Test that registration works when using multiple client reader workers.
+ """
+ worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
+ worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
+
+ request_1, channel_1 = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs_1, request_1)
+ self.assertEqual(request_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ request_2, channel_2 = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs_2, request_2)
+ self.assertEqual(request_2.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 5448d9f0..23be1167 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+
return hs
def test_federation_ack_sent(self):
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
new file mode 100644
index 00000000..8d4dbf23
--- /dev/null
+++ b/tests/replication/test_federation_sender_shard.py
@@ -0,0 +1,235 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events.builder import EventBuilderFactory
+from synapse.rest.admin import register_servlets_for_client_rest_resource
+from synapse.rest.client.v1 import login, room
+from synapse.types import UserID
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ login.register_servlets,
+ register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ ]
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["send_federation"] = False
+ return conf
+
+ def test_send_event_single_sender(self):
+ """Test that using a single federation sender worker correctly sends a
+ new event.
+ """
+ mock_client = Mock(spec=["put_json"])
+ mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {"send_federation": True},
+ http_client=mock_client,
+ )
+
+ user = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room = self.create_room_with_remote_server(user, token)
+
+ mock_client.put_json.reset_mock()
+
+ self.create_and_send_event(room, UserID.from_string(user))
+ self.replicate()
+
+ # Assert that the event was sent out over federation.
+ mock_client.put_json.assert_called()
+ self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
+ self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
+
+ def test_send_event_sharded(self):
+ """Test that using two federation sender workers correctly sends
+ new events.
+ """
+ mock_client1 = Mock(spec=["put_json"])
+ mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender1",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client1,
+ )
+
+ mock_client2 = Mock(spec=["put_json"])
+ mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender2",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client2,
+ )
+
+ user = self.register_user("user2", "pass")
+ token = self.login("user2", "pass")
+
+ sent_on_1 = False
+ sent_on_2 = False
+ for i in range(20):
+ server_name = "other_server_%d" % (i,)
+ room = self.create_room_with_remote_server(user, token, server_name)
+ mock_client1.reset_mock() # type: ignore[attr-defined]
+ mock_client2.reset_mock() # type: ignore[attr-defined]
+
+ self.create_and_send_event(room, UserID.from_string(user))
+ self.replicate()
+
+ if mock_client1.put_json.called:
+ sent_on_1 = True
+ mock_client2.put_json.assert_not_called()
+ self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus"))
+ elif mock_client2.put_json.called:
+ sent_on_2 = True
+ mock_client1.put_json.assert_not_called()
+ self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus"))
+ else:
+ raise AssertionError(
+ "Expected send transaction from one or the other sender"
+ )
+
+ if sent_on_1 and sent_on_2:
+ break
+
+ self.assertTrue(sent_on_1)
+ self.assertTrue(sent_on_2)
+
+ def test_send_typing_sharded(self):
+ """Test that using two federation sender workers correctly sends
+ new typing EDUs.
+ """
+ mock_client1 = Mock(spec=["put_json"])
+ mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender1",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client1,
+ )
+
+ mock_client2 = Mock(spec=["put_json"])
+ mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender2",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client2,
+ )
+
+ user = self.register_user("user3", "pass")
+ token = self.login("user3", "pass")
+
+ typing_handler = self.hs.get_typing_handler()
+
+ sent_on_1 = False
+ sent_on_2 = False
+ for i in range(20):
+ server_name = "other_server_%d" % (i,)
+ room = self.create_room_with_remote_server(user, token, server_name)
+ mock_client1.reset_mock() # type: ignore[attr-defined]
+ mock_client2.reset_mock() # type: ignore[attr-defined]
+
+ self.get_success(
+ typing_handler.started_typing(
+ target_user=UserID.from_string(user),
+ auth_user=UserID.from_string(user),
+ room_id=room,
+ timeout=20000,
+ )
+ )
+
+ self.replicate()
+
+ if mock_client1.put_json.called:
+ sent_on_1 = True
+ mock_client2.put_json.assert_not_called()
+ self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus"))
+ elif mock_client2.put_json.called:
+ sent_on_2 = True
+ mock_client1.put_json.assert_not_called()
+ self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus"))
+ else:
+ raise AssertionError(
+ "Expected send transaction from one or the other sender"
+ )
+
+ if sent_on_1 and sent_on_2:
+ break
+
+ self.assertTrue(sent_on_1)
+ self.assertTrue(sent_on_2)
+
+ def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+ room = self.helper.create_room_as(user, tok=token)
+ store = self.hs.get_datastore()
+ federation = self.hs.get_handlers().federation_handler
+
+ prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
+ room_version = self.get_success(store.get_room_version(room))
+
+ factory = EventBuilderFactory(self.hs)
+ factory.hostname = remote_server
+
+ user_id = UserID("user", remote_server).to_string()
+
+ event_dict = {
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "content": {"membership": Membership.JOIN},
+ "sender": user_id,
+ "room_id": room,
+ }
+
+ builder = factory.for_room_version(room_version, event_dict)
+ join_event = self.get_success(builder.build(prev_event_ids))
+
+ self.get_success(federation.on_send_join_request(remote_server, join_event))
+ self.replicate()
+
+ return room
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
new file mode 100644
index 00000000..2bdc6edb
--- /dev/null
+++ b/tests/replication/test_pusher_shard.py
@@ -0,0 +1,193 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks pusher sharding works
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["start_pushers"] = False
+ return conf
+
+ def _create_pusher_and_send_msg(self, localpart):
+ # Create a user that will get push notifications
+ user_id = self.register_user(localpart, "pass")
+ access_token = self.login(localpart, "pass")
+
+ # Register a pusher
+ user_dict = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_dict["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "https://push.example.com/push"},
+ )
+ )
+
+ self.pump()
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ response = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = response["event_id"]
+
+ return event_id
+
+ def test_send_push_single_worker(self):
+ """Test that registration works when using a pusher worker.
+ """
+ http_client_mock = Mock(spec_set=["post_json_get_json"])
+ http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {"start_pushers": True},
+ proxied_http_client=http_client_mock,
+ )
+
+ event_id = self._create_pusher_and_send_msg("user")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock.post_json_get_json.assert_called_once()
+ self.assertEqual(
+ http_client_mock.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
+
+ def test_send_push_multiple_workers(self):
+ """Test that registration works when using sharded pusher workers.
+ """
+ http_client_mock1 = Mock(spec_set=["post_json_get_json"])
+ http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {
+ "start_pushers": True,
+ "worker_name": "pusher1",
+ "pusher_instances": ["pusher1", "pusher2"],
+ },
+ proxied_http_client=http_client_mock1,
+ )
+
+ http_client_mock2 = Mock(spec_set=["post_json_get_json"])
+ http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {
+ "start_pushers": True,
+ "worker_name": "pusher2",
+ "pusher_instances": ["pusher1", "pusher2"],
+ },
+ proxied_http_client=http_client_mock2,
+ )
+
+ # We choose a user name that we know should go to pusher1.
+ event_id = self._create_pusher_and_send_msg("user2")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock1.post_json_get_json.assert_called_once()
+ http_client_mock2.post_json_get_json.assert_not_called()
+ self.assertEqual(
+ http_client_mock1.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock1.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
+
+ http_client_mock1.post_json_get_json.reset_mock()
+ http_client_mock2.post_json_get_json.reset_mock()
+
+ # Now we choose a user name that we know should go to pusher2.
+ event_id = self._create_pusher_and_send_msg("user4")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock1.post_json_get_json.assert_not_called()
+ http_client_mock2.post_json_get_json.assert_called_once()
+ self.assertEqual(
+ http_client_mock2.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock2.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ae6d05a0..ba8552c2 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1,1006 +1,1447 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 Dirk Klimpel
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import urllib.parse
-from typing import List, Optional
-
-from mock import Mock
-
-import synapse.rest.admin
-from synapse.api.errors import Codes
-from synapse.rest.client.v1 import directory, events, login, room
-
-from tests import unittest
-
-"""Tests admin REST events for /rooms paths."""
-
-
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
-
- url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
- count = self.get_success(
- self.store.db.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
-
-
-class RoomTestCase(unittest.HomeserverTestCase):
- """Test /room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- directory.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- # Create user
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_list_rooms(self):
- """Test that we can list rooms"""
- # Create 3 test rooms
- total_rooms = 3
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- # Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
-
- # Check that response json body contains a "rooms" key
- self.assertTrue(
- "rooms" in channel.json_body,
- msg="Response body does not " "contain a 'rooms' key",
- )
-
- # Check that 3 rooms were returned
- self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
-
- # Check their room_ids match
- returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
- self.assertEqual(room_ids, returned_room_ids)
-
- # Check that all fields are available
- for r in channel.json_body["rooms"]:
- self.assertIn("name", r)
- self.assertIn("canonical_alias", r)
- self.assertIn("joined_members", r)
- self.assertIn("joined_local_members", r)
- self.assertIn("version", r)
- self.assertIn("creator", r)
- self.assertIn("encryption", r)
- self.assertIn("federatable", r)
- self.assertIn("public", r)
- self.assertIn("join_rules", r)
- self.assertIn("guest_access", r)
- self.assertIn("history_visibility", r)
- self.assertIn("state_events", r)
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # Should be 0 as we aren't paginating
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that the prev_batch parameter is not present
- self.assertNotIn("prev_batch", channel.json_body)
-
- # We shouldn't receive a next token here as there's no further rooms to show
- self.assertNotIn("next_batch", channel.json_body)
-
- def test_list_rooms_pagination(self):
- """Test that we can get a full list of rooms through pagination"""
- # Create 5 test rooms
- total_rooms = 5
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Set the name of the rooms so we get a consistent returned ordering
- for idx, room_id in enumerate(room_ids):
- self.helper.send_state(
- room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- returned_room_ids = []
- start = 0
- limit = 2
-
- run_count = 0
- should_repeat = True
- while should_repeat:
- run_count += 1
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
- start,
- limit,
- "name",
- )
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- self.assertTrue("rooms" in channel.json_body)
- for r in channel.json_body["rooms"]:
- returned_room_ids.append(r["room_id"])
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # We're only getting 2 rooms each page, so should be 2 * last run_count
- self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
-
- if run_count > 1:
- # Check the value of prev_batch is correct
- self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
-
- if "next_batch" not in channel.json_body:
- # We have reached the end of the list
- should_repeat = False
- else:
- # Make another query with an updated start value
- start = channel.json_body["next_batch"]
-
- # We should've queried the endpoint 3 times
- self.assertEqual(
- run_count,
- 3,
- msg="Should've queried 3 times for 5 rooms with limit 2 per query",
- )
-
- # Check that we received all of the room ids
- self.assertEqual(room_ids, returned_room_ids)
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- def test_correct_room_attributes(self):
- """Test the correct attributes for a room are returned"""
- # Create a test room
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- test_alias = "#test:test"
- test_room_name = "something"
-
- # Have another user join the room
- user_2 = self.register_user("user4", "pass")
- user_tok_2 = self.login("user4", "pass")
- self.helper.join(room_id, user_2, tok=user_tok_2)
-
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=self.admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=self.admin_user_tok,
- )
-
- # Set a name for the room
- self.helper.send_state(
- room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that only one room was returned
- self.assertEqual(len(rooms), 1)
-
- # And that the value of the total_rooms key was correct
- self.assertEqual(channel.json_body["total_rooms"], 1)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that all provided attributes are set
- r = rooms[0]
- self.assertEqual(room_id, r["room_id"])
- self.assertEqual(test_room_name, r["name"])
- self.assertEqual(test_alias, r["canonical_alias"])
-
- def test_room_list_sort_order(self):
- """Test room list sort ordering. alphabetical name versus number of members,
- reversing the order, etc.
- """
-
- def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (
- urllib.parse.quote(test_alias),
- )
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=admin_user_tok,
- )
-
- def _order_test(
- order_type: str, expected_room_list: List[str], reverse: bool = False,
- ):
- """Request the list of rooms in a certain order. Assert that order is what
- we expect
-
- Args:
- order_type: The type of ordering to give the server
- expected_room_list: The list of room_ids in the order we expect to get
- back from the server
- """
- # Request the list of rooms in the given order
- url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
- if reverse:
- url += "&dir=b"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check for the correct total_rooms value
- self.assertEqual(channel.json_body["total_rooms"], 3)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that rooms were returned in alphabetical order
- returned_order = [r["room_id"] for r in rooms]
- self.assertListEqual(expected_room_list, returned_order) # order is checked
-
- # Create 3 test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
- )
-
- # Set room canonical room aliases
- _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
-
- # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
- user_1 = self.register_user("bob1", "pass")
- user_1_tok = self.login("bob1", "pass")
- self.helper.join(room_id_2, user_1, tok=user_1_tok)
-
- user_2 = self.register_user("bob2", "pass")
- user_2_tok = self.login("bob2", "pass")
- self.helper.join(room_id_3, user_2, tok=user_2_tok)
-
- user_3 = self.register_user("bob3", "pass")
- user_3_tok = self.login("bob3", "pass")
- self.helper.join(room_id_3, user_3, tok=user_3_tok)
-
- # Test different sort orders, with forward and reverse directions
- _order_test("name", [room_id_1, room_id_2, room_id_3])
- _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
- _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
- _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
- _order_test(
- "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
- )
-
- _order_test("version", [room_id_1, room_id_2, room_id_3])
- _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("creator", [room_id_1, room_id_2, room_id_3])
- _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("encryption", [room_id_1, room_id_2, room_id_3])
- _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("federatable", [room_id_1, room_id_2, room_id_3])
- _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("public", [room_id_1, room_id_2, room_id_3])
- # Different sort order of SQlite and PostreSQL
- # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
- _order_test(
- "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
- )
-
- _order_test("state_events", [room_id_3, room_id_2, room_id_1])
- _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- def test_search_term(self):
- """Test that searching for a room works correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- def _search_test(
- expected_room_id: Optional[str],
- search_term: str,
- expected_http_code: int = 200,
- ):
- """Search for a room and check that the returned room's id is a match
-
- Args:
- expected_room_id: The room_id expected to be returned by the API. Set
- to None to expect zero results for the search
- search_term: The term to search for room names with
- expected_http_code: The expected http code for the request
- """
- url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
-
- if expected_http_code != 200:
- return
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that the expected number of rooms were returned
- expected_room_count = 1 if expected_room_id else 0
- self.assertEqual(len(rooms), expected_room_count)
- self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- if expected_room_id:
- # Check that the first returned room id is correct
- r = rooms[0]
- self.assertEqual(expected_room_id, r["room_id"])
-
- # Perform search tests
- _search_test(room_id_1, "something")
- _search_test(room_id_1, "thing")
-
- _search_test(room_id_2, "else")
- _search_test(room_id_2, "se")
-
- _search_test(None, "foo")
- _search_test(None, "bar")
- _search_test(None, "", expected_http_code=400)
-
- def test_single_room(self):
- """Test that a single room can be requested correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- self.assertIn("room_id", channel.json_body)
- self.assertIn("name", channel.json_body)
- self.assertIn("canonical_alias", channel.json_body)
- self.assertIn("joined_members", channel.json_body)
- self.assertIn("joined_local_members", channel.json_body)
- self.assertIn("version", channel.json_body)
- self.assertIn("creator", channel.json_body)
- self.assertIn("encryption", channel.json_body)
- self.assertIn("federatable", channel.json_body)
- self.assertIn("public", channel.json_body)
- self.assertIn("join_rules", channel.json_body)
- self.assertIn("guest_access", channel.json_body)
- self.assertIn("history_visibility", channel.json_body)
- self.assertIn("state_events", channel.json_body)
-
- self.assertEqual(room_id_1, channel.json_body["room_id"])
-
-
-class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
-
- servlets = [
- synapse.rest.admin.register_servlets,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor, clock, homeserver):
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.creator = self.register_user("creator", "test")
- self.creator_tok = self.login("creator", "test")
-
- self.second_user_id = self.register_user("second", "test")
- self.second_tok = self.login("second", "test")
-
- self.public_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=True
- )
- self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
-
- def test_requester_is_no_admin(self):
- """
- If the user is not a server admin, an error 403 is returned.
- """
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.second_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_invalid_parameter(self):
- """
- If a parameter is missing, return an error
- """
- body = json.dumps({"unknown_parameter": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
- def test_local_user_does_not_exist(self):
- """
- Tests that a lookup for a user that does not exist returns a 404
- """
- body = json.dumps({"user_id": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_remote_user(self):
- """
- Check that only local user can join rooms.
- """
- body = json.dumps({"user_id": "@not:exist.bla"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "This endpoint can only be used with local users",
- channel.json_body["error"],
- )
-
- def test_room_does_not_exist(self):
- """
- Check that unknown rooms/server return error 404.
- """
- body = json.dumps({"user_id": self.second_user_id})
- url = "/_synapse/admin/v1/join/!unknown:test"
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("No known servers", channel.json_body["error"])
-
- def test_room_is_not_valid(self):
- """
- Check that invalid room names, return an error 400.
- """
- body = json.dumps({"user_id": self.second_user_id})
- url = "/_synapse/admin/v1/join/invalidroom"
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "invalidroom was not legal room ID or room alias",
- channel.json_body["error"],
- )
-
- def test_join_public_room(self):
- """
- Test joining a local user to a public room with "JoinRules.PUBLIC"
- """
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.public_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
-
- def test_join_private_room_if_not_member(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE"
- when server admin is not member of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False
- )
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_join_private_room_if_member(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE",
- when server admin is member of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False
- )
- self.helper.invite(
- room=private_room_id,
- src=self.creator,
- targ=self.admin_user,
- tok=self.creator_tok,
- )
- self.helper.join(
- room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
- )
-
- # Validate if server admin is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
-
- # Join user to room.
-
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
-
- def test_join_private_room_if_owner(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE",
- when server admin is owner of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=False
- )
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ room_id, body="foo", tok=self.other_user_token, expect_code=403
+ )
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert there is now no longer anyone in the room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ """
+
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_token,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(room_id, expect_code=403)
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class DeleteRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
+
+ request, channel = self.make_request(
+ "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ url = "/_synapse/admin/v1/rooms/invalidroom/delete"
+
+ request, channel = self.make_request(
+ "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom is not a legal room ID", channel.json_body["error"],
+ )
+
+ def test_new_room_user_does_not_exist(self):
+ """
+ Tests that the user ID must be from local server but it does not have to exist.
+ """
+ body = json.dumps({"new_room_user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("kicked_users", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ def test_new_room_user_is_not_local(self):
+ """
+ Check that only local users can create new room to move members.
+ """
+ body = json.dumps({"new_room_user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "User must be our own: @not:exist.bla", channel.json_body["error"],
+ )
+
+ def test_block_is_not_bool(self):
+ """
+ If parameter `block` is not boolean, return an error
+ """
+ body = json.dumps({"block": "NotBool"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_room_and_block(self):
+ """Test to purge a room and block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": True})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_purge_room_and_not_block(self):
+ """Test to purge a room and do not block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": False})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["new_room_id"], user_id=self.other_user
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["new_room_id"], user_id=self.other_user
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(self.room_id, expect_code=403)
+
+ def _is_blocked(self, room_id, expect=True):
+ """Assert that the room is blocked or not
+ """
+ d = self.store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _has_no_members(self, room_id):
+ """Assert there is now no longer anyone in the room
+ """
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def _is_member(self, room_id, user_id):
+ """Test that user is member of the room
+ """
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertIn(user_id, users_in_room)
+
+ def _is_purged(self, room_id):
+ """Test that the following tables have been purged of all rows related to the room.
+ """
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+ """Test /room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_list_rooms(self):
+ """Test that we can list rooms"""
+ # Create 3 test rooms
+ total_rooms = 3
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Check request completed successfully
+ self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+ # Check that response json body contains a "rooms" key
+ self.assertTrue(
+ "rooms" in channel.json_body,
+ msg="Response body does not " "contain a 'rooms' key",
+ )
+
+ # Check that 3 rooms were returned
+ self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+ # Check their room_ids match
+ returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+ self.assertEqual(room_ids, returned_room_ids)
+
+ # Check that all fields are available
+ for r in channel.json_body["rooms"]:
+ self.assertIn("name", r)
+ self.assertIn("canonical_alias", r)
+ self.assertIn("joined_members", r)
+ self.assertIn("joined_local_members", r)
+ self.assertIn("version", r)
+ self.assertIn("creator", r)
+ self.assertIn("encryption", r)
+ self.assertIn("federatable", r)
+ self.assertIn("public", r)
+ self.assertIn("join_rules", r)
+ self.assertIn("guest_access", r)
+ self.assertIn("history_visibility", r)
+ self.assertIn("state_events", r)
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # Should be 0 as we aren't paginating
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that the prev_batch parameter is not present
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # We shouldn't receive a next token here as there's no further rooms to show
+ self.assertNotIn("next_batch", channel.json_body)
+
+ def test_list_rooms_pagination(self):
+ """Test that we can get a full list of rooms through pagination"""
+ # Create 5 test rooms
+ total_rooms = 5
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Set the name of the rooms so we get a consistent returned ordering
+ for idx, room_id in enumerate(room_ids):
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ returned_room_ids = []
+ start = 0
+ limit = 2
+
+ run_count = 0
+ should_repeat = True
+ while should_repeat:
+ run_count += 1
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+ start,
+ limit,
+ "name",
+ )
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ self.assertTrue("rooms" in channel.json_body)
+ for r in channel.json_body["rooms"]:
+ returned_room_ids.append(r["room_id"])
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # We're only getting 2 rooms each page, so should be 2 * last run_count
+ self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+ if run_count > 1:
+ # Check the value of prev_batch is correct
+ self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+ if "next_batch" not in channel.json_body:
+ # We have reached the end of the list
+ should_repeat = False
+ else:
+ # Make another query with an updated start value
+ start = channel.json_body["next_batch"]
+
+ # We should've queried the endpoint 3 times
+ self.assertEqual(
+ run_count,
+ 3,
+ msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+ )
+
+ # Check that we received all of the room ids
+ self.assertEqual(room_ids, returned_room_ids)
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_correct_room_attributes(self):
+ """Test the correct attributes for a room are returned"""
+ # Create a test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ test_alias = "#test:test"
+ test_room_name = "something"
+
+ # Have another user join the room
+ user_2 = self.register_user("user4", "pass")
+ user_tok_2 = self.login("user4", "pass")
+ self.helper.join(room_id, user_2, tok=user_tok_2)
+
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=self.admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=self.admin_user_tok,
+ )
+
+ # Set a name for the room
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that only one room was returned
+ self.assertEqual(len(rooms), 1)
+
+ # And that the value of the total_rooms key was correct
+ self.assertEqual(channel.json_body["total_rooms"], 1)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that all provided attributes are set
+ r = rooms[0]
+ self.assertEqual(room_id, r["room_id"])
+ self.assertEqual(test_room_name, r["name"])
+ self.assertEqual(test_alias, r["canonical_alias"])
+
+ def test_room_list_sort_order(self):
+ """Test room list sort ordering. alphabetical name versus number of members,
+ reversing the order, etc.
+ """
+
+ def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (
+ urllib.parse.quote(test_alias),
+ )
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=admin_user_tok,
+ )
+
+ def _order_test(
+ order_type: str, expected_room_list: List[str], reverse: bool = False,
+ ):
+ """Request the list of rooms in a certain order. Assert that order is what
+ we expect
+
+ Args:
+ order_type: The type of ordering to give the server
+ expected_room_list: The list of room_ids in the order we expect to get
+ back from the server
+ """
+ # Request the list of rooms in the given order
+ url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+ if reverse:
+ url += "&dir=b"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check for the correct total_rooms value
+ self.assertEqual(channel.json_body["total_rooms"], 3)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that rooms were returned in alphabetical order
+ returned_order = [r["room_id"] for r in rooms]
+ self.assertListEqual(expected_room_list, returned_order) # order is checked
+
+ # Create 3 test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ )
+
+ # Set room canonical room aliases
+ _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+ # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+ user_1 = self.register_user("bob1", "pass")
+ user_1_tok = self.login("bob1", "pass")
+ self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+ user_2 = self.register_user("bob2", "pass")
+ user_2_tok = self.login("bob2", "pass")
+ self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+ user_3 = self.register_user("bob3", "pass")
+ user_3_tok = self.login("bob3", "pass")
+ self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+ # Test different sort orders, with forward and reverse directions
+ _order_test("name", [room_id_1, room_id_2, room_id_3])
+ _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+ _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+ _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+ _order_test(
+ "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("version", [room_id_1, room_id_2, room_id_3])
+ _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("creator", [room_id_1, room_id_2, room_id_3])
+ _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("public", [room_id_1, room_id_2, room_id_3])
+ # Different sort order of SQlite and PostreSQL
+ # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+ _order_test(
+ "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+ _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ def test_search_term(self):
+ """Test that searching for a room works correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ def _search_test(
+ expected_room_id: Optional[str],
+ search_term: str,
+ expected_http_code: int = 200,
+ ):
+ """Search for a room and check that the returned room's id is a match
+
+ Args:
+ expected_room_id: The room_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for room names with
+ expected_http_code: The expected http code for the request
+ """
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that the expected number of rooms were returned
+ expected_room_count = 1 if expected_room_id else 0
+ self.assertEqual(len(rooms), expected_room_count)
+ self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ if expected_room_id:
+ # Check that the first returned room id is correct
+ r = rooms[0]
+ self.assertEqual(expected_room_id, r["room_id"])
+
+ # Perform search tests
+ _search_test(room_id_1, "something")
+ _search_test(room_id_1, "thing")
+
+ _search_test(room_id_2, "else")
+ _search_test(room_id_2, "se")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+ _search_test(None, "", expected_http_code=400)
+
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
+ def test_room_members(self):
+ """Test that room members can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ # Have another user join the room
+ user_2 = self.register_user("bar", "pass")
+ user_tok_2 = self.login("bar", "pass")
+ self.helper.join(room_id_1, user_2, tok=user_tok_2)
+ self.helper.join(room_id_2, user_2, tok=user_tok_2)
+
+ # Have another user join the room
+ user_3 = self.register_user("foobar", "pass")
+ user_tok_3 = self.login("foobar", "pass")
+ self.helper.join(room_id_2, user_3, tok=user_tok_3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.second_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If a parameter is missing, return an error
+ """
+ body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ def test_local_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ body = json.dumps({"user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_remote_user(self):
+ """
+ Check that only local user can join rooms.
+ """
+ body = json.dumps({"user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "This endpoint can only be used with local users",
+ channel.json_body["error"],
+ )
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/!unknown:test"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("No known servers", channel.json_body["error"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/invalidroom"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom was not legal room ID or room alias",
+ channel.json_body["error"],
+ )
+
+ def test_join_public_room(self):
+ """
+ Test joining a local user to a public room with "JoinRules.PUBLIC"
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_not_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE"
+ when server admin is not member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_join_private_room_if_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ self.helper.invite(
+ room=private_room_id,
+ src=self.creator,
+ targ=self.admin_user,
+ tok=self.creator_tok,
+ )
+ self.helper.join(
+ room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+ )
+
+ # Validate if server admin is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ # Join user to room.
+
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_owner(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is owner of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cca5f548..f16eef15 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -857,6 +857,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
+ def test_reactivate_user(self):
+ """
+ Test reactivating another user.
+ """
+
+ # Deactivate the user.
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Attempt to reactivate the user (without a password).
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Reactivate the user.
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": False, "password": "foo"}).encode(
+ encoding="utf_8"
+ ),
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+
def test_set_user_as_admin(self):
"""
Test setting the admin flag on a user.
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index fd979999..db52725c 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -398,7 +398,7 @@ class CASTestCase(unittest.HomeserverTestCase):
</cas:serviceResponse>
"""
% cas_user_id
- )
+ ).encode("utf-8")
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
@@ -514,16 +514,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
]
jwt_secret = "secret"
+ jwt_algorithm = "HS256"
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_secret
- self.hs.config.jwt_algorithm = "HS256"
+ self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, "HS256").decode("ascii")
+ return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
def jwt_login(self, *args):
params = json.dumps(
@@ -546,35 +547,126 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "JWT expired")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Signature has expired"
+ )
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: The token is not yet valid (nbf)",
+ )
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "issuer": "test-issuer",
+ }
+ }
+ )
+ def test_login_iss(self):
+ """Test validating the issuer claim."""
+ # A valid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid issuer"
+ )
+
+ # Not providing an issuer.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "iss" claim',
+ )
+
+ def test_login_iss_no_config(self):
+ """Test providing an issuer claim without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "audiences": ["test-audience"],
+ }
+ }
+ )
+ def test_login_aud(self):
+ """Test validating the audience claim."""
+ # A valid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
+ # Not providing an audience.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "aud" claim',
+ )
+
+ def test_login_aud_no_config(self):
+ """Test providing an audience without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
def test_login_no_token(self):
params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -656,6 +748,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 66fa5978..f4f3e567 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -26,6 +26,7 @@ import attr
from parameterized import parameterized_class
from PIL import Image as Image
+from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
@@ -77,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase):
# This uses a real blocking threadpool so we have to wait for it to be
# actually done :/
- x = self.media_storage.ensure_media_is_in_local_cache(file_info)
+ x = defer.ensureDeferred(
+ self.media_storage.ensure_media_is_in_local_cache(file_info)
+ )
# Hotloop until the threadpool does its job...
self.wait_on_thread(x)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 2826211f..74765a58 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import os
+import re
+
+from mock import patch
import attr
@@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver()
def test_cache_returns_correct_type(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
"GET", "url_preview?url=http://matrix.org", shorthand=False
@@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
def test_non_ascii_preview_httpequiv(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_non_ascii_preview_content_type(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_overlong_title(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
IP addresses can be previewed directly.
"""
- self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
"GET", "url_preview?url=http://example.com", shorthand=False
@@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Hardcode the URL resolving to the IP we want.
self.lookups["example.com"] = [
(IPv4Address, "1.1.1.2"),
- (IPv4Address, "8.8.8.8"),
+ (IPv4Address, "10.1.2.3"),
]
request, channel = self.make_request(
@@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Accept-Language header is sent to the remote server
"""
- self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
# Build and make a request to the server
request, channel = self.make_request(
@@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase):
),
server.data,
)
+
+ def test_oembed_photo(self):
+ """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
+ # Route the HTTP version to an HTTP endpoint so that the tests work.
+ with patch.dict(
+ "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+ {
+ re.compile(
+ r"http://twitter\.com/.+/status/.+"
+ ): "http://publish.twitter.com/oembed",
+ },
+ clear=True,
+ ):
+
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result).encode("utf-8")
+
+ end_content = (
+ b"<html><head>"
+ b"<title>Some Title</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ request, channel = self.make_request(
+ "GET",
+ "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(oembed_content),)
+ + oembed_content
+ )
+
+ self.pump()
+
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ )
+
+ def test_oembed_rich(self):
+ """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
+ # Route the HTTP version to an HTTP endpoint so that the tests work.
+ with patch.dict(
+ "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+ {
+ re.compile(
+ r"http://twitter\.com/.+/status/.+"
+ ): "http://publish.twitter.com/oembed",
+ },
+ clear=True,
+ ):
+
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ request, channel = self.make_request(
+ "GET",
+ "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {"og:title": None, "og:description": "Content Preview"},
+ )
diff --git a/tests/server.py b/tests/server.py
index a5e57c52..b6e0b14e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self.threadpool = ThreadPool(self)
+ self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
@@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self):
return self.threadpool
+ def add_tcp_client_callback(self, host, port, callback):
+ """Add a callback that will be invoked when we receive a connection
+ attempt to the given IP/port using `connectTCP`.
+
+ Note that the callback gets run before we return the connection to the
+ client, which means callbacks cannot block while waiting for writes.
+ """
+ self._tcp_callbacks[(host, port)] = callback
+
+ def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+ """Fake L{IReactorTCP.connectTCP}.
+ """
+
+ conn = super().connectTCP(
+ host, port, factory, timeout=timeout, bindAddress=None
+ )
+
+ callback = self._tcp_callbacks.get((host, port))
+ if callback:
+ callback()
+
+ return conn
+
class ThreadPool:
"""
@@ -486,7 +510,7 @@ class FakeTransport(object):
try:
self.other.dataReceived(to_write)
except Exception as e:
- logger.warning("Exception writing to protocol: %s", e)
+ logger.exception("Exception writing to protocol: %s", e)
return
self.buffer = self.buffer[len(to_write) :]
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 38f9b423..f2955a9c 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,6 +14,7 @@
# limitations under the License.
import itertools
+from typing import List
import attr
@@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
- state_before = self.successResultOf(state_d)
+ state_before = self.successResultOf(defer.ensureDeferred(state_d))
state_after = dict(state_before)
if fake_event.state_key is not None:
@@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(self.event_map),
)
- state = self.successResultOf(state_d)
+ state = self.successResultOf(defer.ensureDeferred(state_d))
self.assert_dict(self.expected_combined_state, state)
@@ -608,9 +609,11 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
- return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ return defer.succeed(
+ {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ )
- def _get_auth_chain(self, event_ids):
+ def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -622,10 +625,10 @@ class TestStateResolutionStore(object):
presence of rejected events
Args:
- event_ids (list): The event IDs of the events to fetch the auth
+ event_ids: The event IDs of the events to fetch the auth
chain for. Must be state events.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ List of event IDs of the auth chain.
"""
# Simple DFS for auth chain
@@ -648,4 +651,4 @@ class TestStateResolutionStore(object):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
- return set(chains[0]).union(*chains[1:]) - common
+ return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3b78d488..1d77b4a2 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -56,6 +56,10 @@ class RoomStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_get_room_unknown_room(self):
+ self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+
+ @defer.inlineCallbacks
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
{
@@ -66,6 +70,10 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room_with_stats(self.room.to_string())),
)
+ @defer.inlineCallbacks
+ def test_get_room_with_stats_unknown_room(self):
+ self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
@@ -101,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
@@ -117,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5dd46005..f2829215 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -118,18 +118,22 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test_get_joined_users_from_context(self):
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- bob_event = event_injection.inject_member_event(
- self.hs, room, self.u_bob, Membership.JOIN
+ bob_event = self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room, self.u_bob, Membership.JOIN
+ )
)
# first, create a regular event
- event, context = event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[bob_event.event_id],
- type="m.test.1",
- content={},
+ event, context = self.get_success(
+ event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.1",
+ content={},
+ )
)
users = self.get_success(
@@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Regression test for #7376: create a state event whose key matches bob's
# user_id, but which is *not* a membership event, and persist that; then check
# that `get_joined_users_from_context` returns the correct users for the next event.
- non_member_event = event_injection.inject_event(
- self.hs,
- room_id=room,
- sender=self.u_bob,
- prev_event_ids=[bob_event.event_id],
- type="m.test.2",
- state_key=self.u_bob,
- content={},
+ non_member_event = self.get_success(
+ event_injection.inject_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_bob,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.2",
+ state_key=self.u_bob,
+ content={},
+ )
)
- event, context = event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[non_member_event.event_id],
- type="m.test.3",
- content={},
+ event, context = self.get_success(
+ event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[non_member_event.event_id],
+ type="m.test.3",
+ content={},
+ )
)
users = self.get_success(
self.store.get_joined_users_from_context(event, context)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b88308f..a0e133cd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 89dcc58b..87a16d7d 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.homeserver.get_datastore()
- store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
+ store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
@@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock(
- return_value={
- "user_id": remote_user_id,
- "stream_id": 1,
- "devices": [],
- "master_key": {
+ return_value=succeed(
+ {
"user_id": remote_user_id,
- "usage": ["master"],
- "keys": {"ed25519:" + remote_master_key: remote_master_key},
- },
- "self_signing_key": {
- "user_id": remote_user_id,
- "usage": ["self_signing"],
- "keys": {
- "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
},
- },
- }
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
)
# Resync the device list.
diff --git a/tests/test_server.py b/tests/test_server.py
index 030f58cb..073b2362 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -12,26 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import re
-from io import StringIO
from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def
from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
-from synapse.http.site import SynapseSite, logger
+from synapse.http.site import SynapseSite
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
from tests import unittest
from tests.server import (
- FakeTransport,
ThreadedMemoryReactorClock,
make_request,
render,
@@ -199,10 +193,10 @@ class OptionsResourceTests(unittest.TestCase):
return channel
def test_unknown_options_request(self):
- """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ """An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -219,10 +213,10 @@ class OptionsResourceTests(unittest.TestCase):
)
def test_known_options_request(self):
- """An OPTIONS requests to an known URL still returns 200 OK."""
+ """An OPTIONS requests to an known URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/res/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -318,54 +312,3 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(location_headers, [b"/no/over/there"])
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
-
-
-class SiteTestCase(unittest.HomeserverTestCase):
- def test_lose_connection(self):
- """
- We log the URI correctly redacted when we lose the connection.
- """
-
- class HangingResource(Resource):
- """
- A Resource that strategically hangs, as if it were processing an
- answer.
- """
-
- def render(self, request):
- return NOT_DONE_YET
-
- # Set up a logging handler that we can inspect afterwards
- output = StringIO()
- handler = logging.StreamHandler(output)
- logger.addHandler(handler)
- old_level = logger.level
- logger.setLevel(10)
- self.addCleanup(logger.setLevel, old_level)
- self.addCleanup(logger.removeHandler, handler)
-
- # Make a resource and a Site, the resource will hang and allow us to
- # time out the request while it's 'processing'
- base_resource = Resource()
- base_resource.putChild(b"", HangingResource())
- site = SynapseSite(
- "test", "site_tag", self.hs.config.listeners[0], base_resource, "1.0"
- )
-
- server = site.buildProtocol(None)
- client = AccumulatingProtocol()
- client.makeConnection(FakeTransport(server, self.reactor))
- server.makeConnection(FakeTransport(client, self.reactor))
-
- # Send a request with an access token that will get redacted
- server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n")
- self.pump()
-
- # Lose the connection
- e = Failure(Exception("Failed123"))
- server.connectionLost(e)
- handler.flush()
-
- # Our access token is redacted and the failure reason is logged.
- self.assertIn("/?access_token=<redacted>", output.getvalue())
- self.assertIn("Failed123", output.getvalue())
diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f68..4858e8fc 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return state_group
+ return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs):
- return {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
+ return defer.succeed(
+ {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
+ )
def get_state_group_delta(self, name):
- return None, None
+ return defer.succeed((None, None))
def register_events(self, events):
for e in events:
@@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id):
- return RoomVersions.V1.identifier
+ return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict):
@@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids()
@@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
+ @defer.inlineCallbacks
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = self.store.store_state_group(
+ sg1 = yield self.store.store_state_group(
prev_event_id_1,
event.room_id,
None,
@@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = self.store.store_state_group(
+ sg2 = yield self.store.store_state_group(
prev_event_id_2,
event.room_id,
None,
@@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
- return self.state.compute_event_context(event)
+ result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+ return result
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03..508aeba0 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
"""
Utilities for running the unit tests
"""
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+ """Create an awaitable that just returns a result."""
+ return result
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 43297b53..8522c6fc 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -22,14 +22,12 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection
-from tests.test_utils import get_awaitable_result
-
"""
Utility functions for poking events into the storage of the server under test.
"""
-def inject_member_event(
+async def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
@@ -46,7 +44,7 @@ def inject_member_event(
if extra_content:
content.update(extra_content)
- return inject_event(
+ return await inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
@@ -57,7 +55,7 @@ def inject_member_event(
)
-def inject_event(
+async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
@@ -72,37 +70,27 @@ def inject_event(
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
- test_reactor = hs.get_reactor()
-
- event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
+ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- d = hs.get_storage().persistence.persist_event(event, context)
- test_reactor.advance(0)
- get_awaitable_result(d)
+ await hs.get_storage().persistence.persist_event(event, context)
return event
-def create_event(
+async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
- test_reactor = hs.get_reactor()
-
if room_version is None:
- d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
- test_reactor.advance(0)
- room_version = get_awaitable_result(d)
+ room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- d = hs.get_event_creation_handler().create_new_client_event(
+ event, context = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
- test_reactor.advance(0)
- event, context = get_awaitable_result(d)
return event, context
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index f7381b28..b371efc0 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
#
# before we do that, we persist some other events to act as state.
- self.inject_visibility("@admin:hs", "joined")
+ yield self.inject_visibility("@admin:hs", "joined")
for i in range(0, 10):
yield self.inject_room_member("@resident%i:hs" % i)
@@ -137,8 +137,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
return event
@@ -158,8 +158,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
@@ -179,8 +179,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
diff --git a/tests/unittest.py b/tests/unittest.py
index 3175a3fa..68d2586e 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase):
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
- event_injection.inject_member_event(self.hs, room, user, membership)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, user, membership)
+ )
class FederatingHomeserverTestCase(HomeserverTestCase):
diff --git a/tests/utils.py b/tests/utils.py
index 4d17355a..ac643679 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -671,6 +671,8 @@ def create_room(hs, room_id, creator_id):
},
)
- event, context = yield event_creation_handler.create_new_client_event(builder)
+ event, context = yield defer.ensureDeferred(
+ event_creation_handler.create_new_client_event(builder)
+ )
yield persistence_store.persist_event(event, context)
diff --git a/tox.ini b/tox.ini
index 1c042cb2..595ab3ba 100644
--- a/tox.ini
+++ b/tox.ini
@@ -126,7 +126,7 @@ deps =
black==19.10b0
commands =
python -m black --check --diff .
- /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}"
+ /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}"
{toxinidir}/scripts-dev/config-lint.sh
[testenv:check_isort]
@@ -185,6 +185,7 @@ commands = mypy \
synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \
synapse/handlers/federation.py \
+ synapse/handlers/identity.py \
synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
synapse/handlers/room_member.py \