summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-06-19 15:20:00 +0200
committerAndrej Shadura <andrewsh@debian.org>2022-06-19 15:21:39 +0200
commit734a8e556ce00029d9d7ab0fed73336d24fa91f3 (patch)
treeb277733532b1b141d534133a4715a2fe765ab533
parent7a966d08c8403bcff00ac636d977097602501a69 (diff)
parent6dc64c92c6991f09910f3e6db368e6eeb4b1981e (diff)
Update upstream source from tag 'upstream/1.61.0'
Update to upstream version '1.61.0' with Debian dir 5b9bb60cc861cbccd0027b7db7acf826071dc6a0
-rwxr-xr-x.ci/scripts/checkout_complement.sh25
-rw-r--r--.git-blame-ignore-revs3
-rw-r--r--.github/workflows/tests.yml60
-rw-r--r--CHANGES.md244
-rwxr-xr-xcontrib/cmdclient/console.py9
-rw-r--r--contrib/experiments/cursesio.py165
-rw-r--r--contrib/experiments/test_messaging.py367
-rw-r--r--contrib/jitsimeetbridge/jitsimeetbridge.py298
-rw-r--r--contrib/jitsimeetbridge/syweb-jitsi-conference.patch188
-rw-r--r--contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.js712
-rw-r--r--contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.util.js408
-rw-r--r--contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js254
-rw-r--r--contrib/jitsimeetbridge/unjingle/strophe/base64.js83
-rw-r--r--contrib/jitsimeetbridge/unjingle/strophe/md5.js279
-rw-r--r--contrib/jitsimeetbridge/unjingle/strophe/strophe.js3256
-rw-r--r--contrib/jitsimeetbridge/unjingle/unjingle.js48
-rwxr-xr-xcontrib/scripts/kick_users.py88
-rw-r--r--debian/copyright23
-rwxr-xr-xdemo/start.sh20
-rw-r--r--docker/Dockerfile2
-rw-r--r--docker/complement/SynapseWorkers.Dockerfile12
-rw-r--r--docker/complement/conf-workers/caddy.complement.json72
-rw-r--r--docker/complement/conf-workers/caddy.supervisord.conf7
-rwxr-xr-xdocker/complement/conf-workers/start-complement-synapse-workers.sh23
-rw-r--r--docker/complement/conf-workers/workers-shared.yaml18
-rw-r--r--docker/complement/conf/homeserver.yaml12
-rw-r--r--docker/conf-workers/nginx.conf.j216
-rw-r--r--docker/conf-workers/shared.yaml.j211
-rwxr-xr-xdocker/configure_workers_and_start.py21
-rw-r--r--docs/SUMMARY.md1
-rw-r--r--docs/admin_api/media_admin_api.md2
-rw-r--r--docs/admin_api/user_admin_api.md4
-rw-r--r--docs/development/contributing_guide.md35
-rw-r--r--docs/development/demo.md9
-rw-r--r--docs/development/synapse_architecture/cancellation.md392
-rw-r--r--docs/message_retention_policies.md2
-rw-r--r--docs/modules/spam_checker_callbacks.md44
-rw-r--r--docs/openid.md20
-rw-r--r--docs/sample_config.yaml72
-rw-r--r--docs/structured_logging.md2
-rw-r--r--docs/upgrade.md137
-rw-r--r--docs/usage/configuration/config_documentation.md174
-rw-r--r--docs/welcome_and_overview.md6
-rw-r--r--docs/workers.md27
-rw-r--r--mypy.ini136
-rw-r--r--poetry.lock14
-rw-r--r--pyproject.toml3
-rwxr-xr-xscripts-dev/complement.sh9
-rw-r--r--scripts-dev/mypy_synapse_plugin.py25
-rwxr-xr-xsynapse/_scripts/hash_password.py10
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py28
-rw-r--r--synapse/api/auth.py45
-rw-r--r--synapse/api/constants.py16
-rw-r--r--synapse/api/errors.py37
-rw-r--r--synapse/api/filtering.py9
-rw-r--r--synapse/api/room_versions.py32
-rw-r--r--synapse/app/_base.py44
-rw-r--r--synapse/app/admin_cmd.py2
-rw-r--r--synapse/app/generic_worker.py11
-rw-r--r--synapse/app/homeserver.py36
-rw-r--r--synapse/appservice/__init__.py45
-rw-r--r--synapse/appservice/api.py15
-rw-r--r--synapse/appservice/scheduler.py6
-rw-r--r--synapse/config/_base.py81
-rw-r--r--synapse/config/_base.pyi17
-rw-r--r--synapse/config/appservice.py1
-rw-r--r--synapse/config/auth.py17
-rw-r--r--synapse/config/cache.py82
-rw-r--r--synapse/config/experimental.py6
-rw-r--r--synapse/config/groups.py39
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/config/oembed.py6
-rw-r--r--synapse/config/repository.py16
-rw-r--r--synapse/config/room.py47
-rw-r--r--synapse/config/server.py13
-rw-r--r--synapse/config/tracer.py6
-rw-r--r--synapse/event_auth.py21
-rw-r--r--synapse/events/__init__.py45
-rw-r--r--synapse/events/snapshot.py199
-rw-r--r--synapse/events/spamcheck.py169
-rw-r--r--synapse/events/third_party_rules.py3
-rw-r--r--synapse/events/validator.py9
-rw-r--r--synapse/federation/federation_base.py94
-rw-r--r--synapse/federation/federation_client.py50
-rw-r--r--synapse/federation/federation_server.py90
-rw-r--r--synapse/federation/sender/__init__.py30
-rw-r--r--synapse/federation/sender/per_destination_queue.py9
-rw-r--r--synapse/federation/sender/transaction_manager.py6
-rw-r--r--synapse/federation/transport/client.py507
-rw-r--r--synapse/federation/transport/server/__init__.py48
-rw-r--r--synapse/federation/transport/server/_base.py21
-rw-r--r--synapse/federation/transport/server/federation.py11
-rw-r--r--synapse/federation/transport/server/groups_local.py115
-rw-r--r--synapse/federation/transport/server/groups_server.py755
-rw-r--r--synapse/groups/__init__.py0
-rw-r--r--synapse/groups/attestations.py218
-rw-r--r--synapse/groups/groups_server.py1019
-rw-r--r--synapse/handlers/account_data.py10
-rw-r--r--synapse/handlers/admin.py12
-rw-r--r--synapse/handlers/appservice.py43
-rw-r--r--synapse/handlers/auth.py29
-rw-r--r--synapse/handlers/device.py54
-rw-r--r--synapse/handlers/devicemessage.py12
-rw-r--r--synapse/handlers/directory.py12
-rw-r--r--synapse/handlers/e2e_keys.py34
-rw-r--r--synapse/handlers/event_auth.py10
-rw-r--r--synapse/handlers/events.py6
-rw-r--r--synapse/handlers/federation.py163
-rw-r--r--synapse/handlers/federation_event.py272
-rw-r--r--synapse/handlers/groups_local.py503
-rw-r--r--synapse/handlers/initial_sync.py46
-rw-r--r--synapse/handlers/message.py270
-rw-r--r--synapse/handlers/oidc.py4
-rw-r--r--synapse/handlers/pagination.py48
-rw-r--r--synapse/handlers/presence.py20
-rw-r--r--synapse/handlers/profile.py83
-rw-r--r--synapse/handlers/receipts.py112
-rw-r--r--synapse/handlers/register.py3
-rw-r--r--synapse/handlers/relations.py85
-rw-r--r--synapse/handlers/room.py218
-rw-r--r--synapse/handlers/room_batch.py6
-rw-r--r--synapse/handlers/room_list.py3
-rw-r--r--synapse/handlers/room_member.py41
-rw-r--r--synapse/handlers/room_summary.py26
-rw-r--r--synapse/handlers/search.py26
-rw-r--r--synapse/handlers/stats.py6
-rw-r--r--synapse/handlers/sync.py168
-rw-r--r--synapse/handlers/typing.py22
-rw-r--r--synapse/handlers/user_directory.py6
-rw-r--r--synapse/http/client.py18
-rw-r--r--synapse/http/connectproxyclient.py39
-rw-r--r--synapse/http/federation/matrix_federation_agent.py2
-rw-r--r--synapse/http/federation/srv_resolver.py4
-rw-r--r--synapse/http/federation/well_known_resolver.py6
-rw-r--r--synapse/http/matrixfederationclient.py73
-rw-r--r--synapse/http/proxyagent.py2
-rw-r--r--synapse/http/request_metrics.py10
-rw-r--r--synapse/http/server.py74
-rw-r--r--synapse/http/site.py25
-rw-r--r--synapse/logging/_remote.py20
-rw-r--r--synapse/logging/formatter.py14
-rw-r--r--synapse/logging/handlers.py4
-rw-r--r--synapse/logging/opentracing.py114
-rw-r--r--synapse/logging/scopecontextmanager.py28
-rw-r--r--synapse/metrics/background_process_metrics.py9
-rw-r--r--synapse/metrics/jemalloc.py114
-rw-r--r--synapse/module_api/__init__.py54
-rw-r--r--synapse/module_api/errors.py2
-rw-r--r--synapse/notifier.py15
-rw-r--r--synapse/push/__init__.py74
-rw-r--r--synapse/push/action_generator.py44
-rw-r--r--synapse/push/baserules.py16
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py135
-rw-r--r--synapse/push/clientformat.py4
-rw-r--r--synapse/push/httppusher.py8
-rw-r--r--synapse/push/mailer.py18
-rw-r--r--synapse/push/push_rule_evaluator.py118
-rw-r--r--synapse/push/push_tools.py4
-rw-r--r--synapse/replication/http/_base.py21
-rw-r--r--synapse/replication/http/federation.py4
-rw-r--r--synapse/replication/http/send_event.py6
-rw-r--r--synapse/replication/slave/storage/groups.py58
-rw-r--r--synapse/replication/tcp/client.py19
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/handler.py34
-rw-r--r--synapse/replication/tcp/redis.py39
-rw-r--r--synapse/replication/tcp/streams/__init__.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py20
-rw-r--r--synapse/rest/__init__.py3
-rw-r--r--synapse/rest/admin/__init__.py3
-rw-r--r--synapse/rest/admin/groups.py50
-rw-r--r--synapse/rest/admin/media.py8
-rw-r--r--synapse/rest/admin/rooms.py37
-rw-r--r--synapse/rest/admin/users.py8
-rw-r--r--synapse/rest/client/groups.py962
-rw-r--r--synapse/rest/client/mutual_rooms.py15
-rw-r--r--synapse/rest/client/push_rule.py4
-rw-r--r--synapse/rest/client/receipts.py13
-rw-r--r--synapse/rest/client/room.py20
-rw-r--r--synapse/rest/client/sync.py12
-rw-r--r--synapse/rest/media/v1/media_repository.py358
-rw-r--r--synapse/rest/media/v1/preview_html.py64
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py53
-rw-r--r--synapse/rest/media/v1/thumbnailer.py71
-rw-r--r--synapse/server.py54
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py47
-rw-r--r--synapse/server_notices/server_notices_manager.py74
-rw-r--r--synapse/spam_checker_api/__init__.py2
-rw-r--r--synapse/state/__init__.py180
-rw-r--r--synapse/storage/__init__.py35
-rw-r--r--synapse/storage/_base.py3
-rw-r--r--synapse/storage/background_updates.py42
-rw-r--r--synapse/storage/controllers/__init__.py46
-rw-r--r--synapse/storage/controllers/persist_events.py (renamed from synapse/storage/persist_events.py)71
-rw-r--r--synapse/storage/controllers/purge_events.py (renamed from synapse/storage/purge_events.py)2
-rw-r--r--synapse/storage/controllers/state.py492
-rw-r--r--synapse/storage/database.py46
-rw-r--r--synapse/storage/databases/main/__init__.py26
-rw-r--r--synapse/storage/databases/main/appservice.py47
-rw-r--r--synapse/storage/databases/main/cache.py8
-rw-r--r--synapse/storage/databases/main/devices.py44
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py196
-rw-r--r--synapse/storage/databases/main/event_push_actions.py2
-rw-r--r--synapse/storage/databases/main/events.py245
-rw-r--r--synapse/storage/databases/main/events_worker.py110
-rw-r--r--synapse/storage/databases/main/group_server.py1407
-rw-r--r--synapse/storage/databases/main/lock.py19
-rw-r--r--synapse/storage/databases/main/media_repository.py72
-rw-r--r--synapse/storage/databases/main/metrics.py74
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py45
-rw-r--r--synapse/storage/databases/main/presence.py75
-rw-r--r--synapse/storage/databases/main/profile.py107
-rw-r--r--synapse/storage/databases/main/purge_events.py23
-rw-r--r--synapse/storage/databases/main/push_rule.py305
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py102
-rw-r--r--synapse/storage/databases/main/relations.py59
-rw-r--r--synapse/storage/databases/main/room.py125
-rw-r--r--synapse/storage/databases/main/roommember.py190
-rw-r--r--synapse/storage/databases/main/search.py33
-rw-r--r--synapse/storage/databases/main/state.py103
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stream.py46
-rw-r--r--synapse/storage/databases/main/user_directory.py47
-rw-r--r--synapse/storage/databases/state/bg_updates.py16
-rw-r--r--synapse/storage/databases/state/store.py2
-rw-r--r--synapse/storage/engines/__init__.py12
-rw-r--r--synapse/storage/engines/_base.py26
-rw-r--r--synapse/storage/engines/postgres.py92
-rw-r--r--synapse/storage/engines/sqlite.py72
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/__init__.py14
-rw-r--r--synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql18
-rw-r--r--synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql19
-rw-r--r--synapse/storage/schema/state/delta/70/08_state_group_edges_unique.sql17
-rw-r--r--synapse/storage/state.py338
-rw-r--r--synapse/storage/types.py80
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py60
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py103
-rw-r--r--synapse/util/caches/descriptors.py15
-rw-r--r--synapse/util/caches/lrucache.py79
-rw-r--r--synapse/util/retryutils.py4
-rw-r--r--synapse/visibility.py44
-rw-r--r--tests/api/test_auth.py2
-rw-r--r--tests/api/test_filtering.py6
-rw-r--r--tests/api/test_ratelimiting.py2
-rw-r--r--tests/appservice/test_api.py101
-rw-r--r--tests/appservice/test_appservice.py3
-rw-r--r--tests/config/test_cache.py8
-rw-r--r--tests/crypto/test_event_signing.py17
-rw-r--r--tests/crypto/test_keyring.py2
-rw-r--r--tests/events/test_presence_router.py2
-rw-r--r--tests/events/test_snapshot.py4
-rw-r--r--tests/federation/test_federation_sender.py40
-rw-r--r--tests/federation/test_federation_server.py6
-rw-r--r--tests/federation/transport/server/__init__.py13
-rw-r--r--tests/federation/transport/server/test__base.py141
-rw-r--r--tests/federation/transport/test_server.py4
-rw-r--r--tests/handlers/test_appservice.py16
-rw-r--r--tests/handlers/test_directory.py3
-rw-r--r--tests/handlers/test_federation.py19
-rw-r--r--tests/handlers/test_federation_event.py15
-rw-r--r--tests/handlers/test_message.py14
-rw-r--r--tests/handlers/test_receipts.py94
-rw-r--r--tests/handlers/test_room_summary.py20
-rw-r--r--tests/handlers/test_sync.py1
-rw-r--r--tests/handlers/test_typing.py41
-rw-r--r--tests/handlers/test_user_directory.py3
-rw-r--r--tests/http/server/__init__.py13
-rw-r--r--tests/http/server/_base.py100
-rw-r--r--tests/http/test_fedclient.py6
-rw-r--r--tests/http/test_servlet.py74
-rw-r--r--tests/http/test_site.py2
-rw-r--r--tests/module_api/test_api.py2
-rw-r--r--tests/push/test_push_rule_evaluator.py84
-rw-r--r--tests/replication/_base.py54
-rw-r--r--tests/replication/http/__init__.py13
-rw-r--r--tests/replication/http/test__base.py106
-rw-r--r--tests/replication/slave/storage/_base.py2
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/replication/slave/storage/test_receipts.py12
-rw-r--r--tests/replication/tcp/test_handler.py73
-rw-r--r--tests/replication/test_sharded_event_persister.py14
-rw-r--r--tests/rest/admin/test_admin.py90
-rw-r--r--tests/rest/admin/test_room.py3
-rw-r--r--tests/rest/admin/test_user.py4
-rw-r--r--tests/rest/client/test_account.py1
-rw-r--r--tests/rest/client/test_auth.py41
-rw-r--r--tests/rest/client/test_devices.py (renamed from tests/rest/client/test_device_lists.py)43
-rw-r--r--tests/rest/client/test_events.py3
-rw-r--r--tests/rest/client/test_groups.py56
-rw-r--r--tests/rest/client/test_login.py2
-rw-r--r--tests/rest/client/test_mutual_rooms.py2
-rw-r--r--tests/rest/client/test_notifications.py91
-rw-r--r--tests/rest/client/test_register.py2
-rw-r--r--tests/rest/client/test_relations.py89
-rw-r--r--tests/rest/client/test_retention.py39
-rw-r--r--tests/rest/client/test_room_batch.py7
-rw-r--r--tests/rest/client/test_rooms.py267
-rw-r--r--tests/rest/client/test_sendtodevice.py5
-rw-r--r--tests/rest/client/test_shadow_banned.py4
-rw-r--r--tests/rest/client/test_sync.py41
-rw-r--r--tests/rest/client/test_typing.py3
-rw-r--r--tests/rest/client/test_upgrade_room.py38
-rw-r--r--tests/rest/media/test_media_retention.py321
-rw-r--r--tests/rest/media/v1/test_html_preview.py37
-rw-r--r--tests/rest/media/v1/test_url_preview.py35
-rw-r--r--tests/scripts/test_new_matrix_user.py13
-rw-r--r--tests/server.py14
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py11
-rw-r--r--tests/storage/databases/main/test_events_worker.py25
-rw-r--r--tests/storage/databases/main/test_lock.py54
-rw-r--r--tests/storage/test_appservice.py27
-rw-r--r--tests/storage/test_base.py2
-rw-r--r--tests/storage/test_devices.py7
-rw-r--r--tests/storage/test_event_chain.py3
-rw-r--r--tests/storage/test_event_federation.py9
-rw-r--r--tests/storage/test_events.py58
-rw-r--r--tests/storage/test_monthly_active_users.py83
-rw-r--r--tests/storage/test_purge.py19
-rw-r--r--tests/storage/test_redaction.py14
-rw-r--r--tests/storage/test_room.py12
-rw-r--r--tests/storage/test_room_search.py4
-rw-r--r--tests/storage/test_roommember.py2
-rw-r--r--tests/storage/test_state.py2
-rw-r--r--tests/storage/test_user_directory.py1
-rw-r--r--tests/storage/util/test_partial_state_events_tracker.py59
-rw-r--r--tests/test_mau.py3
-rw-r--r--tests/test_server.py111
-rw-r--r--tests/test_state.py36
-rw-r--r--tests/test_types.py21
-rw-r--r--tests/test_utils/event_injection.py2
-rw-r--r--tests/test_visibility.py46
-rw-r--r--tests/unittest.py2
-rw-r--r--tests/util/test_lrucache.py58
-rw-r--r--tests/utils.py2
338 files changed, 9723 insertions, 16202 deletions
diff --git a/.ci/scripts/checkout_complement.sh b/.ci/scripts/checkout_complement.sh
new file mode 100755
index 00000000..379f5d43
--- /dev/null
+++ b/.ci/scripts/checkout_complement.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+#
+# Fetches a version of complement which best matches the current build.
+#
+# The tarball is unpacked into `./complement`.
+
+set -e
+mkdir -p complement
+
+# Pick an appropriate version of complement. Depending on whether this is a PR or release,
+# etc. we need to use different fallbacks:
+#
+# 1. First check if there's a similarly named branch (GITHUB_HEAD_REF
+# for pull requests, otherwise GITHUB_REF).
+# 2. Attempt to use the base branch, e.g. when merging into release-vX.Y
+# (GITHUB_BASE_REF for pull requests).
+# 3. Use the default complement branch ("HEAD").
+for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "HEAD"; do
+ # Skip empty branch names and merge commits.
+ if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then
+ continue
+ fi
+
+ (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break
+done
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
index 83ddd568..50d28c68 100644
--- a/.git-blame-ignore-revs
+++ b/.git-blame-ignore-revs
@@ -6,3 +6,6 @@ aff1eb7c671b0a3813407321d2702ec46c71fa56
# Update black to 20.8b1 (#9381).
0a00b7ff14890987f09112a2ae696c61001e6cf1
+
+# Convert tests/rest/admin/test_room.py to unix file endings (#7953).
+c4268e3da64f1abb5b31deaeb5769adb6510c0a7 \ No newline at end of file
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index efa35b71..83ab7273 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -306,7 +306,7 @@ jobs:
- run: .ci/scripts/test_synapse_port_db.sh
complement:
- if: ${{ !failure() && !cancelled() }}
+ if: "${{ !failure() && !cancelled() }}"
needs: linting-done
runs-on: ubuntu-latest
@@ -333,30 +333,50 @@ jobs:
# Attempt to check out the same branch of Complement as the PR. If it
# doesn't exist, fallback to HEAD.
- name: Checkout complement
+ run: synapse/.ci/scripts/checkout_complement.sh
+
+ - run: |
+ set -o pipefail
+ COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | gotestfmt
shell: bash
+ name: Run Complement Tests
+
+ # We only run the workers tests on `develop` for now, because they're too slow to wait for on PRs.
+ # Sadly, you can't have an `if` condition on the value of a matrix, so this is a temporary, separate job for now.
+ # GitHub Actions doesn't support YAML anchors, so it's full-on duplication for now.
+ complement-developonly:
+ if: "${{ !failure() && !cancelled() && (github.ref == 'refs/heads/develop') }}"
+ needs: linting-done
+ runs-on: ubuntu-latest
+
+ steps:
+ # The path is set via a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on the path to run Complement.
+ # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path
+ - name: "Set Go Version"
run: |
- mkdir -p complement
- # Attempt to use the version of complement which best matches the current
- # build. Depending on whether this is a PR or release, etc. we need to
- # use different fallbacks.
- #
- # 1. First check if there's a similarly named branch (GITHUB_HEAD_REF
- # for pull requests, otherwise GITHUB_REF).
- # 2. Attempt to use the base branch, e.g. when merging into release-vX.Y
- # (GITHUB_BASE_REF for pull requests).
- # 3. Use the default complement branch ("HEAD").
- for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "HEAD"; do
- # Skip empty branch names and merge commits.
- if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then
- continue
- fi
-
- (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break
- done
+ # Add Go 1.17 to the PATH: see https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md#environment-variables-2
+ echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH
+ # Add the Go path to the PATH: We need this so we can call gotestfmt
+ echo "~/go/bin" >> $GITHUB_PATH
+
+ - name: "Install Complement Dependencies"
+ run: |
+ sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev
+ go get -v github.com/haveyoudebuggedit/gotestfmt/v2/cmd/gotestfmt@latest
+
+ - name: Run actions/checkout@v2 for synapse
+ uses: actions/checkout@v2
+ with:
+ path: synapse
+
+ # Attempt to check out the same branch of Complement as the PR. If it
+ # doesn't exist, fallback to HEAD.
+ - name: Checkout complement
+ run: synapse/.ci/scripts/checkout_complement.sh
- run: |
set -o pipefail
- COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | gotestfmt
+ WORKERS=1 COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | gotestfmt
shell: bash
name: Run Complement Tests
diff --git a/CHANGES.md b/CHANGES.md
index e10ac031..84641aee 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,245 @@
+Synapse 1.61.0 (2022-06-14)
+===========================
+
+This release removes support for the non-standard feature known both as 'groups' and as 'communities', which have been superseded by *Spaces*.
+
+See [the upgrade notes](https://github.com/matrix-org/synapse/blob/develop/docs/upgrade.md#upgrading-to-v1610)
+for more details.
+
+Improved Documentation
+----------------------
+
+- Mention removed community/group worker endpoints in [upgrade.md](https://github.com/matrix-org/synapse/blob/develop/docs/upgrade.md#upgrading-to-v1610s). Contributed by @olmari. ([\#13023](https://github.com/matrix-org/synapse/issues/13023))
+
+
+Synapse 1.61.0rc1 (2022-06-07)
+==============================
+
+Features
+--------
+
+- Add new `media_retention` options to the homeserver config for routinely cleaning up non-recently accessed media. ([\#12732](https://github.com/matrix-org/synapse/issues/12732), [\#12972](https://github.com/matrix-org/synapse/issues/12972), [\#12977](https://github.com/matrix-org/synapse/issues/12977))
+- Experimental support for [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772): Push rule for mutually related events. ([\#12740](https://github.com/matrix-org/synapse/issues/12740), [\#12859](https://github.com/matrix-org/synapse/issues/12859))
+- Update to the `check_event_for_spam` module callback: Deprecate the current callback signature, replace it with a new signature that is both less ambiguous (replacing booleans with explicit allow/block) and more powerful (ability to return explicit error codes). ([\#12808](https://github.com/matrix-org/synapse/issues/12808))
+- Add storage and module API methods to get monthly active users (and their corresponding appservices) within an optionally specified time range. ([\#12838](https://github.com/matrix-org/synapse/issues/12838), [\#12917](https://github.com/matrix-org/synapse/issues/12917))
+- Support the new error code `ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED` from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). ([\#12845](https://github.com/matrix-org/synapse/issues/12845), [\#12923](https://github.com/matrix-org/synapse/issues/12923))
+- Add a configurable background job to delete stale devices. ([\#12855](https://github.com/matrix-org/synapse/issues/12855))
+- Improve URL previews for pages with empty elements. ([\#12951](https://github.com/matrix-org/synapse/issues/12951))
+- Allow updating a user's password using the admin API without logging out their devices. Contributed by @jcgruenhage. ([\#12952](https://github.com/matrix-org/synapse/issues/12952))
+
+
+Bugfixes
+--------
+
+- Always send an `access_token` in `/thirdparty/` requests to appservices, as required by the [Application Service API specification](https://spec.matrix.org/v1.1/application-service-api/#third-party-networks). ([\#12746](https://github.com/matrix-org/synapse/issues/12746))
+- Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as having 'participated' in it. ([\#12766](https://github.com/matrix-org/synapse/issues/12766))
+- Delete events from the `federation_inbound_events_staging` table when a room is purged through the admin API. ([\#12784](https://github.com/matrix-org/synapse/issues/12784))
+- Fix a bug where we did not correctly handle invalid device list updates over federation. Contributed by Carl Bordum Hansen. ([\#12829](https://github.com/matrix-org/synapse/issues/12829))
+- Fix a bug which allowed multiple async operations to access database locks concurrently. Contributed by @sumnerevans @ Beeper. ([\#12832](https://github.com/matrix-org/synapse/issues/12832))
+- Fix an issue introduced in Synapse 0.34 where the `/notifications` endpoint would only return notifications if a user registered at least one pusher. Contributed by Famedly. ([\#12840](https://github.com/matrix-org/synapse/issues/12840))
+- Fix a bug where servers using a Postgres database would fail to backfill from an insertion event when MSC2716 is enabled (`experimental_features.msc2716_enabled`). ([\#12843](https://github.com/matrix-org/synapse/issues/12843))
+- Fix [MSC3787](https://github.com/matrix-org/matrix-spec-proposals/pull/3787) rooms being omitted from room directory, room summary and space hierarchy responses. ([\#12858](https://github.com/matrix-org/synapse/issues/12858))
+- Fix a bug introduced in Synapse 1.54.0 which could sometimes cause exceptions when handling federated traffic. ([\#12877](https://github.com/matrix-org/synapse/issues/12877))
+- Fix a bug introduced in Synapse 1.59.0 which caused room deletion to fail with a foreign key violation error. ([\#12889](https://github.com/matrix-org/synapse/issues/12889))
+- Fix a long-standing bug which caused the `/messages` endpoint to return an incorrect `end` attribute when there were no more events. Contributed by @Vetchu. ([\#12903](https://github.com/matrix-org/synapse/issues/12903))
+- Fix a bug introduced in Synapse 1.58.0 where `/sync` would fail if the most recent event in a room was a redaction of an event that has since been purged. ([\#12905](https://github.com/matrix-org/synapse/issues/12905))
+- Fix a potential memory leak when generating thumbnails. ([\#12932](https://github.com/matrix-org/synapse/issues/12932))
+- Fix a long-standing bug where a URL preview would break if the image failed to download. ([\#12950](https://github.com/matrix-org/synapse/issues/12950))
+
+
+Improved Documentation
+----------------------
+
+- Fix typographical errors in documentation. ([\#12863](https://github.com/matrix-org/synapse/issues/12863))
+- Fix documentation incorrectly stating the `sendToDevice` endpoint can be directed at generic workers. Contributed by Nick @ Beeper. ([\#12867](https://github.com/matrix-org/synapse/issues/12867))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove support for the non-standard groups/communities feature from Synapse. ([\#12553](https://github.com/matrix-org/synapse/issues/12553), [\#12558](https://github.com/matrix-org/synapse/issues/12558), [\#12563](https://github.com/matrix-org/synapse/issues/12563), [\#12895](https://github.com/matrix-org/synapse/issues/12895), [\#12897](https://github.com/matrix-org/synapse/issues/12897), [\#12899](https://github.com/matrix-org/synapse/issues/12899), [\#12900](https://github.com/matrix-org/synapse/issues/12900), [\#12936](https://github.com/matrix-org/synapse/issues/12936), [\#12966](https://github.com/matrix-org/synapse/issues/12966))
+- Remove contributed `kick_users.py` script. This is broken under Python 3, and is not added to the environment when `pip install`ing Synapse. ([\#12908](https://github.com/matrix-org/synapse/issues/12908))
+- Remove `contrib/jitsimeetbridge`. This was an unused experiment that hasn't been meaningfully changed since 2014. ([\#12909](https://github.com/matrix-org/synapse/issues/12909))
+- Remove unused `contrib/experiements/cursesio.py` script, which fails to run under Python 3. ([\#12910](https://github.com/matrix-org/synapse/issues/12910))
+- Remove unused `contrib/experiements/test_messaging.py` script. This fails to run on Python 3. ([\#12911](https://github.com/matrix-org/synapse/issues/12911))
+
+
+Internal Changes
+----------------
+
+- Test Synapse against Complement with workers. ([\#12810](https://github.com/matrix-org/synapse/issues/12810), [\#12933](https://github.com/matrix-org/synapse/issues/12933))
+- Reduce the amount of state we pull from the DB. ([\#12811](https://github.com/matrix-org/synapse/issues/12811), [\#12964](https://github.com/matrix-org/synapse/issues/12964))
+- Try other homeservers when re-syncing state for rooms with partial state. ([\#12812](https://github.com/matrix-org/synapse/issues/12812))
+- Resume state re-syncing for rooms with partial state after a Synapse restart. ([\#12813](https://github.com/matrix-org/synapse/issues/12813))
+- Remove Mutual Rooms' ([MSC2666](https://github.com/matrix-org/matrix-spec-proposals/pull/2666)) endpoint dependency on the User Directory. ([\#12836](https://github.com/matrix-org/synapse/issues/12836))
+- Experimental: expand `check_event_for_spam` with ability to return additional fields. This enables spam-checker implementations to experiment with mechanisms to give users more information about why they are blocked and whether any action is needed from them to be unblocked. ([\#12846](https://github.com/matrix-org/synapse/issues/12846))
+- Remove `dont_notify` from the `.m.rule.room.server_acl` rule. ([\#12849](https://github.com/matrix-org/synapse/issues/12849))
+- Remove the unstable `/hierarchy` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#12851](https://github.com/matrix-org/synapse/issues/12851))
+- Pull out less state when handling gaps in room DAG. ([\#12852](https://github.com/matrix-org/synapse/issues/12852), [\#12904](https://github.com/matrix-org/synapse/issues/12904))
+- Clean-up the push rules datastore. ([\#12856](https://github.com/matrix-org/synapse/issues/12856))
+- Correct a type annotation in the URL preview source code. ([\#12860](https://github.com/matrix-org/synapse/issues/12860))
+- Update `pyjwt` dependency to [2.4.0](https://github.com/jpadilla/pyjwt/releases/tag/2.4.0). ([\#12865](https://github.com/matrix-org/synapse/issues/12865))
+- Enable the `/account/whoami` endpoint on synapse worker processes. Contributed by Nick @ Beeper. ([\#12866](https://github.com/matrix-org/synapse/issues/12866))
+- Enable the `batch_send` endpoint on synapse worker processes. Contributed by Nick @ Beeper. ([\#12868](https://github.com/matrix-org/synapse/issues/12868))
+- Don't generate empty AS transactions when the AS is flagged as down. Contributed by Nick @ Beeper. ([\#12869](https://github.com/matrix-org/synapse/issues/12869))
+- Fix up the variable `state_store` naming. ([\#12871](https://github.com/matrix-org/synapse/issues/12871))
+- Faster room joins: when querying the current state of the room, wait for state to be populated. ([\#12872](https://github.com/matrix-org/synapse/issues/12872))
+- Avoid running queries which will never result in deletions. ([\#12879](https://github.com/matrix-org/synapse/issues/12879))
+- Use constants for EDU types. ([\#12884](https://github.com/matrix-org/synapse/issues/12884))
+- Reduce database load of `/sync` when presence is enabled. ([\#12885](https://github.com/matrix-org/synapse/issues/12885))
+- Refactor `have_seen_events` to reduce memory consumed when processing federation traffic. ([\#12886](https://github.com/matrix-org/synapse/issues/12886))
+- Refactor receipt linearization code. ([\#12888](https://github.com/matrix-org/synapse/issues/12888))
+- Add type annotations to `synapse.logging.opentracing`. ([\#12894](https://github.com/matrix-org/synapse/issues/12894))
+- Remove PyNaCl occurrences directly used in Synapse code. ([\#12902](https://github.com/matrix-org/synapse/issues/12902))
+- Bump types-jsonschema from 4.4.1 to 4.4.6. ([\#12912](https://github.com/matrix-org/synapse/issues/12912))
+- Rename storage classes. ([\#12913](https://github.com/matrix-org/synapse/issues/12913))
+- Preparation for database schema simplifications: stop reading from `event_edges.room_id`. ([\#12914](https://github.com/matrix-org/synapse/issues/12914))
+- Check if we are in a virtual environment before overriding the `PYTHONPATH` environment variable in the demo script. ([\#12916](https://github.com/matrix-org/synapse/issues/12916))
+- Improve the logging when signature checks on events fail. ([\#12925](https://github.com/matrix-org/synapse/issues/12925))
+
+
+Synapse 1.60.0 (2022-05-31)
+===========================
+
+This release of Synapse adds a unique index to the `state_group_edges` table, in
+order to prevent accidentally introducing duplicate information (for example,
+because a database backup was restored multiple times). If your Synapse database
+already has duplicate rows in this table, this could fail with an error and
+require manual remediation.
+
+Additionally, the signature of the `check_event_for_spam` module callback has changed.
+The previous signature has been deprecated and remains working for now. Module authors
+should update their modules to use the new signature where possible.
+
+See [the upgrade notes](https://github.com/matrix-org/synapse/blob/develop/docs/upgrade.md#upgrading-to-v1600)
+for more details.
+
+Bugfixes
+--------
+
+- Fix a bug introduced in Synapse 1.60.0rc1 that would break some imports from `synapse.module_api`. ([\#12918](https://github.com/matrix-org/synapse/issues/12918))
+
+
+Synapse 1.60.0rc2 (2022-05-27)
+==============================
+
+Features
+--------
+
+- Add an option allowing users to use their password to reauthenticate for privileged actions even though password login is disabled. ([\#12883](https://github.com/matrix-org/synapse/issues/12883))
+
+
+Bugfixes
+--------
+
+- Explicitly close `ijson` coroutines once we are done with them, instead of leaving the garbage collector to close them. ([\#12875](https://github.com/matrix-org/synapse/issues/12875))
+
+
+Internal Changes
+----------------
+
+- Improve URL previews by not including the content of media tags in the generated description. ([\#12887](https://github.com/matrix-org/synapse/issues/12887))
+
+
+Synapse 1.60.0rc1 (2022-05-24)
+==============================
+
+Features
+--------
+
+- Measure the time taken in spam-checking callbacks and expose those measurements as metrics. ([\#12513](https://github.com/matrix-org/synapse/issues/12513))
+- Add a `default_power_level_content_override` config option to set default room power levels per room preset. ([\#12618](https://github.com/matrix-org/synapse/issues/12618))
+- Add support for [MSC3787: Allowing knocks to restricted rooms](https://github.com/matrix-org/matrix-spec-proposals/pull/3787). ([\#12623](https://github.com/matrix-org/synapse/issues/12623))
+- Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands. ([\#12672](https://github.com/matrix-org/synapse/issues/12672), [\#12809](https://github.com/matrix-org/synapse/issues/12809))
+- Synapse will now reload [cache config](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#caching) when it receives a [SIGHUP](https://en.wikipedia.org/wiki/SIGHUP) signal. ([\#12673](https://github.com/matrix-org/synapse/issues/12673))
+- Add a config options to allow for auto-tuning of caches. ([\#12701](https://github.com/matrix-org/synapse/issues/12701))
+- Update [MSC2716](https://github.com/matrix-org/matrix-spec-proposals/pull/2716) implementation to process marker events from the current state to avoid markers being lost in timeline gaps for federated servers which would cause the imported history to be undiscovered. ([\#12718](https://github.com/matrix-org/synapse/issues/12718))
+- Add a `drop_federated_event` callback to `SpamChecker` to disregard inbound federated events before they take up much processing power, in an emergency. ([\#12744](https://github.com/matrix-org/synapse/issues/12744))
+- Implement [MSC3818: Copy room type on upgrade](https://github.com/matrix-org/matrix-spec-proposals/pull/3818). ([\#12786](https://github.com/matrix-org/synapse/issues/12786), [\#12792](https://github.com/matrix-org/synapse/issues/12792))
+- Update to the `check_event_for_spam` module callback. Deprecate the current callback signature, replace it with a new signature that is both less ambiguous (replacing booleans with explicit allow/block) and more powerful (ability to return explicit error codes). ([\#12808](https://github.com/matrix-org/synapse/issues/12808))
+
+
+Bugfixes
+--------
+
+- Fix a bug introduced in Synapse 1.7.0 that would prevent events from being sent to clients if there's a retention policy in the room when the support for retention policies is disabled. ([\#12611](https://github.com/matrix-org/synapse/issues/12611))
+- Fix a bug introduced in Synapse 1.57.0 where `/messages` would throw a 500 error when querying for a non-existent room. ([\#12683](https://github.com/matrix-org/synapse/issues/12683))
+- Add a unique index to `state_group_edges` to prevent duplicates being accidentally introduced and the consequential impact to performance. ([\#12687](https://github.com/matrix-org/synapse/issues/12687))
+- Fix a long-standing bug where an empty room would be created when a user with an insufficient power level tried to upgrade a room. ([\#12696](https://github.com/matrix-org/synapse/issues/12696))
+- Fix a bug introduced in Synapse 1.30.0 where empty rooms could be automatically created if a monthly active users limit is set. ([\#12713](https://github.com/matrix-org/synapse/issues/12713))
+- Fix push to dismiss notifications when read on another client. Contributed by @SpiritCroc @ Beeper. ([\#12721](https://github.com/matrix-org/synapse/issues/12721))
+- Fix poor database performance when reading the cache invalidation stream for large servers with lots of workers. ([\#12747](https://github.com/matrix-org/synapse/issues/12747))
+- Fix a long-standing bug where the user directory background process would fail to make forward progress if a user included a null codepoint in their display name or avatar. ([\#12762](https://github.com/matrix-org/synapse/issues/12762))
+- Delete events from the `federation_inbound_events_staging` table when a room is purged through the admin API. ([\#12770](https://github.com/matrix-org/synapse/issues/12770))
+- Give a meaningful error message when a client tries to create a room with an invalid alias localpart. ([\#12779](https://github.com/matrix-org/synapse/issues/12779))
+- Fix a bug introduced in 1.43.0 where a file (`providers.json`) was never closed. Contributed by @arkamar. ([\#12794](https://github.com/matrix-org/synapse/issues/12794))
+- Fix a long-standing bug where finished log contexts would be re-started when failing to contact remote homeservers. ([\#12803](https://github.com/matrix-org/synapse/issues/12803))
+- Fix a bug, introduced in Synapse 1.21.0, that led to media thumbnails being unusable before the index has been added in the background. ([\#12823](https://github.com/matrix-org/synapse/issues/12823))
+
+
+Updates to the Docker image
+---------------------------
+
+- Fix the docker file after a dependency update. ([\#12853](https://github.com/matrix-org/synapse/issues/12853))
+
+
+Improved Documentation
+----------------------
+
+- Fix a typo in the Media Admin API documentation. ([\#12715](https://github.com/matrix-org/synapse/issues/12715))
+- Update the OpenID Connect example for Keycloak to be compatible with newer versions of Keycloak. Contributed by @nhh. ([\#12727](https://github.com/matrix-org/synapse/issues/12727))
+- Fix typo in server listener documentation. ([\#12742](https://github.com/matrix-org/synapse/issues/12742))
+- Link to the configuration manual from the welcome page of the documentation. ([\#12748](https://github.com/matrix-org/synapse/issues/12748))
+- Fix typo in `run_background_tasks_on` option name in configuration manual documentation. ([\#12749](https://github.com/matrix-org/synapse/issues/12749))
+- Add information regarding the `rc_invites` ratelimiting option to the configuration docs. ([\#12759](https://github.com/matrix-org/synapse/issues/12759))
+- Add documentation for cancellation of request processing. ([\#12761](https://github.com/matrix-org/synapse/issues/12761))
+- Recommend using docker to run tests against postgres. ([\#12765](https://github.com/matrix-org/synapse/issues/12765))
+- Add missing user directory endpoint from the generic worker documentation. Contributed by @olmari. ([\#12773](https://github.com/matrix-org/synapse/issues/12773))
+- Add additional info to documentation of config option `cache_autotuning`. ([\#12776](https://github.com/matrix-org/synapse/issues/12776))
+- Update configuration manual documentation to document size-related suffixes. ([\#12777](https://github.com/matrix-org/synapse/issues/12777))
+- Fix invalid YAML syntax in the example documentation for the `url_preview_accept_language` config option. ([\#12785](https://github.com/matrix-org/synapse/issues/12785))
+
+
+Deprecations and Removals
+-------------------------
+
+- Require a body in POST requests to `/rooms/{roomId}/receipt/{receiptType}/{eventId}`, as required by the [Matrix specification](https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidreceiptreceipttypeeventid). This breaks compatibility with Element Android 1.2.0 and earlier: users of those clients will be unable to send read receipts. ([\#12709](https://github.com/matrix-org/synapse/issues/12709))
+
+
+Internal Changes
+----------------
+
+- Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. ([\#10533](https://github.com/matrix-org/synapse/issues/10533))
+- Preparation for faster-room-join work: return subsets of room state which we already have, immediately. ([\#12498](https://github.com/matrix-org/synapse/issues/12498))
+- Add `@cancellable` decorator, for use on endpoint methods that can be cancelled when clients disconnect. ([\#12586](https://github.com/matrix-org/synapse/issues/12586), [\#12588](https://github.com/matrix-org/synapse/issues/12588), [\#12630](https://github.com/matrix-org/synapse/issues/12630), [\#12694](https://github.com/matrix-org/synapse/issues/12694), [\#12698](https://github.com/matrix-org/synapse/issues/12698), [\#12699](https://github.com/matrix-org/synapse/issues/12699), [\#12700](https://github.com/matrix-org/synapse/issues/12700), [\#12705](https://github.com/matrix-org/synapse/issues/12705))
+- Enable cancellation of `GET /rooms/$room_id/members`, `GET /rooms/$room_id/state` and `GET /rooms/$room_id/state/$event_type/*` requests. ([\#12708](https://github.com/matrix-org/synapse/issues/12708))
+- Improve documentation of the `synapse.push` module. ([\#12676](https://github.com/matrix-org/synapse/issues/12676))
+- Refactor functions to on `PushRuleEvaluatorForEvent`. ([\#12677](https://github.com/matrix-org/synapse/issues/12677))
+- Preparation for database schema simplifications: stop writing to `event_reference_hashes`. ([\#12679](https://github.com/matrix-org/synapse/issues/12679))
+- Remove code which updates unused database column `application_services_state.last_txn`. ([\#12680](https://github.com/matrix-org/synapse/issues/12680))
+- Refactor `EventContext` class. ([\#12689](https://github.com/matrix-org/synapse/issues/12689))
+- Remove an unneeded class in the push code. ([\#12691](https://github.com/matrix-org/synapse/issues/12691))
+- Consolidate parsing of relation information from events. ([\#12693](https://github.com/matrix-org/synapse/issues/12693))
+- Convert namespace class `Codes` into a string enum. ([\#12703](https://github.com/matrix-org/synapse/issues/12703))
+- Optimize private read receipt filtering. ([\#12711](https://github.com/matrix-org/synapse/issues/12711))
+- Drop the logging level of status messages for the URL preview cache expiry job from INFO to DEBUG. ([\#12720](https://github.com/matrix-org/synapse/issues/12720))
+- Downgrade some OIDC errors to warnings in the logs, to reduce the noise of Sentry reports. ([\#12723](https://github.com/matrix-org/synapse/issues/12723))
+- Update configs used by Complement to allow more invites/3PID validations during tests. ([\#12731](https://github.com/matrix-org/synapse/issues/12731))
+- Tweak the mypy plugin so that `@cached` can accept `on_invalidate=None`. ([\#12769](https://github.com/matrix-org/synapse/issues/12769))
+- Move methods that call `add_push_rule` to the `PushRuleStore` class. ([\#12772](https://github.com/matrix-org/synapse/issues/12772))
+- Make handling of federation Authorization header (more) compliant with RFC7230. ([\#12774](https://github.com/matrix-org/synapse/issues/12774))
+- Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens. ([\#12775](https://github.com/matrix-org/synapse/issues/12775))
+- Do not keep going if there are 5 back-to-back background update failures. ([\#12781](https://github.com/matrix-org/synapse/issues/12781))
+- Fix federation when using the demo scripts. ([\#12783](https://github.com/matrix-org/synapse/issues/12783))
+- The `hash_password` script now fails when it is called without specifying a config file. Contributed by @jae1911. ([\#12789](https://github.com/matrix-org/synapse/issues/12789))
+- Improve and fix type hints. ([\#12567](https://github.com/matrix-org/synapse/issues/12567), [\#12477](https://github.com/matrix-org/synapse/issues/12477), [\#12717](https://github.com/matrix-org/synapse/issues/12717), [\#12753](https://github.com/matrix-org/synapse/issues/12753), [\#12695](https://github.com/matrix-org/synapse/issues/12695), [\#12734](https://github.com/matrix-org/synapse/issues/12734), [\#12716](https://github.com/matrix-org/synapse/issues/12716), [\#12726](https://github.com/matrix-org/synapse/issues/12726), [\#12790](https://github.com/matrix-org/synapse/issues/12790), [\#12833](https://github.com/matrix-org/synapse/issues/12833))
+- Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible. ([\#12791](https://github.com/matrix-org/synapse/issues/12791))
+- Remove Caddy from the Synapse workers image used in Complement. ([\#12818](https://github.com/matrix-org/synapse/issues/12818))
+- Add Complement's shared registration secret to the Complement worker image. This fixes tests that depend on it. ([\#12819](https://github.com/matrix-org/synapse/issues/12819))
+- Support registering Application Services when running with workers under Complement. ([\#12826](https://github.com/matrix-org/synapse/issues/12826))
+- Disable 'faster room join' Complement tests when testing against Synapse with workers. ([\#12842](https://github.com/matrix-org/synapse/issues/12842))
+
+
Synapse 1.59.1 (2022-05-18)
===========================
@@ -89,7 +331,7 @@ Deprecations and Removals
-------------------------
- Remove unstable identifiers from [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069). ([\#12596](https://github.com/matrix-org/synapse/issues/12596))
-- Remove the unspecified `m.login.jwt` login type and the unstable `uk.half-shot.msc2778.login.application_service` from
+- Remove the unspecified `m.login.jwt` login type and the unstable `uk.half-shot.msc2778.login.application_service` from
[MSC2778](https://github.com/matrix-org/matrix-doc/pull/2778). ([\#12597](https://github.com/matrix-org/synapse/issues/12597))
- Synapse now requires at least Python 3.7.1 (up from 3.7.0), for compatibility with the latest Twisted trunk. ([\#12613](https://github.com/matrix-org/synapse/issues/12613))
diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index 856dd437..895b2a7a 100755
--- a/contrib/cmdclient/console.py
+++ b/contrib/cmdclient/console.py
@@ -16,6 +16,7 @@
""" Starts a synapse client console. """
import argparse
+import binascii
import cmd
import getpass
import json
@@ -26,9 +27,8 @@ import urllib
from http import TwistedHttpClient
from typing import Optional
-import nacl.encoding
-import nacl.signing
import urlparse
+from signedjson.key import NACL_ED25519, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from twisted.internet import defer, reactor, threads
@@ -41,7 +41,6 @@ TRUSTED_ID_SERVERS = ["localhost:8001"]
class SynapseCmd(cmd.Cmd):
-
"""Basic synapse command-line processor.
This processes commands from the user and calls the relevant HTTP methods.
@@ -420,8 +419,8 @@ class SynapseCmd(cmd.Cmd):
pubKey = None
pubKeyObj = yield self.http_client.do_request("GET", url)
if "public_key" in pubKeyObj:
- pubKey = nacl.signing.VerifyKey(
- pubKeyObj["public_key"], encoder=nacl.encoding.HexEncoder
+ pubKey = decode_verify_key_bytes(
+ NACL_ED25519, binascii.unhexlify(pubKeyObj["public_key"])
)
else:
print("No public key found in pubkey response!")
diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py
deleted file mode 100644
index 7695cc77..00000000
--- a/contrib/experiments/cursesio.py
+++ /dev/null
@@ -1,165 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import curses
-import curses.wrapper
-from curses.ascii import isprint
-
-from twisted.internet import reactor
-
-
-class CursesStdIO:
- def __init__(self, stdscr, callback=None):
- self.statusText = "Synapse test app -"
- self.searchText = ""
- self.stdscr = stdscr
-
- self.logLine = ""
-
- self.callback = callback
-
- self._setup()
-
- def _setup(self):
- self.stdscr.nodelay(1) # Make non blocking
-
- self.rows, self.cols = self.stdscr.getmaxyx()
- self.lines = []
-
- curses.use_default_colors()
-
- self.paintStatus(self.statusText)
- self.stdscr.refresh()
-
- def set_callback(self, callback):
- self.callback = callback
-
- def fileno(self):
- """We want to select on FD 0"""
- return 0
-
- def connectionLost(self, reason):
- self.close()
-
- def print_line(self, text):
- """add a line to the internal list of lines"""
-
- self.lines.append(text)
- self.redraw()
-
- def print_log(self, text):
- self.logLine = text
- self.redraw()
-
- def redraw(self):
- """method for redisplaying lines based on internal list of lines"""
-
- self.stdscr.clear()
- self.paintStatus(self.statusText)
- i = 0
- index = len(self.lines) - 1
- while i < (self.rows - 3) and index >= 0:
- self.stdscr.addstr(self.rows - 3 - i, 0, self.lines[index], curses.A_NORMAL)
- i = i + 1
- index = index - 1
-
- self.printLogLine(self.logLine)
-
- self.stdscr.refresh()
-
- def paintStatus(self, text):
- if len(text) > self.cols:
- raise RuntimeError("TextTooLongError")
-
- self.stdscr.addstr(
- self.rows - 2, 0, text + " " * (self.cols - len(text)), curses.A_STANDOUT
- )
-
- def printLogLine(self, text):
- self.stdscr.addstr(
- 0, 0, text + " " * (self.cols - len(text)), curses.A_STANDOUT
- )
-
- def doRead(self):
- """Input is ready!"""
- curses.noecho()
- c = self.stdscr.getch() # read a character
-
- if c == curses.KEY_BACKSPACE:
- self.searchText = self.searchText[:-1]
-
- elif c == curses.KEY_ENTER or c == 10:
- text = self.searchText
- self.searchText = ""
-
- self.print_line(">> %s" % text)
-
- try:
- if self.callback:
- self.callback.on_line(text)
- except Exception as e:
- self.print_line(str(e))
-
- self.stdscr.refresh()
-
- elif isprint(c):
- if len(self.searchText) == self.cols - 2:
- return
- self.searchText = self.searchText + chr(c)
-
- self.stdscr.addstr(
- self.rows - 1,
- 0,
- self.searchText + (" " * (self.cols - len(self.searchText) - 2)),
- )
-
- self.paintStatus(self.statusText + " %d" % len(self.searchText))
- self.stdscr.move(self.rows - 1, len(self.searchText))
- self.stdscr.refresh()
-
- def logPrefix(self):
- return "CursesStdIO"
-
- def close(self):
- """clean up"""
-
- curses.nocbreak()
- self.stdscr.keypad(0)
- curses.echo()
- curses.endwin()
-
-
-class Callback:
- def __init__(self, stdio):
- self.stdio = stdio
-
- def on_line(self, text):
- self.stdio.print_line(text)
-
-
-def main(stdscr):
- screen = CursesStdIO(stdscr) # create Screen object
-
- callback = Callback(screen)
-
- screen.set_callback(callback)
-
- stdscr.refresh()
- reactor.addReader(screen)
- reactor.run()
- screen.close()
-
-
-if __name__ == "__main__":
- curses.wrapper(main)
diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
deleted file mode 100644
index 31b8a682..00000000
--- a/contrib/experiments/test_messaging.py
+++ /dev/null
@@ -1,367 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-""" This is an example of using the server to server implementation to do a
-basic chat style thing. It accepts commands from stdin and outputs to stdout.
-
-It assumes that ucids are of the form <user>@<domain>, and uses <domain> as
-the address of the remote home server to hit.
-
-Usage:
- python test_messaging.py <port>
-
-Currently assumes the local address is localhost:<port>
-
-"""
-
-
-import argparse
-import curses.wrapper
-import json
-import logging
-import os
-import re
-
-import cursesio
-
-from twisted.internet import defer, reactor
-from twisted.python import log
-
-from synapse.app.homeserver import SynapseHomeServer
-from synapse.federation import ReplicationHandler
-from synapse.federation.units import Pdu
-from synapse.util import origin_from_ucid
-
-# from synapse.logging.utils import log_function
-
-
-logger = logging.getLogger("example")
-
-
-def excpetion_errback(failure):
- logging.exception(failure)
-
-
-class InputOutput:
- """This is responsible for basic I/O so that a user can interact with
- the example app.
- """
-
- def __init__(self, screen, user):
- self.screen = screen
- self.user = user
-
- def set_home_server(self, server):
- self.server = server
-
- def on_line(self, line):
- """This is where we process commands."""
-
- try:
- m = re.match(r"^join (\S+)$", line)
- if m:
- # The `sender` wants to join a room.
- (room_name,) = m.groups()
- self.print_line("%s joining %s" % (self.user, room_name))
- self.server.join_room(room_name, self.user, self.user)
- # self.print_line("OK.")
- return
-
- m = re.match(r"^invite (\S+) (\S+)$", line)
- if m:
- # `sender` wants to invite someone to a room
- room_name, invitee = m.groups()
- self.print_line("%s invited to %s" % (invitee, room_name))
- self.server.invite_to_room(room_name, self.user, invitee)
- # self.print_line("OK.")
- return
-
- m = re.match(r"^send (\S+) (.*)$", line)
- if m:
- # `sender` wants to message a room
- room_name, body = m.groups()
- self.print_line("%s send to %s" % (self.user, room_name))
- self.server.send_message(room_name, self.user, body)
- # self.print_line("OK.")
- return
-
- m = re.match(r"^backfill (\S+)$", line)
- if m:
- # we want to backfill a room
- (room_name,) = m.groups()
- self.print_line("backfill %s" % room_name)
- self.server.backfill(room_name)
- return
-
- self.print_line("Unrecognized command")
-
- except Exception as e:
- logger.exception(e)
-
- def print_line(self, text):
- self.screen.print_line(text)
-
- def print_log(self, text):
- self.screen.print_log(text)
-
-
-class IOLoggerHandler(logging.Handler):
- def __init__(self, io):
- logging.Handler.__init__(self)
- self.io = io
-
- def emit(self, record):
- if record.levelno < logging.WARN:
- return
-
- msg = self.format(record)
- self.io.print_log(msg)
-
-
-class Room:
- """Used to store (in memory) the current membership state of a room, and
- which home servers we should send PDUs associated with the room to.
- """
-
- def __init__(self, room_name):
- self.room_name = room_name
- self.invited = set()
- self.participants = set()
- self.servers = set()
-
- self.oldest_server = None
-
- self.have_got_metadata = False
-
- def add_participant(self, participant):
- """Someone has joined the room"""
- self.participants.add(participant)
- self.invited.discard(participant)
-
- server = origin_from_ucid(participant)
- self.servers.add(server)
-
- if not self.oldest_server:
- self.oldest_server = server
-
- def add_invited(self, invitee):
- """Someone has been invited to the room"""
- self.invited.add(invitee)
- self.servers.add(origin_from_ucid(invitee))
-
-
-class HomeServer(ReplicationHandler):
- """A very basic home server implentation that allows people to join a
- room and then invite other people.
- """
-
- def __init__(self, server_name, replication_layer, output):
- self.server_name = server_name
- self.replication_layer = replication_layer
- self.replication_layer.set_handler(self)
-
- self.joined_rooms = {}
-
- self.output = output
-
- def on_receive_pdu(self, pdu):
- """We just received a PDU"""
- pdu_type = pdu.pdu_type
-
- if pdu_type == "sy.room.message":
- self._on_message(pdu)
- elif pdu_type == "sy.room.member" and "membership" in pdu.content:
- if pdu.content["membership"] == "join":
- self._on_join(pdu.context, pdu.state_key)
- elif pdu.content["membership"] == "invite":
- self._on_invite(pdu.origin, pdu.context, pdu.state_key)
- else:
- self.output.print_line(
- "#%s (unrec) %s = %s"
- % (pdu.context, pdu.pdu_type, json.dumps(pdu.content))
- )
-
- def _on_message(self, pdu):
- """We received a message"""
- self.output.print_line(
- "#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"])
- )
-
- def _on_join(self, context, joinee):
- """Someone has joined a room, either a remote user or a local user"""
- room = self._get_or_create_room(context)
- room.add_participant(joinee)
-
- self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED"))
-
- def _on_invite(self, origin, context, invitee):
- """Someone has been invited"""
- room = self._get_or_create_room(context)
- room.add_invited(invitee)
-
- self.output.print_line("#%s %s %s" % (context, invitee, "*** INVITED"))
-
- if not room.have_got_metadata and origin is not self.server_name:
- logger.debug("Get room state")
- self.replication_layer.get_state_for_context(origin, context)
- room.have_got_metadata = True
-
- @defer.inlineCallbacks
- def send_message(self, room_name, sender, body):
- """Send a message to a room!"""
- destinations = yield self.get_servers_for_context(room_name)
-
- try:
- yield self.replication_layer.send_pdu(
- Pdu.create_new(
- context=room_name,
- pdu_type="sy.room.message",
- content={"sender": sender, "body": body},
- origin=self.server_name,
- destinations=destinations,
- )
- )
- except Exception as e:
- logger.exception(e)
-
- @defer.inlineCallbacks
- def join_room(self, room_name, sender, joinee):
- """Join a room!"""
- self._on_join(room_name, joinee)
-
- destinations = yield self.get_servers_for_context(room_name)
-
- try:
- pdu = Pdu.create_new(
- context=room_name,
- pdu_type="sy.room.member",
- is_state=True,
- state_key=joinee,
- content={"membership": "join"},
- origin=self.server_name,
- destinations=destinations,
- )
- yield self.replication_layer.send_pdu(pdu)
- except Exception as e:
- logger.exception(e)
-
- @defer.inlineCallbacks
- def invite_to_room(self, room_name, sender, invitee):
- """Invite someone to a room!"""
- self._on_invite(self.server_name, room_name, invitee)
-
- destinations = yield self.get_servers_for_context(room_name)
-
- try:
- yield self.replication_layer.send_pdu(
- Pdu.create_new(
- context=room_name,
- is_state=True,
- pdu_type="sy.room.member",
- state_key=invitee,
- content={"membership": "invite"},
- origin=self.server_name,
- destinations=destinations,
- )
- )
- except Exception as e:
- logger.exception(e)
-
- def backfill(self, room_name, limit=5):
- room = self.joined_rooms.get(room_name)
-
- if not room:
- return
-
- dest = room.oldest_server
-
- return self.replication_layer.backfill(dest, room_name, limit)
-
- def _get_room_remote_servers(self, room_name):
- return list(self.joined_rooms.setdefault(room_name).servers)
-
- def _get_or_create_room(self, room_name):
- return self.joined_rooms.setdefault(room_name, Room(room_name))
-
- def get_servers_for_context(self, context):
- return defer.succeed(
- self.joined_rooms.setdefault(context, Room(context)).servers
- )
-
-
-def main(stdscr):
- parser = argparse.ArgumentParser()
- parser.add_argument("user", type=str)
- parser.add_argument("-v", "--verbose", action="count")
- args = parser.parse_args()
-
- user = args.user
- server_name = origin_from_ucid(user)
-
- # Set up logging
-
- root_logger = logging.getLogger()
-
- formatter = logging.Formatter(
- "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
- )
- if not os.path.exists("logs"):
- os.makedirs("logs")
- fh = logging.FileHandler("logs/%s" % user)
- fh.setFormatter(formatter)
-
- root_logger.addHandler(fh)
- root_logger.setLevel(logging.DEBUG)
-
- # Hack: The only way to get it to stop logging to sys.stderr :(
- log.theLogPublisher.observers = []
- observer = log.PythonLoggingObserver()
- observer.start()
-
- # Set up synapse server
-
- curses_stdio = cursesio.CursesStdIO(stdscr)
- input_output = InputOutput(curses_stdio, user)
-
- curses_stdio.set_callback(input_output)
-
- app_hs = SynapseHomeServer(server_name, db_name="dbs/%s" % user)
- replication = app_hs.get_replication_layer()
-
- hs = HomeServer(server_name, replication, curses_stdio)
-
- input_output.set_home_server(hs)
-
- # Add input_output logger
- io_logger = IOLoggerHandler(input_output)
- io_logger.setFormatter(formatter)
- root_logger.addHandler(io_logger)
-
- # Start!
-
- try:
- port = int(server_name.split(":")[1])
- except Exception:
- port = 12345
-
- app_hs.get_http_server().start_listening(port)
-
- reactor.addReader(curses_stdio)
-
- reactor.run()
-
-
-if __name__ == "__main__":
- curses.wrapper(main)
diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py
deleted file mode 100644
index b3de4686..00000000
--- a/contrib/jitsimeetbridge/jitsimeetbridge.py
+++ /dev/null
@@ -1,298 +0,0 @@
-#!/usr/bin/env python
-
-"""
-This is an attempt at bridging matrix clients into a Jitis meet room via Matrix
-video call. It uses hard-coded xml strings overg XMPP BOSH. It can display one
-of the streams from the Jitsi bridge until the second lot of SDP comes down and
-we set the remote SDP at which point the stream ends. Our video never gets to
-the bridge.
-
-Requires:
-npm install jquery jsdom
-"""
-import json
-import subprocess
-import time
-
-import gevent
-import grequests
-from BeautifulSoup import BeautifulSoup
-
-ACCESS_TOKEN = ""
-
-MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/"
-MYUSERNAME = "@davetest:matrix.org"
-
-HTTPBIND = "https://meet.jit.si/http-bind"
-# HTTPBIND = 'https://jitsi.vuc.me/http-bind'
-# ROOMNAME = "matrix"
-ROOMNAME = "pibble"
-
-HOST = "guest.jit.si"
-# HOST="jitsi.vuc.me"
-
-TURNSERVER = "turn.guest.jit.si"
-# TURNSERVER="turn.jitsi.vuc.me"
-
-ROOMDOMAIN = "meet.jit.si"
-# ROOMDOMAIN="conference.jitsi.vuc.me"
-
-
-class TrivialMatrixClient:
- def __init__(self, access_token):
- self.token = None
- self.access_token = access_token
-
- def getEvent(self):
- while True:
- url = (
- MATRIXBASE
- + "events?access_token="
- + self.access_token
- + "&timeout=60000"
- )
- if self.token:
- url += "&from=" + self.token
- req = grequests.get(url)
- resps = grequests.map([req])
- obj = json.loads(resps[0].content)
- print("incoming from matrix", obj)
- if "end" not in obj:
- continue
- self.token = obj["end"]
- if len(obj["chunk"]):
- return obj["chunk"][0]
-
- def joinRoom(self, roomId):
- url = MATRIXBASE + "rooms/" + roomId + "/join?access_token=" + self.access_token
- print(url)
- headers = {"Content-Type": "application/json"}
- req = grequests.post(url, headers=headers, data="{}")
- resps = grequests.map([req])
- obj = json.loads(resps[0].content)
- print("response: ", obj)
-
- def sendEvent(self, roomId, evType, event):
- url = (
- MATRIXBASE
- + "rooms/"
- + roomId
- + "/send/"
- + evType
- + "?access_token="
- + self.access_token
- )
- print(url)
- print(json.dumps(event))
- headers = {"Content-Type": "application/json"}
- req = grequests.post(url, headers=headers, data=json.dumps(event))
- resps = grequests.map([req])
- obj = json.loads(resps[0].content)
- print("response: ", obj)
-
-
-xmppClients = {}
-
-
-def matrixLoop():
- while True:
- ev = matrixCli.getEvent()
- print(ev)
- if ev["type"] == "m.room.member":
- print("membership event")
- if ev["membership"] == "invite" and ev["state_key"] == MYUSERNAME:
- roomId = ev["room_id"]
- print("joining room %s" % (roomId))
- matrixCli.joinRoom(roomId)
- elif ev["type"] == "m.room.message":
- if ev["room_id"] in xmppClients:
- print("already have a bridge for that user, ignoring")
- continue
- print("got message, connecting")
- xmppClients[ev["room_id"]] = TrivialXmppClient(ev["room_id"], ev["user_id"])
- gevent.spawn(xmppClients[ev["room_id"]].xmppLoop)
- elif ev["type"] == "m.call.invite":
- print("Incoming call")
- # sdp = ev['content']['offer']['sdp']
- # print "sdp: %s" % (sdp)
- # xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id'])
- # gevent.spawn(xmppClients[ev['room_id']].xmppLoop)
- elif ev["type"] == "m.call.answer":
- print("Call answered")
- sdp = ev["content"]["answer"]["sdp"]
- if ev["room_id"] not in xmppClients:
- print("We didn't have a call for that room")
- continue
- # should probably check call ID too
- xmppCli = xmppClients[ev["room_id"]]
- xmppCli.sendAnswer(sdp)
- elif ev["type"] == "m.call.hangup":
- if ev["room_id"] in xmppClients:
- xmppClients[ev["room_id"]].stop()
- del xmppClients[ev["room_id"]]
-
-
-class TrivialXmppClient:
- def __init__(self, matrixRoom, userId):
- self.rid = 0
- self.matrixRoom = matrixRoom
- self.userId = userId
- self.running = True
-
- def stop(self):
- self.running = False
-
- def nextRid(self):
- self.rid += 1
- return "%d" % (self.rid)
-
- def sendIq(self, xml):
- fullXml = (
- "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s'>%s</body>"
- % (self.nextRid(), self.sid, xml)
- )
- # print "\t>>>%s" % (fullXml)
- return self.xmppPoke(fullXml)
-
- def xmppPoke(self, xml):
- headers = {"Content-Type": "application/xml"}
- req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml)
- resps = grequests.map([req])
- obj = BeautifulSoup(resps[0].content)
- return obj
-
- def sendAnswer(self, answer):
- print("sdp from matrix client", answer)
- p = subprocess.Popen(
- ["node", "unjingle/unjingle.js", "--sdp"],
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- )
- jingle, out_err = p.communicate(answer)
- jingle = jingle % {
- "tojid": self.callfrom,
- "action": "session-accept",
- "initiator": self.callfrom,
- "responder": self.jid,
- "sid": self.callsid,
- }
- print("answer jingle from sdp", jingle)
- res = self.sendIq(jingle)
- print("reply from answer: ", res)
-
- self.ssrcs = {}
- jingleSoup = BeautifulSoup(jingle)
- for cont in jingleSoup.iq.jingle.findAll("content"):
- if cont.description:
- self.ssrcs[cont["name"]] = cont.description["ssrc"]
- print("my ssrcs:", self.ssrcs)
-
- gevent.joinall([gevent.spawn(self.advertiseSsrcs)])
-
- def advertiseSsrcs(self):
- time.sleep(7)
- print("SSRC spammer started")
- while self.running:
- ssrcMsg = (
- "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>"
- % {
- "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
- "nick": self.userId,
- "assrc": self.ssrcs["audio"],
- "vssrc": self.ssrcs["video"],
- }
- )
- res = self.sendIq(ssrcMsg)
- print("reply from ssrc announce: ", res)
- time.sleep(10)
-
- def xmppLoop(self):
- self.matrixCallId = time.time()
- res = self.xmppPoke(
- "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' to='%s' xml:lang='en' wait='60' hold='1' content='text/xml; charset=utf-8' ver='1.6' xmpp:version='1.0' xmlns:xmpp='urn:xmpp:xbosh'/>"
- % (self.nextRid(), HOST)
- )
-
- print(res)
- self.sid = res.body["sid"]
- print("sid %s" % (self.sid))
-
- res = self.sendIq(
- "<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='ANONYMOUS'/>"
- )
-
- res = self.xmppPoke(
- "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s' to='%s' xml:lang='en' xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'/>"
- % (self.nextRid(), self.sid, HOST)
- )
-
- res = self.sendIq(
- "<iq type='set' id='_bind_auth_2' xmlns='jabber:client'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>"
- )
- print(res)
-
- self.jid = res.body.iq.bind.jid.string
- print("jid: %s" % (self.jid))
- self.shortJid = self.jid.split("-")[0]
-
- res = self.sendIq(
- "<iq type='set' id='_session_auth_2' xmlns='jabber:client'><session xmlns='urn:ietf:params:xml:ns:xmpp-session'/></iq>"
- )
-
- # randomthing = res.body.iq['to']
- # whatsitpart = randomthing.split('-')[0]
-
- # print "other random bind thing: %s" % (randomthing)
-
- # advertise preence to the jitsi room, with our nick
- res = self.sendIq(
- "<iq type='get' to='%s' xmlns='jabber:client' id='1:sendIQ'><services xmlns='urn:xmpp:extdisco:1'><service host='%s'/></services></iq><presence to='%s@%s/d98f6c40' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%s</nick></presence>"
- % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId)
- )
- self.muc = {"users": []}
- for p in res.body.findAll("presence"):
- u = {}
- u["shortJid"] = p["from"].split("/")[1]
- if p.c and p.c.nick:
- u["nick"] = p.c.nick.string
- self.muc["users"].append(u)
- print("muc: ", self.muc)
-
- # wait for stuff
- while True:
- print("waiting...")
- res = self.sendIq("")
- print("got from stream: ", res)
- if res.body.iq:
- jingles = res.body.iq.findAll("jingle")
- if len(jingles):
- self.callfrom = res.body.iq["from"]
- self.handleInvite(jingles[0])
- elif "type" in res.body and res.body["type"] == "terminate":
- self.running = False
- del xmppClients[self.matrixRoom]
- return
-
- def handleInvite(self, jingle):
- self.initiator = jingle["initiator"]
- self.callsid = jingle["sid"]
- p = subprocess.Popen(
- ["node", "unjingle/unjingle.js", "--jingle"],
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- )
- print("raw jingle invite", str(jingle))
- sdp, out_err = p.communicate(str(jingle))
- print("transformed remote offer sdp", sdp)
- inviteEvent = {
- "offer": {"type": "offer", "sdp": sdp},
- "call_id": self.matrixCallId,
- "version": 0,
- "lifetime": 30000,
- }
- matrixCli.sendEvent(self.matrixRoom, "m.call.invite", inviteEvent)
-
-
-matrixCli = TrivialMatrixClient(ACCESS_TOKEN) # Undefined name
-
-gevent.joinall([gevent.spawn(matrixLoop)])
diff --git a/contrib/jitsimeetbridge/syweb-jitsi-conference.patch b/contrib/jitsimeetbridge/syweb-jitsi-conference.patch
deleted file mode 100644
index aed23c78..00000000
--- a/contrib/jitsimeetbridge/syweb-jitsi-conference.patch
+++ /dev/null
@@ -1,188 +0,0 @@
-diff --git a/syweb/webclient/app/components/matrix/matrix-call.js b/syweb/webclient/app/components/matrix/matrix-call.js
-index 9fbfff0..dc68077 100644
---- a/syweb/webclient/app/components/matrix/matrix-call.js
-+++ b/syweb/webclient/app/components/matrix/matrix-call.js
-@@ -16,6 +16,45 @@ limitations under the License.
-
- 'use strict';
-
-+
-+function sendKeyframe(pc) {
-+ console.log('sendkeyframe', pc.iceConnectionState);
-+ if (pc.iceConnectionState !== 'connected') return; // safe...
-+ pc.setRemoteDescription(
-+ pc.remoteDescription,
-+ function () {
-+ pc.createAnswer(
-+ function (modifiedAnswer) {
-+ pc.setLocalDescription(
-+ modifiedAnswer,
-+ function () {
-+ // noop
-+ },
-+ function (error) {
-+ console.log('triggerKeyframe setLocalDescription failed', error);
-+ messageHandler.showError();
-+ }
-+ );
-+ },
-+ function (error) {
-+ console.log('triggerKeyframe createAnswer failed', error);
-+ messageHandler.showError();
-+ }
-+ );
-+ },
-+ function (error) {
-+ console.log('triggerKeyframe setRemoteDescription failed', error);
-+ messageHandler.showError();
-+ }
-+ );
-+}
-+
-+
-+
-+
-+
-+
-+
- var forAllVideoTracksOnStream = function(s, f) {
- var tracks = s.getVideoTracks();
- for (var i = 0; i < tracks.length; i++) {
-@@ -83,7 +122,7 @@ angular.module('MatrixCall', [])
- }
-
- // FIXME: we should prevent any calls from being placed or accepted before this has finished
-- MatrixCall.getTurnServer();
-+ //MatrixCall.getTurnServer();
-
- MatrixCall.CALL_TIMEOUT = 60000;
- MatrixCall.FALLBACK_STUN_SERVER = 'stun:stun.l.google.com:19302';
-@@ -132,6 +171,22 @@ angular.module('MatrixCall', [])
- pc.onsignalingstatechange = function() { self.onSignallingStateChanged(); };
- pc.onicecandidate = function(c) { self.gotLocalIceCandidate(c); };
- pc.onaddstream = function(s) { self.onAddStream(s); };
-+
-+ var datachan = pc.createDataChannel('RTCDataChannel', {
-+ reliable: false
-+ });
-+ console.log("data chan: "+datachan);
-+ datachan.onopen = function() {
-+ console.log("data channel open");
-+ };
-+ datachan.onmessage = function() {
-+ console.log("data channel message");
-+ };
-+ pc.ondatachannel = function(event) {
-+ console.log("have data channel");
-+ event.channel.binaryType = 'blob';
-+ };
-+
- return pc;
- }
-
-@@ -200,6 +255,12 @@ angular.module('MatrixCall', [])
- }, this.msg.lifetime - event.age);
- };
-
-+ MatrixCall.prototype.receivedInvite = function(event) {
-+ console.log("Got second invite for call "+this.call_id);
-+ this.peerConn.setRemoteDescription(new RTCSessionDescription(this.msg.offer), this.onSetRemoteDescriptionSuccess, this.onSetRemoteDescriptionError);
-+ };
-+
-+
- // perverse as it may seem, sometimes we want to instantiate a call with a hangup message
- // (because when getting the state of the room on load, events come in reverse order and
- // we want to remember that a call has been hung up)
-@@ -349,7 +410,7 @@ angular.module('MatrixCall', [])
- 'mandatory': {
- 'OfferToReceiveAudio': true,
- 'OfferToReceiveVideo': this.type == 'video'
-- },
-+ }
- };
- this.peerConn.createAnswer(function(d) { self.createdAnswer(d); }, function(e) {}, constraints);
- // This can't be in an apply() because it's called by a predecessor call under glare conditions :(
-@@ -359,8 +420,20 @@ angular.module('MatrixCall', [])
- MatrixCall.prototype.gotLocalIceCandidate = function(event) {
- if (event.candidate) {
- console.log("Got local ICE "+event.candidate.sdpMid+" candidate: "+event.candidate.candidate);
-- this.sendCandidate(event.candidate);
-- }
-+ //this.sendCandidate(event.candidate);
-+ } else {
-+ console.log("have all candidates, sending answer");
-+ var content = {
-+ version: 0,
-+ call_id: this.call_id,
-+ answer: this.peerConn.localDescription
-+ };
-+ this.sendEventWithRetry('m.call.answer', content);
-+ var self = this;
-+ $rootScope.$apply(function() {
-+ self.state = 'connecting';
-+ });
-+ }
- }
-
- MatrixCall.prototype.gotRemoteIceCandidate = function(cand) {
-@@ -418,15 +491,6 @@ angular.module('MatrixCall', [])
- console.log("Created answer: "+description);
- var self = this;
- this.peerConn.setLocalDescription(description, function() {
-- var content = {
-- version: 0,
-- call_id: self.call_id,
-- answer: self.peerConn.localDescription
-- };
-- self.sendEventWithRetry('m.call.answer', content);
-- $rootScope.$apply(function() {
-- self.state = 'connecting';
-- });
- }, function() { console.log("Error setting local description!"); } );
- };
-
-@@ -448,6 +512,9 @@ angular.module('MatrixCall', [])
- $rootScope.$apply(function() {
- self.state = 'connected';
- self.didConnect = true;
-+ /*$timeout(function() {
-+ sendKeyframe(self.peerConn);
-+ }, 1000);*/
- });
- } else if (this.peerConn.iceConnectionState == 'failed') {
- this.hangup('ice_failed');
-@@ -518,6 +585,7 @@ angular.module('MatrixCall', [])
-
- MatrixCall.prototype.onRemoteStreamEnded = function(event) {
- console.log("Remote stream ended");
-+ return;
- var self = this;
- $rootScope.$apply(function() {
- self.state = 'ended';
-diff --git a/syweb/webclient/app/components/matrix/matrix-phone-service.js b/syweb/webclient/app/components/matrix/matrix-phone-service.js
-index 55dbbf5..272fa27 100644
---- a/syweb/webclient/app/components/matrix/matrix-phone-service.js
-+++ b/syweb/webclient/app/components/matrix/matrix-phone-service.js
-@@ -48,6 +48,13 @@ angular.module('matrixPhoneService', [])
- return;
- }
-
-+ // do we already have an entry for this call ID?
-+ var existingEntry = matrixPhoneService.allCalls[msg.call_id];
-+ if (existingEntry) {
-+ existingEntry.receivedInvite(msg);
-+ return;
-+ }
-+
- var call = undefined;
- if (!isLive) {
- // if this event wasn't live then this call may already be over
-@@ -108,7 +115,7 @@ angular.module('matrixPhoneService', [])
- call.hangup();
- }
- } else {
-- $rootScope.$broadcast(matrixPhoneService.INCOMING_CALL_EVENT, call);
-+ $rootScope.$broadcast(matrixPhoneService.INCOMING_CALL_EVENT, call);
- }
- } else if (event.type == 'm.call.answer') {
- var call = matrixPhoneService.allCalls[msg.call_id];
diff --git a/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.js b/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.js
deleted file mode 100644
index e99dd7bf..00000000
--- a/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.js
+++ /dev/null
@@ -1,712 +0,0 @@
-/* jshint -W117 */
-// SDP STUFF
-function SDP(sdp) {
- this.media = sdp.split('\r\nm=');
- for (var i = 1; i < this.media.length; i++) {
- this.media[i] = 'm=' + this.media[i];
- if (i != this.media.length - 1) {
- this.media[i] += '\r\n';
- }
- }
- this.session = this.media.shift() + '\r\n';
- this.raw = this.session + this.media.join('');
-}
-
-exports.SDP = SDP;
-
-var jsdom = require("jsdom");
-var window = jsdom.jsdom().parentWindow;
-var $ = require('jquery')(window);
-
-var SDPUtil = require('./strophe.jingle.sdp.util.js').SDPUtil;
-
-/**
- * Returns map of MediaChannel mapped per channel idx.
- */
-SDP.prototype.getMediaSsrcMap = function() {
- var self = this;
- var media_ssrcs = {};
- for (channelNum = 0; channelNum < self.media.length; channelNum++) {
- modified = true;
- tmp = SDPUtil.find_lines(self.media[channelNum], 'a=ssrc:');
- var type = SDPUtil.parse_mid(SDPUtil.find_line(self.media[channelNum], 'a=mid:'));
- var channel = new MediaChannel(channelNum, type);
- media_ssrcs[channelNum] = channel;
- tmp.forEach(function (line) {
- var linessrc = line.substring(7).split(' ')[0];
- // allocate new ChannelSsrc
- if(!channel.ssrcs[linessrc]) {
- channel.ssrcs[linessrc] = new ChannelSsrc(linessrc, type);
- }
- channel.ssrcs[linessrc].lines.push(line);
- });
- tmp = SDPUtil.find_lines(self.media[channelNum], 'a=ssrc-group:');
- tmp.forEach(function(line){
- var semantics = line.substr(0, idx).substr(13);
- var ssrcs = line.substr(14 + semantics.length).split(' ');
- if (ssrcs.length != 0) {
- var ssrcGroup = new ChannelSsrcGroup(semantics, ssrcs);
- channel.ssrcGroups.push(ssrcGroup);
- }
- });
- }
- return media_ssrcs;
-};
-/**
- * Returns <tt>true</tt> if this SDP contains given SSRC.
- * @param ssrc the ssrc to check.
- * @returns {boolean} <tt>true</tt> if this SDP contains given SSRC.
- */
-SDP.prototype.containsSSRC = function(ssrc) {
- var channels = this.getMediaSsrcMap();
- var contains = false;
- Object.keys(channels).forEach(function(chNumber){
- var channel = channels[chNumber];
- //console.log("Check", channel, ssrc);
- if(Object.keys(channel.ssrcs).indexOf(ssrc) != -1){
- contains = true;
- }
- });
- return contains;
-};
-
-/**
- * Returns map of MediaChannel that contains only media not contained in <tt>otherSdp</tt>. Mapped by channel idx.
- * @param otherSdp the other SDP to check ssrc with.
- */
-SDP.prototype.getNewMedia = function(otherSdp) {
-
- // this could be useful in Array.prototype.
- function arrayEquals(array) {
- // if the other array is a falsy value, return
- if (!array)
- return false;
-
- // compare lengths - can save a lot of time
- if (this.length != array.length)
- return false;
-
- for (var i = 0, l=this.length; i < l; i++) {
- // Check if we have nested arrays
- if (this[i] instanceof Array && array[i] instanceof Array) {
- // recurse into the nested arrays
- if (!this[i].equals(array[i]))
- return false;
- }
- else if (this[i] != array[i]) {
- // Warning - two different object instances will never be equal: {x:20} != {x:20}
- return false;
- }
- }
- return true;
- }
-
- var myMedia = this.getMediaSsrcMap();
- var othersMedia = otherSdp.getMediaSsrcMap();
- var newMedia = {};
- Object.keys(othersMedia).forEach(function(channelNum) {
- var myChannel = myMedia[channelNum];
- var othersChannel = othersMedia[channelNum];
- if(!myChannel && othersChannel) {
- // Add whole channel
- newMedia[channelNum] = othersChannel;
- return;
- }
- // Look for new ssrcs accross the channel
- Object.keys(othersChannel.ssrcs).forEach(function(ssrc) {
- if(Object.keys(myChannel.ssrcs).indexOf(ssrc) === -1) {
- // Allocate channel if we've found ssrc that doesn't exist in our channel
- if(!newMedia[channelNum]){
- newMedia[channelNum] = new MediaChannel(othersChannel.chNumber, othersChannel.mediaType);
- }
- newMedia[channelNum].ssrcs[ssrc] = othersChannel.ssrcs[ssrc];
- }
- });
-
- // Look for new ssrc groups across the channels
- othersChannel.ssrcGroups.forEach(function(otherSsrcGroup){
-
- // try to match the other ssrc-group with an ssrc-group of ours
- var matched = false;
- for (var i = 0; i < myChannel.ssrcGroups.length; i++) {
- var mySsrcGroup = myChannel.ssrcGroups[i];
- if (otherSsrcGroup.semantics == mySsrcGroup.semantics
- && arrayEquals.apply(otherSsrcGroup.ssrcs, [mySsrcGroup.ssrcs])) {
-
- matched = true;
- break;
- }
- }
-
- if (!matched) {
- // Allocate channel if we've found an ssrc-group that doesn't
- // exist in our channel
-
- if(!newMedia[channelNum]){
- newMedia[channelNum] = new MediaChannel(othersChannel.chNumber, othersChannel.mediaType);
- }
- newMedia[channelNum].ssrcGroups.push(otherSsrcGroup);
- }
- });
- });
- return newMedia;
-};
-
-// remove iSAC and CN from SDP
-SDP.prototype.mangle = function () {
- var i, j, mline, lines, rtpmap, newdesc;
- for (i = 0; i < this.media.length; i++) {
- lines = this.media[i].split('\r\n');
- lines.pop(); // remove empty last element
- mline = SDPUtil.parse_mline(lines.shift());
- if (mline.media != 'audio')
- continue;
- newdesc = '';
- mline.fmt.length = 0;
- for (j = 0; j < lines.length; j++) {
- if (lines[j].substr(0, 9) == 'a=rtpmap:') {
- rtpmap = SDPUtil.parse_rtpmap(lines[j]);
- if (rtpmap.name == 'CN' || rtpmap.name == 'ISAC')
- continue;
- mline.fmt.push(rtpmap.id);
- newdesc += lines[j] + '\r\n';
- } else {
- newdesc += lines[j] + '\r\n';
- }
- }
- this.media[i] = SDPUtil.build_mline(mline) + '\r\n';
- this.media[i] += newdesc;
- }
- this.raw = this.session + this.media.join('');
-};
-
-// remove lines matching prefix from session section
-SDP.prototype.removeSessionLines = function(prefix) {
- var self = this;
- var lines = SDPUtil.find_lines(this.session, prefix);
- lines.forEach(function(line) {
- self.session = self.session.replace(line + '\r\n', '');
- });
- this.raw = this.session + this.media.join('');
- return lines;
-}
-// remove lines matching prefix from a media section specified by mediaindex
-// TODO: non-numeric mediaindex could match mid
-SDP.prototype.removeMediaLines = function(mediaindex, prefix) {
- var self = this;
- var lines = SDPUtil.find_lines(this.media[mediaindex], prefix);
- lines.forEach(function(line) {
- self.media[mediaindex] = self.media[mediaindex].replace(line + '\r\n', '');
- });
- this.raw = this.session + this.media.join('');
- return lines;
-}
-
-// add content's to a jingle element
-SDP.prototype.toJingle = function (elem, thecreator) {
- var i, j, k, mline, ssrc, rtpmap, tmp, line, lines;
- var self = this;
- // new bundle plan
- if (SDPUtil.find_line(this.session, 'a=group:')) {
- lines = SDPUtil.find_lines(this.session, 'a=group:');
- for (i = 0; i < lines.length; i++) {
- tmp = lines[i].split(' ');
- var semantics = tmp.shift().substr(8);
- elem.c('group', {xmlns: 'urn:xmpp:jingle:apps:grouping:0', semantics:semantics});
- for (j = 0; j < tmp.length; j++) {
- elem.c('content', {name: tmp[j]}).up();
- }
- elem.up();
- }
- }
- // old bundle plan, to be removed
- var bundle = [];
- if (SDPUtil.find_line(this.session, 'a=group:BUNDLE')) {
- bundle = SDPUtil.find_line(this.session, 'a=group:BUNDLE ').split(' ');
- bundle.shift();
- }
- for (i = 0; i < this.media.length; i++) {
- mline = SDPUtil.parse_mline(this.media[i].split('\r\n')[0]);
- if (!(mline.media === 'audio' ||
- mline.media === 'video' ||
- mline.media === 'application'))
- {
- continue;
- }
- if (SDPUtil.find_line(this.media[i], 'a=ssrc:')) {
- ssrc = SDPUtil.find_line(this.media[i], 'a=ssrc:').substring(7).split(' ')[0]; // take the first
- } else {
- ssrc = false;
- }
-
- elem.c('content', {creator: thecreator, name: mline.media});
- if (SDPUtil.find_line(this.media[i], 'a=mid:')) {
- // prefer identifier from a=mid if present
- var mid = SDPUtil.parse_mid(SDPUtil.find_line(this.media[i], 'a=mid:'));
- elem.attrs({ name: mid });
-
- // old BUNDLE plan, to be removed
- if (bundle.indexOf(mid) !== -1) {
- elem.c('bundle', {xmlns: 'http://estos.de/ns/bundle'}).up();
- bundle.splice(bundle.indexOf(mid), 1);
- }
- }
-
- if (SDPUtil.find_line(this.media[i], 'a=rtpmap:').length)
- {
- elem.c('description',
- {xmlns: 'urn:xmpp:jingle:apps:rtp:1',
- media: mline.media });
- if (ssrc) {
- elem.attrs({ssrc: ssrc});
- }
- for (j = 0; j < mline.fmt.length; j++) {
- rtpmap = SDPUtil.find_line(this.media[i], 'a=rtpmap:' + mline.fmt[j]);
- elem.c('payload-type', SDPUtil.parse_rtpmap(rtpmap));
- // put any 'a=fmtp:' + mline.fmt[j] lines into <param name=foo value=bar/>
- if (SDPUtil.find_line(this.media[i], 'a=fmtp:' + mline.fmt[j])) {
- tmp = SDPUtil.parse_fmtp(SDPUtil.find_line(this.media[i], 'a=fmtp:' + mline.fmt[j]));
- for (k = 0; k < tmp.length; k++) {
- elem.c('parameter', tmp[k]).up();
- }
- }
- this.RtcpFbToJingle(i, elem, mline.fmt[j]); // XEP-0293 -- map a=rtcp-fb
-
- elem.up();
- }
- if (SDPUtil.find_line(this.media[i], 'a=crypto:', this.session)) {
- elem.c('encryption', {required: 1});
- var crypto = SDPUtil.find_lines(this.media[i], 'a=crypto:', this.session);
- crypto.forEach(function(line) {
- elem.c('crypto', SDPUtil.parse_crypto(line)).up();
- });
- elem.up(); // end of encryption
- }
-
- if (ssrc) {
- // new style mapping
- elem.c('source', { ssrc: ssrc, xmlns: 'urn:xmpp:jingle:apps:rtp:ssma:0' });
- // FIXME: group by ssrc and support multiple different ssrcs
- var ssrclines = SDPUtil.find_lines(this.media[i], 'a=ssrc:');
- ssrclines.forEach(function(line) {
- idx = line.indexOf(' ');
- var linessrc = line.substr(0, idx).substr(7);
- if (linessrc != ssrc) {
- elem.up();
- ssrc = linessrc;
- elem.c('source', { ssrc: ssrc, xmlns: 'urn:xmpp:jingle:apps:rtp:ssma:0' });
- }
- var kv = line.substr(idx + 1);
- elem.c('parameter');
- if (kv.indexOf(':') == -1) {
- elem.attrs({ name: kv });
- } else {
- elem.attrs({ name: kv.split(':', 2)[0] });
- elem.attrs({ value: kv.split(':', 2)[1] });
- }
- elem.up();
- });
- elem.up();
-
- // old proprietary mapping, to be removed at some point
- tmp = SDPUtil.parse_ssrc(this.media[i]);
- tmp.xmlns = 'http://estos.de/ns/ssrc';
- tmp.ssrc = ssrc;
- elem.c('ssrc', tmp).up(); // ssrc is part of description
-
- // XEP-0339 handle ssrc-group attributes
- var ssrc_group_lines = SDPUtil.find_lines(this.media[i], 'a=ssrc-group:');
- ssrc_group_lines.forEach(function(line) {
- idx = line.indexOf(' ');
- var semantics = line.substr(0, idx).substr(13);
- var ssrcs = line.substr(14 + semantics.length).split(' ');
- if (ssrcs.length != 0) {
- elem.c('ssrc-group', { semantics: semantics, xmlns: 'urn:xmpp:jingle:apps:rtp:ssma:0' });
- ssrcs.forEach(function(ssrc) {
- elem.c('source', { ssrc: ssrc })
- .up();
- });
- elem.up();
- }
- });
- }
-
- if (SDPUtil.find_line(this.media[i], 'a=rtcp-mux')) {
- elem.c('rtcp-mux').up();
- }
-
- // XEP-0293 -- map a=rtcp-fb:*
- this.RtcpFbToJingle(i, elem, '*');
-
- // XEP-0294
- if (SDPUtil.find_line(this.media[i], 'a=extmap:')) {
- lines = SDPUtil.find_lines(this.media[i], 'a=extmap:');
- for (j = 0; j < lines.length; j++) {
- tmp = SDPUtil.parse_extmap(lines[j]);
- elem.c('rtp-hdrext', { xmlns: 'urn:xmpp:jingle:apps:rtp:rtp-hdrext:0',
- uri: tmp.uri,
- id: tmp.value });
- if (tmp.hasOwnProperty('direction')) {
- switch (tmp.direction) {
- case 'sendonly':
- elem.attrs({senders: 'responder'});
- break;
- case 'recvonly':
- elem.attrs({senders: 'initiator'});
- break;
- case 'sendrecv':
- elem.attrs({senders: 'both'});
- break;
- case 'inactive':
- elem.attrs({senders: 'none'});
- break;
- }
- }
- // TODO: handle params
- elem.up();
- }
- }
- elem.up(); // end of description
- }
-
- // map ice-ufrag/pwd, dtls fingerprint, candidates
- this.TransportToJingle(i, elem);
-
- if (SDPUtil.find_line(this.media[i], 'a=sendrecv', this.session)) {
- elem.attrs({senders: 'both'});
- } else if (SDPUtil.find_line(this.media[i], 'a=sendonly', this.session)) {
- elem.attrs({senders: 'initiator'});
- } else if (SDPUtil.find_line(this.media[i], 'a=recvonly', this.session)) {
- elem.attrs({senders: 'responder'});
- } else if (SDPUtil.find_line(this.media[i], 'a=inactive', this.session)) {
- elem.attrs({senders: 'none'});
- }
- if (mline.port == '0') {
- // estos hack to reject an m-line
- elem.attrs({senders: 'rejected'});
- }
- elem.up(); // end of content
- }
- elem.up();
- return elem;
-};
-
-SDP.prototype.TransportToJingle = function (mediaindex, elem) {
- var i = mediaindex;
- var tmp;
- var self = this;
- elem.c('transport');
-
- // XEP-0343 DTLS/SCTP
- if (SDPUtil.find_line(this.media[mediaindex], 'a=sctpmap:').length)
- {
- var sctpmap = SDPUtil.find_line(
- this.media[i], 'a=sctpmap:', self.session);
- if (sctpmap)
- {
- var sctpAttrs = SDPUtil.parse_sctpmap(sctpmap);
- elem.c('sctpmap',
- {
- xmlns: 'urn:xmpp:jingle:transports:dtls-sctp:1',
- number: sctpAttrs[0], /* SCTP port */
- protocol: sctpAttrs[1], /* protocol */
- });
- // Optional stream count attribute
- if (sctpAttrs.length > 2)
- elem.attrs({ streams: sctpAttrs[2]});
- elem.up();
- }
- }
- // XEP-0320
- var fingerprints = SDPUtil.find_lines(this.media[mediaindex], 'a=fingerprint:', this.session);
- fingerprints.forEach(function(line) {
- tmp = SDPUtil.parse_fingerprint(line);
- tmp.xmlns = 'urn:xmpp:jingle:apps:dtls:0';
- elem.c('fingerprint').t(tmp.fingerprint);
- delete tmp.fingerprint;
- line = SDPUtil.find_line(self.media[mediaindex], 'a=setup:', self.session);
- if (line) {
- tmp.setup = line.substr(8);
- }
- elem.attrs(tmp);
- elem.up(); // end of fingerprint
- });
- tmp = SDPUtil.iceparams(this.media[mediaindex], this.session);
- if (tmp) {
- tmp.xmlns = 'urn:xmpp:jingle:transports:ice-udp:1';
- elem.attrs(tmp);
- // XEP-0176
- if (SDPUtil.find_line(this.media[mediaindex], 'a=candidate:', this.session)) { // add any a=candidate lines
- var lines = SDPUtil.find_lines(this.media[mediaindex], 'a=candidate:', this.session);
- lines.forEach(function (line) {
- elem.c('candidate', SDPUtil.candidateToJingle(line)).up();
- });
- }
- }
- elem.up(); // end of transport
-}
-
-SDP.prototype.RtcpFbToJingle = function (mediaindex, elem, payloadtype) { // XEP-0293
- var lines = SDPUtil.find_lines(this.media[mediaindex], 'a=rtcp-fb:' + payloadtype);
- lines.forEach(function (line) {
- var tmp = SDPUtil.parse_rtcpfb(line);
- if (tmp.type == 'trr-int') {
- elem.c('rtcp-fb-trr-int', {xmlns: 'urn:xmpp:jingle:apps:rtp:rtcp-fb:0', value: tmp.params[0]});
- elem.up();
- } else {
- elem.c('rtcp-fb', {xmlns: 'urn:xmpp:jingle:apps:rtp:rtcp-fb:0', type: tmp.type});
- if (tmp.params.length > 0) {
- elem.attrs({'subtype': tmp.params[0]});
- }
- elem.up();
- }
- });
-};
-
-SDP.prototype.RtcpFbFromJingle = function (elem, payloadtype) { // XEP-0293
- var media = '';
- var tmp = elem.find('>rtcp-fb-trr-int[xmlns="urn:xmpp:jingle:apps:rtp:rtcp-fb:0"]');
- if (tmp.length) {
- media += 'a=rtcp-fb:' + '*' + ' ' + 'trr-int' + ' ';
- if (tmp.attr('value')) {
- media += tmp.attr('value');
- } else {
- media += '0';
- }
- media += '\r\n';
- }
- tmp = elem.find('>rtcp-fb[xmlns="urn:xmpp:jingle:apps:rtp:rtcp-fb:0"]');
- tmp.each(function () {
- media += 'a=rtcp-fb:' + payloadtype + ' ' + $(this).attr('type');
- if ($(this).attr('subtype')) {
- media += ' ' + $(this).attr('subtype');
- }
- media += '\r\n';
- });
- return media;
-};
-
-// construct an SDP from a jingle stanza
-SDP.prototype.fromJingle = function (jingle) {
- var self = this;
- this.raw = 'v=0\r\n' +
- 'o=- ' + '1923518516' + ' 2 IN IP4 0.0.0.0\r\n' +// FIXME
- 's=-\r\n' +
- 't=0 0\r\n';
- // http://tools.ietf.org/html/draft-ietf-mmusic-sdp-bundle-negotiation-04#section-8
- if ($(jingle).find('>group[xmlns="urn:xmpp:jingle:apps:grouping:0"]').length) {
- $(jingle).find('>group[xmlns="urn:xmpp:jingle:apps:grouping:0"]').each(function (idx, group) {
- var contents = $(group).find('>content').map(function (idx, content) {
- return content.getAttribute('name');
- }).get();
- if (contents.length > 0) {
- self.raw += 'a=group:' + (group.getAttribute('semantics') || group.getAttribute('type')) + ' ' + contents.join(' ') + '\r\n';
- }
- });
- } else if ($(jingle).find('>group[xmlns="urn:ietf:rfc:5888"]').length) {
- // temporary namespace, not to be used. to be removed soon.
- $(jingle).find('>group[xmlns="urn:ietf:rfc:5888"]').each(function (idx, group) {
- var contents = $(group).find('>content').map(function (idx, content) {
- return content.getAttribute('name');
- }).get();
- if (group.getAttribute('type') !== null && contents.length > 0) {
- self.raw += 'a=group:' + group.getAttribute('type') + ' ' + contents.join(' ') + '\r\n';
- }
- });
- } else {
- // for backward compability, to be removed soon
- // assume all contents are in the same bundle group, can be improved upon later
- var bundle = $(jingle).find('>content').filter(function (idx, content) {
- //elem.c('bundle', {xmlns:'http://estos.de/ns/bundle'});
- return $(content).find('>bundle').length > 0;
- }).map(function (idx, content) {
- return content.getAttribute('name');
- }).get();
- if (bundle.length) {
- this.raw += 'a=group:BUNDLE ' + bundle.join(' ') + '\r\n';
- }
- }
-
- this.session = this.raw;
- jingle.find('>content').each(function () {
- var m = self.jingle2media($(this));
- self.media.push(m);
- });
-
- // reconstruct msid-semantic -- apparently not necessary
- /*
- var msid = SDPUtil.parse_ssrc(this.raw);
- if (msid.hasOwnProperty('mslabel')) {
- this.session += "a=msid-semantic: WMS " + msid.mslabel + "\r\n";
- }
- */
-
- this.raw = this.session + this.media.join('');
-};
-
-// translate a jingle content element into an an SDP media part
-SDP.prototype.jingle2media = function (content) {
- var media = '',
- desc = content.find('description'),
- ssrc = desc.attr('ssrc'),
- self = this,
- tmp;
- var sctp = content.find(
- '>transport>sctpmap[xmlns="urn:xmpp:jingle:transports:dtls-sctp:1"]');
-
- tmp = { media: desc.attr('media') };
- tmp.port = '1';
- if (content.attr('senders') == 'rejected') {
- // estos hack to reject an m-line.
- tmp.port = '0';
- }
- if (content.find('>transport>fingerprint').length || desc.find('encryption').length) {
- if (sctp.length)
- tmp.proto = 'DTLS/SCTP';
- else
- tmp.proto = 'RTP/SAVPF';
- } else {
- tmp.proto = 'RTP/AVPF';
- }
- if (!sctp.length)
- {
- tmp.fmt = desc.find('payload-type').map(
- function () { return this.getAttribute('id'); }).get();
- media += SDPUtil.build_mline(tmp) + '\r\n';
- }
- else
- {
- media += 'm=application 1 DTLS/SCTP ' + sctp.attr('number') + '\r\n';
- media += 'a=sctpmap:' + sctp.attr('number') +
- ' ' + sctp.attr('protocol');
-
- var streamCount = sctp.attr('streams');
- if (streamCount)
- media += ' ' + streamCount + '\r\n';
- else
- media += '\r\n';
- }
-
- media += 'c=IN IP4 0.0.0.0\r\n';
- if (!sctp.length)
- media += 'a=rtcp:1 IN IP4 0.0.0.0\r\n';
- //tmp = content.find('>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]');
- tmp = content.find('>bundle>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]');
- //console.log('transports: '+content.find('>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]').length);
- //console.log('bundle.transports: '+content.find('>bundle>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]').length);
- //console.log("tmp fingerprint: "+tmp.find('>fingerprint').innerHTML);
- if (tmp.length) {
- if (tmp.attr('ufrag')) {
- media += SDPUtil.build_iceufrag(tmp.attr('ufrag')) + '\r\n';
- }
- if (tmp.attr('pwd')) {
- media += SDPUtil.build_icepwd(tmp.attr('pwd')) + '\r\n';
- }
- tmp.find('>fingerprint').each(function () {
- // FIXME: check namespace at some point
- media += 'a=fingerprint:' + this.getAttribute('hash');
- media += ' ' + $(this).text();
- media += '\r\n';
- //console.log("mline "+media);
- if (this.getAttribute('setup')) {
- media += 'a=setup:' + this.getAttribute('setup') + '\r\n';
- }
- });
- }
- switch (content.attr('senders')) {
- case 'initiator':
- media += 'a=sendonly\r\n';
- break;
- case 'responder':
- media += 'a=recvonly\r\n';
- break;
- case 'none':
- media += 'a=inactive\r\n';
- break;
- case 'both':
- media += 'a=sendrecv\r\n';
- break;
- }
- media += 'a=mid:' + content.attr('name') + '\r\n';
- /*if (content.attr('name') == 'video') {
- media += 'a=x-google-flag:conference' + '\r\n';
- }*/
-
- // <description><rtcp-mux/></description>
- // see http://code.google.com/p/libjingle/issues/detail?id=309 -- no spec though
- // and http://mail.jabber.org/pipermail/jingle/2011-December/001761.html
- if (desc.find('rtcp-mux').length) {
- media += 'a=rtcp-mux\r\n';
- }
-
- if (desc.find('encryption').length) {
- desc.find('encryption>crypto').each(function () {
- media += 'a=crypto:' + this.getAttribute('tag');
- media += ' ' + this.getAttribute('crypto-suite');
- media += ' ' + this.getAttribute('key-params');
- if (this.getAttribute('session-params')) {
- media += ' ' + this.getAttribute('session-params');
- }
- media += '\r\n';
- });
- }
- desc.find('payload-type').each(function () {
- media += SDPUtil.build_rtpmap(this) + '\r\n';
- if ($(this).find('>parameter').length) {
- media += 'a=fmtp:' + this.getAttribute('id') + ' ';
- media += $(this).find('parameter').map(function () { return (this.getAttribute('name') ? (this.getAttribute('name') + '=') : '') + this.getAttribute('value'); }).get().join('; ');
- media += '\r\n';
- }
- // xep-0293
- media += self.RtcpFbFromJingle($(this), this.getAttribute('id'));
- });
-
- // xep-0293
- media += self.RtcpFbFromJingle(desc, '*');
-
- // xep-0294
- tmp = desc.find('>rtp-hdrext[xmlns="urn:xmpp:jingle:apps:rtp:rtp-hdrext:0"]');
- tmp.each(function () {
- media += 'a=extmap:' + this.getAttribute('id') + ' ' + this.getAttribute('uri') + '\r\n';
- });
-
- content.find('>bundle>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]>candidate').each(function () {
- media += SDPUtil.candidateFromJingle(this);
- });
-
- // XEP-0339 handle ssrc-group attributes
- tmp = content.find('description>ssrc-group[xmlns="urn:xmpp:jingle:apps:rtp:ssma:0"]').each(function() {
- var semantics = this.getAttribute('semantics');
- var ssrcs = $(this).find('>source').map(function() {
- return this.getAttribute('ssrc');
- }).get();
-
- if (ssrcs.length != 0) {
- media += 'a=ssrc-group:' + semantics + ' ' + ssrcs.join(' ') + '\r\n';
- }
- });
-
- tmp = content.find('description>source[xmlns="urn:xmpp:jingle:apps:rtp:ssma:0"]');
- tmp.each(function () {
- var ssrc = this.getAttribute('ssrc');
- $(this).find('>parameter').each(function () {
- media += 'a=ssrc:' + ssrc + ' ' + this.getAttribute('name');
- if (this.getAttribute('value') && this.getAttribute('value').length)
- media += ':' + this.getAttribute('value');
- media += '\r\n';
- });
- });
-
- if (tmp.length === 0) {
- // fallback to proprietary mapping of a=ssrc lines
- tmp = content.find('description>ssrc[xmlns="http://estos.de/ns/ssrc"]');
- if (tmp.length) {
- media += 'a=ssrc:' + ssrc + ' cname:' + tmp.attr('cname') + '\r\n';
- media += 'a=ssrc:' + ssrc + ' msid:' + tmp.attr('msid') + '\r\n';
- media += 'a=ssrc:' + ssrc + ' mslabel:' + tmp.attr('mslabel') + '\r\n';
- media += 'a=ssrc:' + ssrc + ' label:' + tmp.attr('label') + '\r\n';
- }
- }
- return media;
-};
-
diff --git a/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.util.js b/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.util.js
deleted file mode 100644
index 042a123c..00000000
--- a/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.util.js
+++ /dev/null
@@ -1,408 +0,0 @@
-/**
- * Contains utility classes used in SDP class.
- *
- */
-
-/**
- * Class holds a=ssrc lines and media type a=mid
- * @param ssrc synchronization source identifier number(a=ssrc lines from SDP)
- * @param type media type eg. "audio" or "video"(a=mid frm SDP)
- * @constructor
- */
-function ChannelSsrc(ssrc, type) {
- this.ssrc = ssrc;
- this.type = type;
- this.lines = [];
-}
-
-/**
- * Class holds a=ssrc-group: lines
- * @param semantics
- * @param ssrcs
- * @constructor
- */
-function ChannelSsrcGroup(semantics, ssrcs, line) {
- this.semantics = semantics;
- this.ssrcs = ssrcs;
-}
-
-/**
- * Helper class represents media channel. Is a container for ChannelSsrc, holds channel idx and media type.
- * @param channelNumber channel idx in SDP media array.
- * @param mediaType media type(a=mid)
- * @constructor
- */
-function MediaChannel(channelNumber, mediaType) {
- /**
- * SDP channel number
- * @type {*}
- */
- this.chNumber = channelNumber;
- /**
- * Channel media type(a=mid)
- * @type {*}
- */
- this.mediaType = mediaType;
- /**
- * The maps of ssrc numbers to ChannelSsrc objects.
- */
- this.ssrcs = {};
-
- /**
- * The array of ChannelSsrcGroup objects.
- * @type {Array}
- */
- this.ssrcGroups = [];
-}
-
-SDPUtil = {
- iceparams: function (mediadesc, sessiondesc) {
- var data = null;
- if (SDPUtil.find_line(mediadesc, 'a=ice-ufrag:', sessiondesc) &&
- SDPUtil.find_line(mediadesc, 'a=ice-pwd:', sessiondesc)) {
- data = {
- ufrag: SDPUtil.parse_iceufrag(SDPUtil.find_line(mediadesc, 'a=ice-ufrag:', sessiondesc)),
- pwd: SDPUtil.parse_icepwd(SDPUtil.find_line(mediadesc, 'a=ice-pwd:', sessiondesc))
- };
- }
- return data;
- },
- parse_iceufrag: function (line) {
- return line.substring(12);
- },
- build_iceufrag: function (frag) {
- return 'a=ice-ufrag:' + frag;
- },
- parse_icepwd: function (line) {
- return line.substring(10);
- },
- build_icepwd: function (pwd) {
- return 'a=ice-pwd:' + pwd;
- },
- parse_mid: function (line) {
- return line.substring(6);
- },
- parse_mline: function (line) {
- var parts = line.substring(2).split(' '),
- data = {};
- data.media = parts.shift();
- data.port = parts.shift();
- data.proto = parts.shift();
- if (parts[parts.length - 1] === '') { // trailing whitespace
- parts.pop();
- }
- data.fmt = parts;
- return data;
- },
- build_mline: function (mline) {
- return 'm=' + mline.media + ' ' + mline.port + ' ' + mline.proto + ' ' + mline.fmt.join(' ');
- },
- parse_rtpmap: function (line) {
- var parts = line.substring(9).split(' '),
- data = {};
- data.id = parts.shift();
- parts = parts[0].split('/');
- data.name = parts.shift();
- data.clockrate = parts.shift();
- data.channels = parts.length ? parts.shift() : '1';
- return data;
- },
- /**
- * Parses SDP line "a=sctpmap:..." and extracts SCTP port from it.
- * @param line eg. "a=sctpmap:5000 webrtc-datachannel"
- * @returns [SCTP port number, protocol, streams]
- */
- parse_sctpmap: function (line)
- {
- var parts = line.substring(10).split(' ');
- var sctpPort = parts[0];
- var protocol = parts[1];
- // Stream count is optional
- var streamCount = parts.length > 2 ? parts[2] : null;
- return [sctpPort, protocol, streamCount];// SCTP port
- },
- build_rtpmap: function (el) {
- var line = 'a=rtpmap:' + el.getAttribute('id') + ' ' + el.getAttribute('name') + '/' + el.getAttribute('clockrate');
- if (el.getAttribute('channels') && el.getAttribute('channels') != '1') {
- line += '/' + el.getAttribute('channels');
- }
- return line;
- },
- parse_crypto: function (line) {
- var parts = line.substring(9).split(' '),
- data = {};
- data.tag = parts.shift();
- data['crypto-suite'] = parts.shift();
- data['key-params'] = parts.shift();
- if (parts.length) {
- data['session-params'] = parts.join(' ');
- }
- return data;
- },
- parse_fingerprint: function (line) { // RFC 4572
- var parts = line.substring(14).split(' '),
- data = {};
- data.hash = parts.shift();
- data.fingerprint = parts.shift();
- // TODO assert that fingerprint satisfies 2UHEX *(":" 2UHEX) ?
- return data;
- },
- parse_fmtp: function (line) {
- var parts = line.split(' '),
- i, key, value,
- data = [];
- parts.shift();
- parts = parts.join(' ').split(';');
- for (i = 0; i < parts.length; i++) {
- key = parts[i].split('=')[0];
- while (key.length && key[0] == ' ') {
- key = key.substring(1);
- }
- value = parts[i].split('=')[1];
- if (key && value) {
- data.push({name: key, value: value});
- } else if (key) {
- // rfc 4733 (DTMF) style stuff
- data.push({name: '', value: key});
- }
- }
- return data;
- },
- parse_icecandidate: function (line) {
- var candidate = {},
- elems = line.split(' ');
- candidate.foundation = elems[0].substring(12);
- candidate.component = elems[1];
- candidate.protocol = elems[2].toLowerCase();
- candidate.priority = elems[3];
- candidate.ip = elems[4];
- candidate.port = elems[5];
- // elems[6] => "typ"
- candidate.type = elems[7];
- candidate.generation = 0; // default value, may be overwritten below
- for (var i = 8; i < elems.length; i += 2) {
- switch (elems[i]) {
- case 'raddr':
- candidate['rel-addr'] = elems[i + 1];
- break;
- case 'rport':
- candidate['rel-port'] = elems[i + 1];
- break;
- case 'generation':
- candidate.generation = elems[i + 1];
- break;
- case 'tcptype':
- candidate.tcptype = elems[i + 1];
- break;
- default: // TODO
- console.log('parse_icecandidate not translating "' + elems[i] + '" = "' + elems[i + 1] + '"');
- }
- }
- candidate.network = '1';
- candidate.id = Math.random().toString(36).substr(2, 10); // not applicable to SDP -- FIXME: should be unique, not just random
- return candidate;
- },
- build_icecandidate: function (cand) {
- var line = ['a=candidate:' + cand.foundation, cand.component, cand.protocol, cand.priority, cand.ip, cand.port, 'typ', cand.type].join(' ');
- line += ' ';
- switch (cand.type) {
- case 'srflx':
- case 'prflx':
- case 'relay':
- if (cand.hasOwnAttribute('rel-addr') && cand.hasOwnAttribute('rel-port')) {
- line += 'raddr';
- line += ' ';
- line += cand['rel-addr'];
- line += ' ';
- line += 'rport';
- line += ' ';
- line += cand['rel-port'];
- line += ' ';
- }
- break;
- }
- if (cand.hasOwnAttribute('tcptype')) {
- line += 'tcptype';
- line += ' ';
- line += cand.tcptype;
- line += ' ';
- }
- line += 'generation';
- line += ' ';
- line += cand.hasOwnAttribute('generation') ? cand.generation : '0';
- return line;
- },
- parse_ssrc: function (desc) {
- // proprietary mapping of a=ssrc lines
- // TODO: see "Jingle RTP Source Description" by Juberti and P. Thatcher on google docs
- // and parse according to that
- var lines = desc.split('\r\n'),
- data = {};
- for (var i = 0; i < lines.length; i++) {
- if (lines[i].substring(0, 7) == 'a=ssrc:') {
- var idx = lines[i].indexOf(' ');
- data[lines[i].substr(idx + 1).split(':', 2)[0]] = lines[i].substr(idx + 1).split(':', 2)[1];
- }
- }
- return data;
- },
- parse_rtcpfb: function (line) {
- var parts = line.substr(10).split(' ');
- var data = {};
- data.pt = parts.shift();
- data.type = parts.shift();
- data.params = parts;
- return data;
- },
- parse_extmap: function (line) {
- var parts = line.substr(9).split(' ');
- var data = {};
- data.value = parts.shift();
- if (data.value.indexOf('/') != -1) {
- data.direction = data.value.substr(data.value.indexOf('/') + 1);
- data.value = data.value.substr(0, data.value.indexOf('/'));
- } else {
- data.direction = 'both';
- }
- data.uri = parts.shift();
- data.params = parts;
- return data;
- },
- find_line: function (haystack, needle, sessionpart) {
- var lines = haystack.split('\r\n');
- for (var i = 0; i < lines.length; i++) {
- if (lines[i].substring(0, needle.length) == needle) {
- return lines[i];
- }
- }
- if (!sessionpart) {
- return false;
- }
- // search session part
- lines = sessionpart.split('\r\n');
- for (var j = 0; j < lines.length; j++) {
- if (lines[j].substring(0, needle.length) == needle) {
- return lines[j];
- }
- }
- return false;
- },
- find_lines: function (haystack, needle, sessionpart) {
- var lines = haystack.split('\r\n'),
- needles = [];
- for (var i = 0; i < lines.length; i++) {
- if (lines[i].substring(0, needle.length) == needle)
- needles.push(lines[i]);
- }
- if (needles.length || !sessionpart) {
- return needles;
- }
- // search session part
- lines = sessionpart.split('\r\n');
- for (var j = 0; j < lines.length; j++) {
- if (lines[j].substring(0, needle.length) == needle) {
- needles.push(lines[j]);
- }
- }
- return needles;
- },
- candidateToJingle: function (line) {
- // a=candidate:2979166662 1 udp 2113937151 192.168.2.100 57698 typ host generation 0
- // <candidate component=... foundation=... generation=... id=... ip=... network=... port=... priority=... protocol=... type=.../>
- if (line.indexOf('candidate:') === 0) {
- line = 'a=' + line;
- } else if (line.substring(0, 12) != 'a=candidate:') {
- console.log('parseCandidate called with a line that is not a candidate line');
- console.log(line);
- return null;
- }
- if (line.substring(line.length - 2) == '\r\n') // chomp it
- line = line.substring(0, line.length - 2);
- var candidate = {},
- elems = line.split(' '),
- i;
- if (elems[6] != 'typ') {
- console.log('did not find typ in the right place');
- console.log(line);
- return null;
- }
- candidate.foundation = elems[0].substring(12);
- candidate.component = elems[1];
- candidate.protocol = elems[2].toLowerCase();
- candidate.priority = elems[3];
- candidate.ip = elems[4];
- candidate.port = elems[5];
- // elems[6] => "typ"
- candidate.type = elems[7];
-
- candidate.generation = '0'; // default, may be overwritten below
- for (i = 8; i < elems.length; i += 2) {
- switch (elems[i]) {
- case 'raddr':
- candidate['rel-addr'] = elems[i + 1];
- break;
- case 'rport':
- candidate['rel-port'] = elems[i + 1];
- break;
- case 'generation':
- candidate.generation = elems[i + 1];
- break;
- case 'tcptype':
- candidate.tcptype = elems[i + 1];
- break;
- default: // TODO
- console.log('not translating "' + elems[i] + '" = "' + elems[i + 1] + '"');
- }
- }
- candidate.network = '1';
- candidate.id = Math.random().toString(36).substr(2, 10); // not applicable to SDP -- FIXME: should be unique, not just random
- return candidate;
- },
- candidateFromJingle: function (cand) {
- var line = 'a=candidate:';
- line += cand.getAttribute('foundation');
- line += ' ';
- line += cand.getAttribute('component');
- line += ' ';
- line += cand.getAttribute('protocol'); //.toUpperCase(); // chrome M23 doesn't like this
- line += ' ';
- line += cand.getAttribute('priority');
- line += ' ';
- line += cand.getAttribute('ip');
- line += ' ';
- line += cand.getAttribute('port');
- line += ' ';
- line += 'typ';
- line += ' ' + cand.getAttribute('type');
- line += ' ';
- switch (cand.getAttribute('type')) {
- case 'srflx':
- case 'prflx':
- case 'relay':
- if (cand.getAttribute('rel-addr') && cand.getAttribute('rel-port')) {
- line += 'raddr';
- line += ' ';
- line += cand.getAttribute('rel-addr');
- line += ' ';
- line += 'rport';
- line += ' ';
- line += cand.getAttribute('rel-port');
- line += ' ';
- }
- break;
- }
- if (cand.getAttribute('protocol').toLowerCase() == 'tcp') {
- line += 'tcptype';
- line += ' ';
- line += cand.getAttribute('tcptype');
- line += ' ';
- }
- line += 'generation';
- line += ' ';
- line += cand.getAttribute('generation') || '0';
- return line + '\r\n';
- }
-};
-
-exports.SDPUtil = SDPUtil;
-
diff --git a/contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js b/contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js
deleted file mode 100644
index 9c45c2df..00000000
--- a/contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js
+++ /dev/null
@@ -1,254 +0,0 @@
-/**
- * Wrapper for built-in http.js to emulate the browser XMLHttpRequest object.
- *
- * This can be used with JS designed for browsers to improve reuse of code and
- * allow the use of existing libraries.
- *
- * Usage: include("XMLHttpRequest.js") and use XMLHttpRequest per W3C specs.
- *
- * @todo SSL Support
- * @author Dan DeFelippi <dan@driverdan.com>
- * @license MIT
- */
-
-var Url = require("url")
- ,sys = require("util");
-
-exports.XMLHttpRequest = function() {
- /**
- * Private variables
- */
- var self = this;
- var http = require('http');
- var https = require('https');
-
- // Holds http.js objects
- var client;
- var request;
- var response;
-
- // Request settings
- var settings = {};
-
- // Set some default headers
- var defaultHeaders = {
- "User-Agent": "node.js",
- "Accept": "*/*",
- };
-
- var headers = defaultHeaders;
-
- /**
- * Constants
- */
- this.UNSENT = 0;
- this.OPENED = 1;
- this.HEADERS_RECEIVED = 2;
- this.LOADING = 3;
- this.DONE = 4;
-
- /**
- * Public vars
- */
- // Current state
- this.readyState = this.UNSENT;
-
- // default ready state change handler in case one is not set or is set late
- this.onreadystatechange = function() {};
-
- // Result & response
- this.responseText = "";
- this.responseXML = "";
- this.status = null;
- this.statusText = null;
-
- /**
- * Open the connection. Currently supports local server requests.
- *
- * @param string method Connection method (eg GET, POST)
- * @param string url URL for the connection.
- * @param boolean async Asynchronous connection. Default is true.
- * @param string user Username for basic authentication (optional)
- * @param string password Password for basic authentication (optional)
- */
- this.open = function(method, url, async, user, password) {
- settings = {
- "method": method,
- "url": url,
- "async": async || null,
- "user": user || null,
- "password": password || null
- };
-
- this.abort();
-
- setState(this.OPENED);
- };
-
- /**
- * Sets a header for the request.
- *
- * @param string header Header name
- * @param string value Header value
- */
- this.setRequestHeader = function(header, value) {
- headers[header] = value;
- };
-
- /**
- * Gets a header from the server response.
- *
- * @param string header Name of header to get.
- * @return string Text of the header or null if it doesn't exist.
- */
- this.getResponseHeader = function(header) {
- if (this.readyState > this.OPENED && response.headers[header]) {
- return header + ": " + response.headers[header];
- }
-
- return null;
- };
-
- /**
- * Gets all the response headers.
- *
- * @return string
- */
- this.getAllResponseHeaders = function() {
- if (this.readyState < this.HEADERS_RECEIVED) {
- throw "INVALID_STATE_ERR: Headers have not been received.";
- }
- var result = "";
-
- for (var i in response.headers) {
- result += i + ": " + response.headers[i] + "\r\n";
- }
- return result.substr(0, result.length - 2);
- };
-
- /**
- * Sends the request to the server.
- *
- * @param string data Optional data to send as request body.
- */
- this.send = function(data) {
- if (this.readyState != this.OPENED) {
- throw "INVALID_STATE_ERR: connection must be opened before send() is called";
- }
-
- var ssl = false;
- var url = Url.parse(settings.url);
-
- // Determine the server
- switch (url.protocol) {
- case 'https:':
- ssl = true;
- // SSL & non-SSL both need host, no break here.
- case 'http:':
- var host = url.hostname;
- break;
-
- case undefined:
- case '':
- var host = "localhost";
- break;
-
- default:
- throw "Protocol not supported.";
- }
-
- // Default to port 80. If accessing localhost on another port be sure
- // to use http://localhost:port/path
- var port = url.port || (ssl ? 443 : 80);
- // Add query string if one is used
- var uri = url.pathname + (url.search ? url.search : '');
-
- // Set the Host header or the server may reject the request
- this.setRequestHeader("Host", host);
-
- // Set content length header
- if (settings.method == "GET" || settings.method == "HEAD") {
- data = null;
- } else if (data) {
- this.setRequestHeader("Content-Length", Buffer.byteLength(data));
-
- if (!headers["Content-Type"]) {
- this.setRequestHeader("Content-Type", "text/plain;charset=UTF-8");
- }
- }
-
- // Use the proper protocol
- var doRequest = ssl ? https.request : http.request;
-
- var options = {
- host: host,
- port: port,
- path: uri,
- method: settings.method,
- headers: headers,
- agent: false
- };
-
- var req = doRequest(options, function(res) {
- response = res;
- response.setEncoding("utf8");
-
- setState(self.HEADERS_RECEIVED);
- self.status = response.statusCode;
-
- response.on('data', function(chunk) {
- // Make sure there's some data
- if (chunk) {
- self.responseText += chunk;
- }
- setState(self.LOADING);
- });
-
- response.on('end', function() {
- setState(self.DONE);
- });
-
- response.on('error', function() {
- self.handleError(error);
- });
- }).on('error', function(error) {
- self.handleError(error);
- });
-
- req.setHeader("Connection", "Close");
-
- // Node 0.4 and later won't accept empty data. Make sure it's needed.
- if (data) {
- req.write(data);
- }
-
- req.end();
- };
-
- this.handleError = function(error) {
- this.status = 503;
- this.statusText = error;
- this.responseText = error.stack;
- setState(this.DONE);
- };
-
- /**
- * Aborts a request.
- */
- this.abort = function() {
- headers = defaultHeaders;
- this.readyState = this.UNSENT;
- this.responseText = "";
- this.responseXML = "";
- };
-
- /**
- * Changes readyState and calls onreadystatechange.
- *
- * @param int state New state
- */
- var setState = function(state) {
- self.readyState = state;
- self.onreadystatechange();
- }
-};
diff --git a/contrib/jitsimeetbridge/unjingle/strophe/base64.js b/contrib/jitsimeetbridge/unjingle/strophe/base64.js
deleted file mode 100644
index 418caac0..00000000
--- a/contrib/jitsimeetbridge/unjingle/strophe/base64.js
+++ /dev/null
@@ -1,83 +0,0 @@
-// This code was written by Tyler Akins and has been placed in the
-// public domain. It would be nice if you left this header intact.
-// Base64 code from Tyler Akins -- http://rumkin.com
-
-var Base64 = (function () {
- var keyStr = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";
-
- var obj = {
- /**
- * Encodes a string in base64
- * @param {String} input The string to encode in base64.
- */
- encode: function (input) {
- var output = "";
- var chr1, chr2, chr3;
- var enc1, enc2, enc3, enc4;
- var i = 0;
-
- do {
- chr1 = input.charCodeAt(i++);
- chr2 = input.charCodeAt(i++);
- chr3 = input.charCodeAt(i++);
-
- enc1 = chr1 >> 2;
- enc2 = ((chr1 & 3) << 4) | (chr2 >> 4);
- enc3 = ((chr2 & 15) << 2) | (chr3 >> 6);
- enc4 = chr3 & 63;
-
- if (isNaN(chr2)) {
- enc3 = enc4 = 64;
- } else if (isNaN(chr3)) {
- enc4 = 64;
- }
-
- output = output + keyStr.charAt(enc1) + keyStr.charAt(enc2) +
- keyStr.charAt(enc3) + keyStr.charAt(enc4);
- } while (i < input.length);
-
- return output;
- },
-
- /**
- * Decodes a base64 string.
- * @param {String} input The string to decode.
- */
- decode: function (input) {
- var output = "";
- var chr1, chr2, chr3;
- var enc1, enc2, enc3, enc4;
- var i = 0;
-
- // remove all characters that are not A-Z, a-z, 0-9, +, /, or =
- input = input.replace(/[^A-Za-z0-9\+\/\=]/g, '');
-
- do {
- enc1 = keyStr.indexOf(input.charAt(i++));
- enc2 = keyStr.indexOf(input.charAt(i++));
- enc3 = keyStr.indexOf(input.charAt(i++));
- enc4 = keyStr.indexOf(input.charAt(i++));
-
- chr1 = (enc1 << 2) | (enc2 >> 4);
- chr2 = ((enc2 & 15) << 4) | (enc3 >> 2);
- chr3 = ((enc3 & 3) << 6) | enc4;
-
- output = output + String.fromCharCode(chr1);
-
- if (enc3 != 64) {
- output = output + String.fromCharCode(chr2);
- }
- if (enc4 != 64) {
- output = output + String.fromCharCode(chr3);
- }
- } while (i < input.length);
-
- return output;
- }
- };
-
- return obj;
-})();
-
-// Nodify
-exports.Base64 = Base64;
diff --git a/contrib/jitsimeetbridge/unjingle/strophe/md5.js b/contrib/jitsimeetbridge/unjingle/strophe/md5.js
deleted file mode 100644
index 5334325e..00000000
--- a/contrib/jitsimeetbridge/unjingle/strophe/md5.js
+++ /dev/null
@@ -1,279 +0,0 @@
-/*
- * A JavaScript implementation of the RSA Data Security, Inc. MD5 Message
- * Digest Algorithm, as defined in RFC 1321.
- * Version 2.1 Copyright (C) Paul Johnston 1999 - 2002.
- * Other contributors: Greg Holt, Andrew Kepert, Ydnar, Lostinet
- * Distributed under the BSD License
- * See http://pajhome.org.uk/crypt/md5 for more info.
- */
-
-var MD5 = (function () {
- /*
- * Configurable variables. You may need to tweak these to be compatible with
- * the server-side, but the defaults work in most cases.
- */
- var hexcase = 0; /* hex output format. 0 - lowercase; 1 - uppercase */
- var b64pad = ""; /* base-64 pad character. "=" for strict RFC compliance */
- var chrsz = 8; /* bits per input character. 8 - ASCII; 16 - Unicode */
-
- /*
- * Add integers, wrapping at 2^32. This uses 16-bit operations internally
- * to work around bugs in some JS interpreters.
- */
- var safe_add = function (x, y) {
- var lsw = (x & 0xFFFF) + (y & 0xFFFF);
- var msw = (x >> 16) + (y >> 16) + (lsw >> 16);
- return (msw << 16) | (lsw & 0xFFFF);
- };
-
- /*
- * Bitwise rotate a 32-bit number to the left.
- */
- var bit_rol = function (num, cnt) {
- return (num << cnt) | (num >>> (32 - cnt));
- };
-
- /*
- * Convert a string to an array of little-endian words
- * If chrsz is ASCII, characters >255 have their hi-byte silently ignored.
- */
- var str2binl = function (str) {
- var bin = [];
- var mask = (1 << chrsz) - 1;
- for(var i = 0; i < str.length * chrsz; i += chrsz)
- {
- bin[i>>5] |= (str.charCodeAt(i / chrsz) & mask) << (i%32);
- }
- return bin;
- };
-
- /*
- * Convert an array of little-endian words to a string
- */
- var binl2str = function (bin) {
- var str = "";
- var mask = (1 << chrsz) - 1;
- for(var i = 0; i < bin.length * 32; i += chrsz)
- {
- str += String.fromCharCode((bin[i>>5] >>> (i % 32)) & mask);
- }
- return str;
- };
-
- /*
- * Convert an array of little-endian words to a hex string.
- */
- var binl2hex = function (binarray) {
- var hex_tab = hexcase ? "0123456789ABCDEF" : "0123456789abcdef";
- var str = "";
- for(var i = 0; i < binarray.length * 4; i++)
- {
- str += hex_tab.charAt((binarray[i>>2] >> ((i%4)*8+4)) & 0xF) +
- hex_tab.charAt((binarray[i>>2] >> ((i%4)*8 )) & 0xF);
- }
- return str;
- };
-
- /*
- * Convert an array of little-endian words to a base-64 string
- */
- var binl2b64 = function (binarray) {
- var tab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
- var str = "";
- var triplet, j;
- for(var i = 0; i < binarray.length * 4; i += 3)
- {
- triplet = (((binarray[i >> 2] >> 8 * ( i %4)) & 0xFF) << 16) |
- (((binarray[i+1 >> 2] >> 8 * ((i+1)%4)) & 0xFF) << 8 ) |
- ((binarray[i+2 >> 2] >> 8 * ((i+2)%4)) & 0xFF);
- for(j = 0; j < 4; j++)
- {
- if(i * 8 + j * 6 > binarray.length * 32) { str += b64pad; }
- else { str += tab.charAt((triplet >> 6*(3-j)) & 0x3F); }
- }
- }
- return str;
- };
-
- /*
- * These functions implement the four basic operations the algorithm uses.
- */
- var md5_cmn = function (q, a, b, x, s, t) {
- return safe_add(bit_rol(safe_add(safe_add(a, q),safe_add(x, t)), s),b);
- };
-
- var md5_ff = function (a, b, c, d, x, s, t) {
- return md5_cmn((b & c) | ((~b) & d), a, b, x, s, t);
- };
-
- var md5_gg = function (a, b, c, d, x, s, t) {
- return md5_cmn((b & d) | (c & (~d)), a, b, x, s, t);
- };
-
- var md5_hh = function (a, b, c, d, x, s, t) {
- return md5_cmn(b ^ c ^ d, a, b, x, s, t);
- };
-
- var md5_ii = function (a, b, c, d, x, s, t) {
- return md5_cmn(c ^ (b | (~d)), a, b, x, s, t);
- };
-
- /*
- * Calculate the MD5 of an array of little-endian words, and a bit length
- */
- var core_md5 = function (x, len) {
- /* append padding */
- x[len >> 5] |= 0x80 << ((len) % 32);
- x[(((len + 64) >>> 9) << 4) + 14] = len;
-
- var a = 1732584193;
- var b = -271733879;
- var c = -1732584194;
- var d = 271733878;
-
- var olda, oldb, oldc, oldd;
- for (var i = 0; i < x.length; i += 16)
- {
- olda = a;
- oldb = b;
- oldc = c;
- oldd = d;
-
- a = md5_ff(a, b, c, d, x[i+ 0], 7 , -680876936);
- d = md5_ff(d, a, b, c, x[i+ 1], 12, -389564586);
- c = md5_ff(c, d, a, b, x[i+ 2], 17, 606105819);
- b = md5_ff(b, c, d, a, x[i+ 3], 22, -1044525330);
- a = md5_ff(a, b, c, d, x[i+ 4], 7 , -176418897);
- d = md5_ff(d, a, b, c, x[i+ 5], 12, 1200080426);
- c = md5_ff(c, d, a, b, x[i+ 6], 17, -1473231341);
- b = md5_ff(b, c, d, a, x[i+ 7], 22, -45705983);
- a = md5_ff(a, b, c, d, x[i+ 8], 7 , 1770035416);
- d = md5_ff(d, a, b, c, x[i+ 9], 12, -1958414417);
- c = md5_ff(c, d, a, b, x[i+10], 17, -42063);
- b = md5_ff(b, c, d, a, x[i+11], 22, -1990404162);
- a = md5_ff(a, b, c, d, x[i+12], 7 , 1804603682);
- d = md5_ff(d, a, b, c, x[i+13], 12, -40341101);
- c = md5_ff(c, d, a, b, x[i+14], 17, -1502002290);
- b = md5_ff(b, c, d, a, x[i+15], 22, 1236535329);
-
- a = md5_gg(a, b, c, d, x[i+ 1], 5 , -165796510);
- d = md5_gg(d, a, b, c, x[i+ 6], 9 , -1069501632);
- c = md5_gg(c, d, a, b, x[i+11], 14, 643717713);
- b = md5_gg(b, c, d, a, x[i+ 0], 20, -373897302);
- a = md5_gg(a, b, c, d, x[i+ 5], 5 , -701558691);
- d = md5_gg(d, a, b, c, x[i+10], 9 , 38016083);
- c = md5_gg(c, d, a, b, x[i+15], 14, -660478335);
- b = md5_gg(b, c, d, a, x[i+ 4], 20, -405537848);
- a = md5_gg(a, b, c, d, x[i+ 9], 5 , 568446438);
- d = md5_gg(d, a, b, c, x[i+14], 9 , -1019803690);
- c = md5_gg(c, d, a, b, x[i+ 3], 14, -187363961);
- b = md5_gg(b, c, d, a, x[i+ 8], 20, 1163531501);
- a = md5_gg(a, b, c, d, x[i+13], 5 , -1444681467);
- d = md5_gg(d, a, b, c, x[i+ 2], 9 , -51403784);
- c = md5_gg(c, d, a, b, x[i+ 7], 14, 1735328473);
- b = md5_gg(b, c, d, a, x[i+12], 20, -1926607734);
-
- a = md5_hh(a, b, c, d, x[i+ 5], 4 , -378558);
- d = md5_hh(d, a, b, c, x[i+ 8], 11, -2022574463);
- c = md5_hh(c, d, a, b, x[i+11], 16, 1839030562);
- b = md5_hh(b, c, d, a, x[i+14], 23, -35309556);
- a = md5_hh(a, b, c, d, x[i+ 1], 4 , -1530992060);
- d = md5_hh(d, a, b, c, x[i+ 4], 11, 1272893353);
- c = md5_hh(c, d, a, b, x[i+ 7], 16, -155497632);
- b = md5_hh(b, c, d, a, x[i+10], 23, -1094730640);
- a = md5_hh(a, b, c, d, x[i+13], 4 , 681279174);
- d = md5_hh(d, a, b, c, x[i+ 0], 11, -358537222);
- c = md5_hh(c, d, a, b, x[i+ 3], 16, -722521979);
- b = md5_hh(b, c, d, a, x[i+ 6], 23, 76029189);
- a = md5_hh(a, b, c, d, x[i+ 9], 4 , -640364487);
- d = md5_hh(d, a, b, c, x[i+12], 11, -421815835);
- c = md5_hh(c, d, a, b, x[i+15], 16, 530742520);
- b = md5_hh(b, c, d, a, x[i+ 2], 23, -995338651);
-
- a = md5_ii(a, b, c, d, x[i+ 0], 6 , -198630844);
- d = md5_ii(d, a, b, c, x[i+ 7], 10, 1126891415);
- c = md5_ii(c, d, a, b, x[i+14], 15, -1416354905);
- b = md5_ii(b, c, d, a, x[i+ 5], 21, -57434055);
- a = md5_ii(a, b, c, d, x[i+12], 6 , 1700485571);
- d = md5_ii(d, a, b, c, x[i+ 3], 10, -1894986606);
- c = md5_ii(c, d, a, b, x[i+10], 15, -1051523);
- b = md5_ii(b, c, d, a, x[i+ 1], 21, -2054922799);
- a = md5_ii(a, b, c, d, x[i+ 8], 6 , 1873313359);
- d = md5_ii(d, a, b, c, x[i+15], 10, -30611744);
- c = md5_ii(c, d, a, b, x[i+ 6], 15, -1560198380);
- b = md5_ii(b, c, d, a, x[i+13], 21, 1309151649);
- a = md5_ii(a, b, c, d, x[i+ 4], 6 , -145523070);
- d = md5_ii(d, a, b, c, x[i+11], 10, -1120210379);
- c = md5_ii(c, d, a, b, x[i+ 2], 15, 718787259);
- b = md5_ii(b, c, d, a, x[i+ 9], 21, -343485551);
-
- a = safe_add(a, olda);
- b = safe_add(b, oldb);
- c = safe_add(c, oldc);
- d = safe_add(d, oldd);
- }
- return [a, b, c, d];
- };
-
-
- /*
- * Calculate the HMAC-MD5, of a key and some data
- */
- var core_hmac_md5 = function (key, data) {
- var bkey = str2binl(key);
- if(bkey.length > 16) { bkey = core_md5(bkey, key.length * chrsz); }
-
- var ipad = new Array(16), opad = new Array(16);
- for(var i = 0; i < 16; i++)
- {
- ipad[i] = bkey[i] ^ 0x36363636;
- opad[i] = bkey[i] ^ 0x5C5C5C5C;
- }
-
- var hash = core_md5(ipad.concat(str2binl(data)), 512 + data.length * chrsz);
- return core_md5(opad.concat(hash), 512 + 128);
- };
-
- var obj = {
- /*
- * These are the functions you'll usually want to call.
- * They take string arguments and return either hex or base-64 encoded
- * strings.
- */
- hexdigest: function (s) {
- return binl2hex(core_md5(str2binl(s), s.length * chrsz));
- },
-
- b64digest: function (s) {
- return binl2b64(core_md5(str2binl(s), s.length * chrsz));
- },
-
- hash: function (s) {
- return binl2str(core_md5(str2binl(s), s.length * chrsz));
- },
-
- hmac_hexdigest: function (key, data) {
- return binl2hex(core_hmac_md5(key, data));
- },
-
- hmac_b64digest: function (key, data) {
- return binl2b64(core_hmac_md5(key, data));
- },
-
- hmac_hash: function (key, data) {
- return binl2str(core_hmac_md5(key, data));
- },
-
- /*
- * Perform a simple self-test to see if the VM is working
- */
- test: function () {
- return MD5.hexdigest("abc") === "900150983cd24fb0d6963f7d28e17f72";
- }
- };
-
- return obj;
-})();
-
-// Nodify
-exports.MD5 = MD5;
diff --git a/contrib/jitsimeetbridge/unjingle/strophe/strophe.js b/contrib/jitsimeetbridge/unjingle/strophe/strophe.js
deleted file mode 100644
index 06d426cd..00000000
--- a/contrib/jitsimeetbridge/unjingle/strophe/strophe.js
+++ /dev/null
@@ -1,3256 +0,0 @@
-/*
- This program is distributed under the terms of the MIT license.
- Please see the LICENSE file for details.
-
- Copyright 2006-2008, OGG, LLC
-*/
-
-/* jslint configuration: */
-/*global document, window, setTimeout, clearTimeout, console,
- XMLHttpRequest, ActiveXObject,
- Base64, MD5,
- Strophe, $build, $msg, $iq, $pres */
-
-/** File: strophe.js
- * A JavaScript library for XMPP BOSH.
- *
- * This is the JavaScript version of the Strophe library. Since JavaScript
- * has no facilities for persistent TCP connections, this library uses
- * Bidirectional-streams Over Synchronous HTTP (BOSH) to emulate
- * a persistent, stateful, two-way connection to an XMPP server. More
- * information on BOSH can be found in XEP 124.
- */
-
-/** PrivateFunction: Function.prototype.bind
- * Bind a function to an instance.
- *
- * This Function object extension method creates a bound method similar
- * to those in Python. This means that the 'this' object will point
- * to the instance you want. See
- * <a href='https://developer.mozilla.org/en/JavaScript/Reference/Global_Objects/Function/bind'>MDC's bind() documentation</a> and
- * <a href='http://benjamin.smedbergs.us/blog/2007-01-03/bound-functions-and-function-imports-in-javascript/'>Bound Functions and Function Imports in JavaScript</a>
- * for a complete explanation.
- *
- * This extension already exists in some browsers (namely, Firefox 3), but
- * we provide it to support those that don't.
- *
- * Parameters:
- * (Object) obj - The object that will become 'this' in the bound function.
- * (Object) argN - An option argument that will be prepended to the
- * arguments given for the function call
- *
- * Returns:
- * The bound function.
- */
-
-/* Make it work on node.js: Nodify
- *
- * Steps:
- * 1. Create the global objects: window, document, Base64, MD5 and XMLHttpRequest
- * 2. Use the node-XMLHttpRequest module.
- * 3. Use jsdom for the document object - since it supports DOM functions.
- * 4. Replace all calls to childNodes with _childNodes (since the former doesn't
- * seem to work on jsdom).
- * 5. While getting the response from XMLHttpRequest, manually convert the text
- * data to XML.
- * 6. All calls to nodeName should replaced by nodeName.toLowerCase() since jsdom
- * seems to always convert node names to upper case.
- *
- */
-var XMLHttpRequest = require('./XMLHttpRequest.js').XMLHttpRequest;
-var Base64 = require('./base64.js').Base64;
-var MD5 = require('./md5.js').MD5;
-var jsdom = require("jsdom").jsdom;
-
-document = jsdom("<html><head></head><body></body></html>"),
-
-window = {
- XMLHttpRequest: XMLHttpRequest,
- Base64: Base64,
- MD5: MD5
-};
-
-exports.Strophe = window;
-
-if (!Function.prototype.bind) {
- Function.prototype.bind = function (obj /*, arg1, arg2, ... */)
- {
- var func = this;
- var _slice = Array.prototype.slice;
- var _concat = Array.prototype.concat;
- var _args = _slice.call(arguments, 1);
-
- return function () {
- return func.apply(obj ? obj : this,
- _concat.call(_args,
- _slice.call(arguments, 0)));
- };
- };
-}
-
-/** PrivateFunction: Array.prototype.indexOf
- * Return the index of an object in an array.
- *
- * This function is not supplied by some JavaScript implementations, so
- * we provide it if it is missing. This code is from:
- * http://developer.mozilla.org/En/Core_JavaScript_1.5_Reference:Objects:Array:indexOf
- *
- * Parameters:
- * (Object) elt - The object to look for.
- * (Integer) from - The index from which to start looking. (optional).
- *
- * Returns:
- * The index of elt in the array or -1 if not found.
- */
-if (!Array.prototype.indexOf)
-{
- Array.prototype.indexOf = function(elt /*, from*/)
- {
- var len = this.length;
-
- var from = Number(arguments[1]) || 0;
- from = (from < 0) ? Math.ceil(from) : Math.floor(from);
- if (from < 0) {
- from += len;
- }
-
- for (; from < len; from++) {
- if (from in this && this[from] === elt) {
- return from;
- }
- }
-
- return -1;
- };
-}
-
-/* All of the Strophe globals are defined in this special function below so
- * that references to the globals become closures. This will ensure that
- * on page reload, these references will still be available to callbacks
- * that are still executing.
- */
-
-(function (callback) {
-var Strophe;
-
-/** Function: $build
- * Create a Strophe.Builder.
- * This is an alias for 'new Strophe.Builder(name, attrs)'.
- *
- * Parameters:
- * (String) name - The root element name.
- * (Object) attrs - The attributes for the root element in object notation.
- *
- * Returns:
- * A new Strophe.Builder object.
- */
-function $build(name, attrs) { return new Strophe.Builder(name, attrs); }
-/** Function: $msg
- * Create a Strophe.Builder with a <message/> element as the root.
- *
- * Parmaeters:
- * (Object) attrs - The <message/> element attributes in object notation.
- *
- * Returns:
- * A new Strophe.Builder object.
- */
-function $msg(attrs) { return new Strophe.Builder("message", attrs); }
-/** Function: $iq
- * Create a Strophe.Builder with an <iq/> element as the root.
- *
- * Parameters:
- * (Object) attrs - The <iq/> element attributes in object notation.
- *
- * Returns:
- * A new Strophe.Builder object.
- */
-function $iq(attrs) { return new Strophe.Builder("iq", attrs); }
-/** Function: $pres
- * Create a Strophe.Builder with a <presence/> element as the root.
- *
- * Parameters:
- * (Object) attrs - The <presence/> element attributes in object notation.
- *
- * Returns:
- * A new Strophe.Builder object.
- */
-function $pres(attrs) { return new Strophe.Builder("presence", attrs); }
-
-/** Class: Strophe
- * An object container for all Strophe library functions.
- *
- * This class is just a container for all the objects and constants
- * used in the library. It is not meant to be instantiated, but to
- * provide a namespace for library objects, constants, and functions.
- */
-Strophe = {
- /** Constant: VERSION
- * The version of the Strophe library. Unreleased builds will have
- * a version of head-HASH where HASH is a partial revision.
- */
- VERSION: "@VERSION@",
-
- /** Constants: XMPP Namespace Constants
- * Common namespace constants from the XMPP RFCs and XEPs.
- *
- * NS.HTTPBIND - HTTP BIND namespace from XEP 124.
- * NS.BOSH - BOSH namespace from XEP 206.
- * NS.CLIENT - Main XMPP client namespace.
- * NS.AUTH - Legacy authentication namespace.
- * NS.ROSTER - Roster operations namespace.
- * NS.PROFILE - Profile namespace.
- * NS.DISCO_INFO - Service discovery info namespace from XEP 30.
- * NS.DISCO_ITEMS - Service discovery items namespace from XEP 30.
- * NS.MUC - Multi-User Chat namespace from XEP 45.
- * NS.SASL - XMPP SASL namespace from RFC 3920.
- * NS.STREAM - XMPP Streams namespace from RFC 3920.
- * NS.BIND - XMPP Binding namespace from RFC 3920.
- * NS.SESSION - XMPP Session namespace from RFC 3920.
- */
- NS: {
- HTTPBIND: "http://jabber.org/protocol/httpbind",
- BOSH: "urn:xmpp:xbosh",
- CLIENT: "jabber:client",
- AUTH: "jabber:iq:auth",
- ROSTER: "jabber:iq:roster",
- PROFILE: "jabber:iq:profile",
- DISCO_INFO: "http://jabber.org/protocol/disco#info",
- DISCO_ITEMS: "http://jabber.org/protocol/disco#items",
- MUC: "http://jabber.org/protocol/muc",
- SASL: "urn:ietf:params:xml:ns:xmpp-sasl",
- STREAM: "http://etherx.jabber.org/streams",
- BIND: "urn:ietf:params:xml:ns:xmpp-bind",
- SESSION: "urn:ietf:params:xml:ns:xmpp-session",
- VERSION: "jabber:iq:version",
- STANZAS: "urn:ietf:params:xml:ns:xmpp-stanzas"
- },
-
- /** Function: addNamespace
- * This function is used to extend the current namespaces in
- * Strophe.NS. It takes a key and a value with the key being the
- * name of the new namespace, with its actual value.
- * For example:
- * Strophe.addNamespace('PUBSUB', "http://jabber.org/protocol/pubsub");
- *
- * Parameters:
- * (String) name - The name under which the namespace will be
- * referenced under Strophe.NS
- * (String) value - The actual namespace.
- */
- addNamespace: function (name, value)
- {
- Strophe.NS[name] = value;
- },
-
- /** Constants: Connection Status Constants
- * Connection status constants for use by the connection handler
- * callback.
- *
- * Status.ERROR - An error has occurred
- * Status.CONNECTING - The connection is currently being made
- * Status.CONNFAIL - The connection attempt failed
- * Status.AUTHENTICATING - The connection is authenticating
- * Status.AUTHFAIL - The authentication attempt failed
- * Status.CONNECTED - The connection has succeeded
- * Status.DISCONNECTED - The connection has been terminated
- * Status.DISCONNECTING - The connection is currently being terminated
- * Status.ATTACHED - The connection has been attached
- */
- Status: {
- ERROR: 0,
- CONNECTING: 1,
- CONNFAIL: 2,
- AUTHENTICATING: 3,
- AUTHFAIL: 4,
- CONNECTED: 5,
- DISCONNECTED: 6,
- DISCONNECTING: 7,
- ATTACHED: 8
- },
-
- /** Constants: Log Level Constants
- * Logging level indicators.
- *
- * LogLevel.DEBUG - Debug output
- * LogLevel.INFO - Informational output
- * LogLevel.WARN - Warnings
- * LogLevel.ERROR - Errors
- * LogLevel.FATAL - Fatal errors
- */
- LogLevel: {
- DEBUG: 0,
- INFO: 1,
- WARN: 2,
- ERROR: 3,
- FATAL: 4
- },
-
- /** PrivateConstants: DOM Element Type Constants
- * DOM element types.
- *
- * ElementType.NORMAL - Normal element.
- * ElementType.TEXT - Text data element.
- */
- ElementType: {
- NORMAL: 1,
- TEXT: 3
- },
-
- /** PrivateConstants: Timeout Values
- * Timeout values for error states. These values are in seconds.
- * These should not be changed unless you know exactly what you are
- * doing.
- *
- * TIMEOUT - Timeout multiplier. A waiting request will be considered
- * failed after Math.floor(TIMEOUT * wait) seconds have elapsed.
- * This defaults to 1.1, and with default wait, 66 seconds.
- * SECONDARY_TIMEOUT - Secondary timeout multiplier. In cases where
- * Strophe can detect early failure, it will consider the request
- * failed if it doesn't return after
- * Math.floor(SECONDARY_TIMEOUT * wait) seconds have elapsed.
- * This defaults to 0.1, and with default wait, 6 seconds.
- */
- TIMEOUT: 1.1,
- SECONDARY_TIMEOUT: 0.1,
-
- /** Function: forEachChild
- * Map a function over some or all child elements of a given element.
- *
- * This is a small convenience function for mapping a function over
- * some or all of the children of an element. If elemName is null, all
- * children will be passed to the function, otherwise only children
- * whose tag names match elemName will be passed.
- *
- * Parameters:
- * (XMLElement) elem - The element to operate on.
- * (String) elemName - The child element tag name filter.
- * (Function) func - The function to apply to each child. This
- * function should take a single argument, a DOM element.
- */
- forEachChild: function (elem, elemName, func)
- {
- var i, childNode;
-
- for (i = 0; i < elem._childNodes.length; i++) {
- childNode = elem._childNodes[i];
- if (childNode.nodeType == Strophe.ElementType.NORMAL &&
- (!elemName || this.isTagEqual(childNode, elemName))) {
- func(childNode);
- }
- }
- },
-
- /** Function: isTagEqual
- * Compare an element's tag name with a string.
- *
- * This function is case insensitive.
- *
- * Parameters:
- * (XMLElement) el - A DOM element.
- * (String) name - The element name.
- *
- * Returns:
- * true if the element's tag name matches _el_, and false
- * otherwise.
- */
- isTagEqual: function (el, name)
- {
- return el.tagName.toLowerCase() == name.toLowerCase();
- },
-
- /** PrivateVariable: _xmlGenerator
- * _Private_ variable that caches a DOM document to
- * generate elements.
- */
- _xmlGenerator: null,
-
- /** PrivateFunction: _makeGenerator
- * _Private_ function that creates a dummy XML DOM document to serve as
- * an element and text node generator.
- */
- _makeGenerator: function () {
- var doc;
-
- if (window.ActiveXObject) {
- doc = this._getIEXmlDom();
- doc.appendChild(doc.createElement('strophe'));
- } else {
- doc = document.implementation
- .createDocument('jabber:client', 'strophe', null);
- }
-
- return doc;
- },
-
- /** Function: xmlGenerator
- * Get the DOM document to generate elements.
- *
- * Returns:
- * The currently used DOM document.
- */
- xmlGenerator: function () {
- if (!Strophe._xmlGenerator) {
- Strophe._xmlGenerator = Strophe._makeGenerator();
- }
- return Strophe._xmlGenerator;
- },
-
- /** PrivateFunction: _getIEXmlDom
- * Gets IE xml doc object
- *
- * Returns:
- * A Microsoft XML DOM Object
- * See Also:
- * http://msdn.microsoft.com/en-us/library/ms757837%28VS.85%29.aspx
- */
- _getIEXmlDom : function() {
- var doc = null;
- var docStrings = [
- "Msxml2.DOMDocument.6.0",
- "Msxml2.DOMDocument.5.0",
- "Msxml2.DOMDocument.4.0",
- "MSXML2.DOMDocument.3.0",
- "MSXML2.DOMDocument",
- "MSXML.DOMDocument",
- "Microsoft.XMLDOM"
- ];
-
- for (var d = 0; d < docStrings.length; d++) {
- if (doc === null) {
- try {
- doc = new ActiveXObject(docStrings[d]);
- } catch (e) {
- doc = null;
- }
- } else {
- break;
- }
- }
-
- return doc;
- },
-
- /** Function: xmlElement
- * Create an XML DOM element.
- *
- * This function creates an XML DOM element correctly across all
- * implementations. Note that these are not HTML DOM elements, which
- * aren't appropriate for XMPP stanzas.
- *
- * Parameters:
- * (String) name - The name for the element.
- * (Array|Object) attrs - An optional array or object containing
- * key/value pairs to use as element attributes. The object should
- * be in the format {'key': 'value'} or {key: 'value'}. The array
- * should have the format [['key1', 'value1'], ['key2', 'value2']].
- * (String) text - The text child data for the element.
- *
- * Returns:
- * A new XML DOM element.
- */
- xmlElement: function (name)
- {
- if (!name) { return null; }
-
- var node = Strophe.xmlGenerator().createElement(name);
-
- // FIXME: this should throw errors if args are the wrong type or
- // there are more than two optional args
- var a, i, k;
- for (a = 1; a < arguments.length; a++) {
- if (!arguments[a]) { continue; }
- if (typeof(arguments[a]) == "string" ||
- typeof(arguments[a]) == "number") {
- node.appendChild(Strophe.xmlTextNode(arguments[a]));
- } else if (typeof(arguments[a]) == "object" &&
- typeof(arguments[a].sort) == "function") {
- for (i = 0; i < arguments[a].length; i++) {
- if (typeof(arguments[a][i]) == "object" &&
- typeof(arguments[a][i].sort) == "function") {
- node.setAttribute(arguments[a][i][0],
- arguments[a][i][1]);
- }
- }
- } else if (typeof(arguments[a]) == "object") {
- for (k in arguments[a]) {
- if (arguments[a].hasOwnProperty(k)) {
- node.setAttribute(k, arguments[a][k]);
- }
- }
- }
- }
-
- return node;
- },
-
- /* Function: xmlescape
- * Excapes invalid xml characters.
- *
- * Parameters:
- * (String) text - text to escape.
- *
- * Returns:
- * Escaped text.
- */
- xmlescape: function(text)
- {
- text = text.replace(/\&/g, "&amp;");
- text = text.replace(/</g, "&lt;");
- text = text.replace(/>/g, "&gt;");
- return text;
- },
-
- /** Function: xmlTextNode
- * Creates an XML DOM text node.
- *
- * Provides a cross implementation version of document.createTextNode.
- *
- * Parameters:
- * (String) text - The content of the text node.
- *
- * Returns:
- * A new XML DOM text node.
- */
- xmlTextNode: function (text)
- {
- //ensure text is escaped
- text = Strophe.xmlescape(text);
-
- return Strophe.xmlGenerator().createTextNode(text);
- },
-
- /** Function: getText
- * Get the concatenation of all text children of an element.
- *
- * Parameters:
- * (XMLElement) elem - A DOM element.
- *
- * Returns:
- * A String with the concatenated text of all text element children.
- */
- getText: function (elem)
- {
- if (!elem) { return null; }
-
- var str = "";
- if (elem._childNodes.length === 0 && elem.nodeType ==
- Strophe.ElementType.TEXT) {
- str += elem.nodeValue;
- }
-
- for (var i = 0; i < elem._childNodes.length; i++) {
- if (elem._childNodes[i].nodeType == Strophe.ElementType.TEXT) {
- str += elem._childNodes[i].nodeValue;
- }
- }
-
- return str;
- },
-
- /** Function: copyElement
- * Copy an XML DOM element.
- *
- * This function copies a DOM element and all its descendants and returns
- * the new copy.
- *
- * Parameters:
- * (XMLElement) elem - A DOM element.
- *
- * Returns:
- * A new, copied DOM element tree.
- */
- copyElement: function (elem)
- {
- var i, el;
- if (elem.nodeType == Strophe.ElementType.NORMAL) {
- el = Strophe.xmlElement(elem.tagName);
-
- for (i = 0; i < elem.attributes.length; i++) {
- el.setAttribute(elem.attributes[i].nodeName.toLowerCase(),
- elem.attributes[i].value);
- }
-
- for (i = 0; i < elem._childNodes.length; i++) {
- el.appendChild(Strophe.copyElement(elem._childNodes[i]));
- }
- } else if (elem.nodeType == Strophe.ElementType.TEXT) {
- el = Strophe.xmlTextNode(elem.nodeValue);
- }
-
- return el;
- },
-
- /** Function: escapeNode
- * Escape the node part (also called local part) of a JID.
- *
- * Parameters:
- * (String) node - A node (or local part).
- *
- * Returns:
- * An escaped node (or local part).
- */
- escapeNode: function (node)
- {
- return node.replace(/^\s+|\s+$/g, '')
- .replace(/\\/g, "\\5c")
- .replace(/ /g, "\\20")
- .replace(/\"/g, "\\22")
- .replace(/\&/g, "\\26")
- .replace(/\'/g, "\\27")
- .replace(/\//g, "\\2f")
- .replace(/:/g, "\\3a")
- .replace(/</g, "\\3c")
- .replace(/>/g, "\\3e")
- .replace(/@/g, "\\40");
- },
-
- /** Function: unescapeNode
- * Unescape a node part (also called local part) of a JID.
- *
- * Parameters:
- * (String) node - A node (or local part).
- *
- * Returns:
- * An unescaped node (or local part).
- */
- unescapeNode: function (node)
- {
- return node.replace(/\\20/g, " ")
- .replace(/\\22/g, '"')
- .replace(/\\26/g, "&")
- .replace(/\\27/g, "'")
- .replace(/\\2f/g, "/")
- .replace(/\\3a/g, ":")
- .replace(/\\3c/g, "<")
- .replace(/\\3e/g, ">")
- .replace(/\\40/g, "@")
- .replace(/\\5c/g, "\\");
- },
-
- /** Function: getNodeFromJid
- * Get the node portion of a JID String.
- *
- * Parameters:
- * (String) jid - A JID.
- *
- * Returns:
- * A String containing the node.
- */
- getNodeFromJid: function (jid)
- {
- if (jid.indexOf("@") < 0) { return null; }
- return jid.split("@")[0];
- },
-
- /** Function: getDomainFromJid
- * Get the domain portion of a JID String.
- *
- * Parameters:
- * (String) jid - A JID.
- *
- * Returns:
- * A String containing the domain.
- */
- getDomainFromJid: function (jid)
- {
- var bare = Strophe.getBareJidFromJid(jid);
- if (bare.indexOf("@") < 0) {
- return bare;
- } else {
- var parts = bare.split("@");
- parts.splice(0, 1);
- return parts.join('@');
- }
- },
-
- /** Function: getResourceFromJid
- * Get the resource portion of a JID String.
- *
- * Parameters:
- * (String) jid - A JID.
- *
- * Returns:
- * A String containing the resource.
- */
- getResourceFromJid: function (jid)
- {
- var s = jid.split("/");
- if (s.length < 2) { return null; }
- s.splice(0, 1);
- return s.join('/');
- },
-
- /** Function: getBareJidFromJid
- * Get the bare JID from a JID String.
- *
- * Parameters:
- * (String) jid - A JID.
- *
- * Returns:
- * A String containing the bare JID.
- */
- getBareJidFromJid: function (jid)
- {
- return jid ? jid.split("/")[0] : null;
- },
-
- /** Function: log
- * User overrideable logging function.
- *
- * This function is called whenever the Strophe library calls any
- * of the logging functions. The default implementation of this
- * function does nothing. If client code wishes to handle the logging
- * messages, it should override this with
- * > Strophe.log = function (level, msg) {
- * > (user code here)
- * > };
- *
- * Please note that data sent and received over the wire is logged
- * via Strophe.Connection.rawInput() and Strophe.Connection.rawOutput().
- *
- * The different levels and their meanings are
- *
- * DEBUG - Messages useful for debugging purposes.
- * INFO - Informational messages. This is mostly information like
- * 'disconnect was called' or 'SASL auth succeeded'.
- * WARN - Warnings about potential problems. This is mostly used
- * to report transient connection errors like request timeouts.
- * ERROR - Some error occurred.
- * FATAL - A non-recoverable fatal error occurred.
- *
- * Parameters:
- * (Integer) level - The log level of the log message. This will
- * be one of the values in Strophe.LogLevel.
- * (String) msg - The log message.
- */
- log: function (level, msg)
- {
- return;
- },
-
- /** Function: debug
- * Log a message at the Strophe.LogLevel.DEBUG level.
- *
- * Parameters:
- * (String) msg - The log message.
- */
- debug: function(msg)
- {
- this.log(this.LogLevel.DEBUG, msg);
- },
-
- /** Function: info
- * Log a message at the Strophe.LogLevel.INFO level.
- *
- * Parameters:
- * (String) msg - The log message.
- */
- info: function (msg)
- {
- this.log(this.LogLevel.INFO, msg);
- },
-
- /** Function: warn
- * Log a message at the Strophe.LogLevel.WARN level.
- *
- * Parameters:
- * (String) msg - The log message.
- */
- warn: function (msg)
- {
- this.log(this.LogLevel.WARN, msg);
- },
-
- /** Function: error
- * Log a message at the Strophe.LogLevel.ERROR level.
- *
- * Parameters:
- * (String) msg - The log message.
- */
- error: function (msg)
- {
- this.log(this.LogLevel.ERROR, msg);
- },
-
- /** Function: fatal
- * Log a message at the Strophe.LogLevel.FATAL level.
- *
- * Parameters:
- * (String) msg - The log message.
- */
- fatal: function (msg)
- {
- this.log(this.LogLevel.FATAL, msg);
- },
-
- /** Function: serialize
- * Render a DOM element and all descendants to a String.
- *
- * Parameters:
- * (XMLElement) elem - A DOM element.
- *
- * Returns:
- * The serialized element tree as a String.
- */
- serialize: function (elem)
- {
- var result;
-
- if (!elem) { return null; }
-
- if (typeof(elem.tree) === "function") {
- elem = elem.tree();
- }
-
- var nodeName = elem.nodeName.toLowerCase();
- var i, child;
-
- if (elem.getAttribute("_realname")) {
- nodeName = elem.getAttribute("_realname").toLowerCase();
- }
-
- result = "<" + nodeName.toLowerCase();
- for (i = 0; i < elem.attributes.length; i++) {
- if(elem.attributes[i].nodeName.toLowerCase() != "_realname") {
- result += " " + elem.attributes[i].nodeName.toLowerCase() +
- "='" + elem.attributes[i].value
- .replace(/&/g, "&amp;")
- .replace(/\'/g, "&apos;")
- .replace(/</g, "&lt;") + "'";
- }
- }
-
- if (elem._childNodes.length > 0) {
- result += ">";
- for (i = 0; i < elem._childNodes.length; i++) {
- child = elem._childNodes[i];
- if (child.nodeType == Strophe.ElementType.NORMAL) {
- // normal element, so recurse
- result += Strophe.serialize(child);
- } else if (child.nodeType == Strophe.ElementType.TEXT) {
- // text element
- result += child.nodeValue;
- }
- }
- result += "</" + nodeName.toLowerCase() + ">";
- } else {
- result += "/>";
- }
-
- return result;
- },
-
- /** PrivateVariable: _requestId
- * _Private_ variable that keeps track of the request ids for
- * connections.
- */
- _requestId: 0,
-
- /** PrivateVariable: Strophe.connectionPlugins
- * _Private_ variable Used to store plugin names that need
- * initialization on Strophe.Connection construction.
- */
- _connectionPlugins: {},
-
- /** Function: addConnectionPlugin
- * Extends the Strophe.Connection object with the given plugin.
- *
- * Paramaters:
- * (String) name - The name of the extension.
- * (Object) ptype - The plugin's prototype.
- */
- addConnectionPlugin: function (name, ptype)
- {
- Strophe._connectionPlugins[name] = ptype;
- }
-};
-
-/** Class: Strophe.Builder
- * XML DOM builder.
- *
- * This object provides an interface similar to JQuery but for building
- * DOM element easily and rapidly. All the functions except for toString()
- * and tree() return the object, so calls can be chained. Here's an
- * example using the $iq() builder helper.
- * > $iq({to: 'you', from: 'me', type: 'get', id: '1'})
- * > .c('query', {xmlns: 'strophe:example'})
- * > .c('example')
- * > .toString()
- * The above generates this XML fragment
- * > <iq to='you' from='me' type='get' id='1'>
- * > <query xmlns='strophe:example'>
- * > <example/>
- * > </query>
- * > </iq>
- * The corresponding DOM manipulations to get a similar fragment would be
- * a lot more tedious and probably involve several helper variables.
- *
- * Since adding children makes new operations operate on the child, up()
- * is provided to traverse up the tree. To add two children, do
- * > builder.c('child1', ...).up().c('child2', ...)
- * The next operation on the Builder will be relative to the second child.
- */
-
-/** Constructor: Strophe.Builder
- * Create a Strophe.Builder object.
- *
- * The attributes should be passed in object notation. For example
- * > var b = new Builder('message', {to: 'you', from: 'me'});
- * or
- * > var b = new Builder('messsage', {'xml:lang': 'en'});
- *
- * Parameters:
- * (String) name - The name of the root element.
- * (Object) attrs - The attributes for the root element in object notation.
- *
- * Returns:
- * A new Strophe.Builder.
- */
-Strophe.Builder = function (name, attrs)
-{
- // Set correct namespace for jabber:client elements
- if (name == "presence" || name == "message" || name == "iq") {
- if (attrs && !attrs.xmlns) {
- attrs.xmlns = Strophe.NS.CLIENT;
- } else if (!attrs) {
- attrs = {xmlns: Strophe.NS.CLIENT};
- }
- }
-
- // Holds the tree being built.
- this.nodeTree = Strophe.xmlElement(name, attrs);
-
- // Points to the current operation node.
- this.node = this.nodeTree;
-};
-
-Strophe.Builder.prototype = {
- /** Function: tree
- * Return the DOM tree.
- *
- * This function returns the current DOM tree as an element object. This
- * is suitable for passing to functions like Strophe.Connection.send().
- *
- * Returns:
- * The DOM tree as a element object.
- */
- tree: function ()
- {
- return this.nodeTree;
- },
-
- /** Function: toString
- * Serialize the DOM tree to a String.
- *
- * This function returns a string serialization of the current DOM
- * tree. It is often used internally to pass data to a
- * Strophe.Request object.
- *
- * Returns:
- * The serialized DOM tree in a String.
- */
- toString: function ()
- {
- return Strophe.serialize(this.nodeTree);
- },
-
- /** Function: up
- * Make the current parent element the new current element.
- *
- * This function is often used after c() to traverse back up the tree.
- * For example, to add two children to the same element
- * > builder.c('child1', {}).up().c('child2', {});
- *
- * Returns:
- * The Stophe.Builder object.
- */
- up: function ()
- {
- this.node = this.node.parentNode;
- return this;
- },
-
- /** Function: attrs
- * Add or modify attributes of the current element.
- *
- * The attributes should be passed in object notation. This function
- * does not move the current element pointer.
- *
- * Parameters:
- * (Object) moreattrs - The attributes to add/modify in object notation.
- *
- * Returns:
- * The Strophe.Builder object.
- */
- attrs: function (moreattrs)
- {
- for (var k in moreattrs) {
- if (moreattrs.hasOwnProperty(k)) {
- this.node.setAttribute(k, moreattrs[k]);
- }
- }
- return this;
- },
-
- /** Function: c
- * Add a child to the current element and make it the new current
- * element.
- *
- * This function moves the current element pointer to the child. If you
- * need to add another child, it is necessary to use up() to go back
- * to the parent in the tree.
- *
- * Parameters:
- * (String) name - The name of the child.
- * (Object) attrs - The attributes of the child in object notation.
- *
- * Returns:
- * The Strophe.Builder object.
- */
- c: function (name, attrs)
- {
- var child = Strophe.xmlElement(name, attrs);
- this.node.appendChild(child);
- this.node = child;
- return this;
- },
-
- /** Function: cnode
- * Add a child to the current element and make it the new current
- * element.
- *
- * This function is the same as c() except that instead of using a
- * name and an attributes object to create the child it uses an
- * existing DOM element object.
- *
- * Parameters:
- * (XMLElement) elem - A DOM element.
- *
- * Returns:
- * The Strophe.Builder object.
- */
- cnode: function (elem)
- {
- var xmlGen = Strophe.xmlGenerator();
- var newElem = xmlGen.importNode ? xmlGen.importNode(elem, true) : Strophe.copyElement(elem);
- this.node.appendChild(newElem);
- this.node = newElem;
- return this;
- },
-
- /** Function: t
- * Add a child text element.
- *
- * This *does not* make the child the new current element since there
- * are no children of text elements.
- *
- * Parameters:
- * (String) text - The text data to append to the current element.
- *
- * Returns:
- * The Strophe.Builder object.
- */
- t: function (text)
- {
- var child = Strophe.xmlTextNode(text);
- this.node.appendChild(child);
- return this;
- }
-};
-
-
-/** PrivateClass: Strophe.Handler
- * _Private_ helper class for managing stanza handlers.
- *
- * A Strophe.Handler encapsulates a user provided callback function to be
- * executed when matching stanzas are received by the connection.
- * Handlers can be either one-off or persistant depending on their
- * return value. Returning true will cause a Handler to remain active, and
- * returning false will remove the Handler.
- *
- * Users will not use Strophe.Handler objects directly, but instead they
- * will use Strophe.Connection.addHandler() and
- * Strophe.Connection.deleteHandler().
- */
-
-/** PrivateConstructor: Strophe.Handler
- * Create and initialize a new Strophe.Handler.
- *
- * Parameters:
- * (Function) handler - A function to be executed when the handler is run.
- * (String) ns - The namespace to match.
- * (String) name - The element name to match.
- * (String) type - The element type to match.
- * (String) id - The element id attribute to match.
- * (String) from - The element from attribute to match.
- * (Object) options - Handler options
- *
- * Returns:
- * A new Strophe.Handler object.
- */
-Strophe.Handler = function (handler, ns, name, type, id, from, options)
-{
- this.handler = handler;
- this.ns = ns;
- this.name = name;
- this.type = type;
- this.id = id;
- this.options = options || {matchbare: false};
-
- // default matchBare to false if undefined
- if (!this.options.matchBare) {
- this.options.matchBare = false;
- }
-
- if (this.options.matchBare) {
- this.from = from ? Strophe.getBareJidFromJid(from) : null;
- } else {
- this.from = from;
- }
-
- // whether the handler is a user handler or a system handler
- this.user = true;
-};
-
-Strophe.Handler.prototype = {
- /** PrivateFunction: isMatch
- * Tests if a stanza matches the Strophe.Handler.
- *
- * Parameters:
- * (XMLElement) elem - The XML element to test.
- *
- * Returns:
- * true if the stanza matches and false otherwise.
- */
- isMatch: function (elem)
- {
- var nsMatch;
- var from = null;
-
- if (this.options.matchBare) {
- from = Strophe.getBareJidFromJid(elem.getAttribute('from'));
- } else {
- from = elem.getAttribute('from');
- }
-
- nsMatch = false;
- if (!this.ns) {
- nsMatch = true;
- } else {
- var that = this;
- Strophe.forEachChild(elem, null, function (elem) {
- if (elem.getAttribute("xmlns") == that.ns) {
- nsMatch = true;
- }
- });
-
- nsMatch = nsMatch || elem.getAttribute("xmlns") == this.ns;
- }
-
- if (nsMatch &&
- (!this.name || Strophe.isTagEqual(elem, this.name)) &&
- (!this.type || elem.getAttribute("type") == this.type) &&
- (!this.id || elem.getAttribute("id") == this.id) &&
- (!this.from || from == this.from)) {
- return true;
- }
-
- return false;
- },
-
- /** PrivateFunction: run
- * Run the callback on a matching stanza.
- *
- * Parameters:
- * (XMLElement) elem - The DOM element that triggered the
- * Strophe.Handler.
- *
- * Returns:
- * A boolean indicating if the handler should remain active.
- */
- run: function (elem)
- {
- var result = null;
- try {
- result = this.handler(elem);
- } catch (e) {
- if (e.sourceURL) {
- Strophe.fatal("error: " + this.handler +
- " " + e.sourceURL + ":" +
- e.line + " - " + e.name + ": " + e.message);
- } else if (e.fileName) {
- if (typeof(console) != "undefined") {
- console.trace();
- console.error(this.handler, " - error - ", e, e.message);
- }
- Strophe.fatal("error: " + this.handler + " " +
- e.fileName + ":" + e.lineNumber + " - " +
- e.name + ": " + e.message);
- } else {
- Strophe.fatal("error: " + this.handler);
- }
-
- throw e;
- }
-
- return result;
- },
-
- /** PrivateFunction: toString
- * Get a String representation of the Strophe.Handler object.
- *
- * Returns:
- * A String.
- */
- toString: function ()
- {
- return "{Handler: " + this.handler + "(" + this.name + "," +
- this.id + "," + this.ns + ")}";
- }
-};
-
-/** PrivateClass: Strophe.TimedHandler
- * _Private_ helper class for managing timed handlers.
- *
- * A Strophe.TimedHandler encapsulates a user provided callback that
- * should be called after a certain period of time or at regular
- * intervals. The return value of the callback determines whether the
- * Strophe.TimedHandler will continue to fire.
- *
- * Users will not use Strophe.TimedHandler objects directly, but instead
- * they will use Strophe.Connection.addTimedHandler() and
- * Strophe.Connection.deleteTimedHandler().
- */
-
-/** PrivateConstructor: Strophe.TimedHandler
- * Create and initialize a new Strophe.TimedHandler object.
- *
- * Parameters:
- * (Integer) period - The number of milliseconds to wait before the
- * handler is called.
- * (Function) handler - The callback to run when the handler fires. This
- * function should take no arguments.
- *
- * Returns:
- * A new Strophe.TimedHandler object.
- */
-Strophe.TimedHandler = function (period, handler)
-{
- this.period = period;
- this.handler = handler;
-
- this.lastCalled = new Date().getTime();
- this.user = true;
-};
-
-Strophe.TimedHandler.prototype = {
- /** PrivateFunction: run
- * Run the callback for the Strophe.TimedHandler.
- *
- * Returns:
- * true if the Strophe.TimedHandler should be called again, and false
- * otherwise.
- */
- run: function ()
- {
- this.lastCalled = new Date().getTime();
- return this.handler();
- },
-
- /** PrivateFunction: reset
- * Reset the last called time for the Strophe.TimedHandler.
- */
- reset: function ()
- {
- this.lastCalled = new Date().getTime();
- },
-
- /** PrivateFunction: toString
- * Get a string representation of the Strophe.TimedHandler object.
- *
- * Returns:
- * The string representation.
- */
- toString: function ()
- {
- return "{TimedHandler: " + this.handler + "(" + this.period +")}";
- }
-};
-
-/** PrivateClass: Strophe.Request
- * _Private_ helper class that provides a cross implementation abstraction
- * for a BOSH related XMLHttpRequest.
- *
- * The Strophe.Request class is used internally to encapsulate BOSH request
- * information. It is not meant to be used from user's code.
- */
-
-/** PrivateConstructor: Strophe.Request
- * Create and initialize a new Strophe.Request object.
- *
- * Parameters:
- * (XMLElement) elem - The XML data to be sent in the request.
- * (Function) func - The function that will be called when the
- * XMLHttpRequest readyState changes.
- * (Integer) rid - The BOSH rid attribute associated with this request.
- * (Integer) sends - The number of times this same request has been
- * sent.
- */
-Strophe.Request = function (elem, func, rid, sends)
-{
- this.id = ++Strophe._requestId;
- this.xmlData = elem;
- this.data = Strophe.serialize(elem);
- // save original function in case we need to make a new request
- // from this one.
- this.origFunc = func;
- this.func = func;
- this.rid = rid;
- this.date = NaN;
- this.sends = sends || 0;
- this.abort = false;
- this.dead = null;
- this.age = function () {
- if (!this.date) { return 0; }
- var now = new Date();
- return (now - this.date) / 1000;
- };
- this.timeDead = function () {
- if (!this.dead) { return 0; }
- var now = new Date();
- return (now - this.dead) / 1000;
- };
- this.xhr = this._newXHR();
-};
-
-Strophe.Request.prototype = {
- /** PrivateFunction: getResponse
- * Get a response from the underlying XMLHttpRequest.
- *
- * This function attempts to get a response from the request and checks
- * for errors.
- *
- * Throws:
- * "parsererror" - A parser error occured.
- *
- * Returns:
- * The DOM element tree of the response.
- */
- getResponse: function ()
- {
- // console.log("getResponse:", this.xhr.responseXML, ":", this.xhr.responseText);
-
- var node = null;
- if (this.xhr.responseXML && this.xhr.responseXML.documentElement) {
- node = this.xhr.responseXML.documentElement;
- if (node.tagName == "parsererror") {
- Strophe.error("invalid response received");
- Strophe.error("responseText: " + this.xhr.responseText);
- Strophe.error("responseXML: " +
- Strophe.serialize(this.xhr.responseXML));
- throw "parsererror";
- }
- } else if (this.xhr.responseText) {
- // Hack for node.
- var _div = document.createElement("div");
- _div.innerHTML = this.xhr.responseText;
- node = _div._childNodes[0];
-
- Strophe.error("invalid response received");
- Strophe.error("responseText: " + this.xhr.responseText);
- Strophe.error("responseXML: " +
- Strophe.serialize(this.xhr.responseXML));
- }
-
- return node;
- },
-
- /** PrivateFunction: _newXHR
- * _Private_ helper function to create XMLHttpRequests.
- *
- * This function creates XMLHttpRequests across all implementations.
- *
- * Returns:
- * A new XMLHttpRequest.
- */
- _newXHR: function ()
- {
- var xhr = null;
- if (window.XMLHttpRequest) {
- xhr = new XMLHttpRequest();
- if (xhr.overrideMimeType) {
- xhr.overrideMimeType("text/xml");
- }
- } else if (window.ActiveXObject) {
- xhr = new ActiveXObject("Microsoft.XMLHTTP");
- }
-
- // use Function.bind() to prepend ourselves as an argument
- xhr.onreadystatechange = this.func.bind(null, this);
-
- return xhr;
- }
-};
-
-/** Class: Strophe.Connection
- * XMPP Connection manager.
- *
- * Thie class is the main part of Strophe. It manages a BOSH connection
- * to an XMPP server and dispatches events to the user callbacks as
- * data arrives. It supports SASL PLAIN, SASL DIGEST-MD5, and legacy
- * authentication.
- *
- * After creating a Strophe.Connection object, the user will typically
- * call connect() with a user supplied callback to handle connection level
- * events like authentication failure, disconnection, or connection
- * complete.
- *
- * The user will also have several event handlers defined by using
- * addHandler() and addTimedHandler(). These will allow the user code to
- * respond to interesting stanzas or do something periodically with the
- * connection. These handlers will be active once authentication is
- * finished.
- *
- * To send data to the connection, use send().
- */
-
-/** Constructor: Strophe.Connection
- * Create and initialize a Strophe.Connection object.
- *
- * Parameters:
- * (String) service - The BOSH service URL.
- *
- * Returns:
- * A new Strophe.Connection object.
- */
-Strophe.Connection = function (service)
-{
- /* The path to the httpbind service. */
- this.service = service;
- /* The connected JID. */
- this.jid = "";
- /* request id for body tags */
- this.rid = Math.floor(Math.random() * 4294967295);
- /* The current session ID. */
- this.sid = null;
- this.streamId = null;
- /* stream:features */
- this.features = null;
-
- // SASL
- this.do_session = false;
- this.do_bind = false;
-
- // handler lists
- this.timedHandlers = [];
- this.handlers = [];
- this.removeTimeds = [];
- this.removeHandlers = [];
- this.addTimeds = [];
- this.addHandlers = [];
-
- this._idleTimeout = null;
- this._disconnectTimeout = null;
-
- this.authenticated = false;
- this.disconnecting = false;
- this.connected = false;
-
- this.errors = 0;
-
- this.paused = false;
-
- // default BOSH values
- this.hold = 1;
- this.wait = 60;
- this.window = 5;
-
- this._data = [];
- this._requests = [];
- this._uniqueId = Math.round(Math.random() * 10000);
-
- this._sasl_success_handler = null;
- this._sasl_failure_handler = null;
- this._sasl_challenge_handler = null;
-
- // setup onIdle callback every 1/10th of a second
- this._idleTimeout = setTimeout(this._onIdle.bind(this), 100);
-
- // initialize plugins
- for (var k in Strophe._connectionPlugins) {
- if (Strophe._connectionPlugins.hasOwnProperty(k)) {
- var ptype = Strophe._connectionPlugins[k];
- // jslint complaints about the below line, but this is fine
- var F = function () {};
- F.prototype = ptype;
- this[k] = new F();
- this[k].init(this);
- }
- }
-};
-
-Strophe.Connection.prototype = {
- /** Function: reset
- * Reset the connection.
- *
- * This function should be called after a connection is disconnected
- * before that connection is reused.
- */
- reset: function ()
- {
- this.rid = Math.floor(Math.random() * 4294967295);
-
- this.sid = null;
- this.streamId = null;
-
- // SASL
- this.do_session = false;
- this.do_bind = false;
-
- // handler lists
- this.timedHandlers = [];
- this.handlers = [];
- this.removeTimeds = [];
- this.removeHandlers = [];
- this.addTimeds = [];
- this.addHandlers = [];
-
- this.authenticated = false;
- this.disconnecting = false;
- this.connected = false;
-
- this.errors = 0;
-
- this._requests = [];
- this._uniqueId = Math.round(Math.random()*10000);
- },
-
- /** Function: pause
- * Pause the request manager.
- *
- * This will prevent Strophe from sending any more requests to the
- * server. This is very useful for temporarily pausing while a lot
- * of send() calls are happening quickly. This causes Strophe to
- * send the data in a single request, saving many request trips.
- */
- pause: function ()
- {
- this.paused = true;
- },
-
- /** Function: resume
- * Resume the request manager.
- *
- * This resumes after pause() has been called.
- */
- resume: function ()
- {
- this.paused = false;
- },
-
- /** Function: getUniqueId
- * Generate a unique ID for use in <iq/> elements.
- *
- * All <iq/> stanzas are required to have unique id attributes. This
- * function makes creating these easy. Each connection instance has
- * a counter which starts from zero, and the value of this counter
- * plus a colon followed by the suffix becomes the unique id. If no
- * suffix is supplied, the counter is used as the unique id.
- *
- * Suffixes are used to make debugging easier when reading the stream
- * data, and their use is recommended. The counter resets to 0 for
- * every new connection for the same reason. For connections to the
- * same server that authenticate the same way, all the ids should be
- * the same, which makes it easy to see changes. This is useful for
- * automated testing as well.
- *
- * Parameters:
- * (String) suffix - A optional suffix to append to the id.
- *
- * Returns:
- * A unique string to be used for the id attribute.
- */
- getUniqueId: function (suffix)
- {
- if (typeof(suffix) == "string" || typeof(suffix) == "number") {
- return ++this._uniqueId + ":" + suffix;
- } else {
- return ++this._uniqueId + "";
- }
- },
-
- /** Function: connect
- * Starts the connection process.
- *
- * As the connection process proceeds, the user supplied callback will
- * be triggered multiple times with status updates. The callback
- * should take two arguments - the status code and the error condition.
- *
- * The status code will be one of the values in the Strophe.Status
- * constants. The error condition will be one of the conditions
- * defined in RFC 3920 or the condition 'strophe-parsererror'.
- *
- * Please see XEP 124 for a more detailed explanation of the optional
- * parameters below.
- *
- * Parameters:
- * (String) jid - The user's JID. This may be a bare JID,
- * or a full JID. If a node is not supplied, SASL ANONYMOUS
- * authentication will be attempted.
- * (String) pass - The user's password.
- * (Function) callback The connect callback function.
- * (Integer) wait - The optional HTTPBIND wait value. This is the
- * time the server will wait before returning an empty result for
- * a request. The default setting of 60 seconds is recommended.
- * Other settings will require tweaks to the Strophe.TIMEOUT value.
- * (Integer) hold - The optional HTTPBIND hold value. This is the
- * number of connections the server will hold at one time. This
- * should almost always be set to 1 (the default).
- */
- connect: function (jid, pass, callback, wait, hold, route)
- {
- this.jid = jid;
- this.pass = pass;
- this.connect_callback = callback;
- this.disconnecting = false;
- this.connected = false;
- this.authenticated = false;
- this.errors = 0;
-
- this.wait = wait || this.wait;
- this.hold = hold || this.hold;
-
- // parse jid for domain and resource
- this.domain = Strophe.getDomainFromJid(this.jid);
-
- // build the body tag
- var body_attrs = {
- to: this.domain,
- "xml:lang": "en",
- wait: this.wait,
- hold: this.hold,
- content: "text/xml; charset=utf-8",
- ver: "1.6",
- "xmpp:version": "1.0",
- "xmlns:xmpp": Strophe.NS.BOSH
- };
- if (route) {
- body_attrs.route = route;
- }
-
- var body = this._buildBody().attrs(body_attrs);
-
- this._changeConnectStatus(Strophe.Status.CONNECTING, null);
-
- this._requests.push(
- new Strophe.Request(body.tree(),
- this._onRequestStateChange.bind(
- this, this._connect_cb.bind(this)),
- body.tree().getAttribute("rid")));
- this._throttledRequestHandler();
- },
-
- /** Function: attach
- * Attach to an already created and authenticated BOSH session.
- *
- * This function is provided to allow Strophe to attach to BOSH
- * sessions which have been created externally, perhaps by a Web
- * application. This is often used to support auto-login type features
- * without putting user credentials into the page.
- *
- * Parameters:
- * (String) jid - The full JID that is bound by the session.
- * (String) sid - The SID of the BOSH session.
- * (String) rid - The current RID of the BOSH session. This RID
- * will be used by the next request.
- * (Function) callback The connect callback function.
- * (Integer) wait - The optional HTTPBIND wait value. This is the
- * time the server will wait before returning an empty result for
- * a request. The default setting of 60 seconds is recommended.
- * Other settings will require tweaks to the Strophe.TIMEOUT value.
- * (Integer) hold - The optional HTTPBIND hold value. This is the
- * number of connections the server will hold at one time. This
- * should almost always be set to 1 (the default).
- * (Integer) wind - The optional HTTBIND window value. This is the
- * allowed range of request ids that are valid. The default is 5.
- */
- attach: function (jid, sid, rid, callback, wait, hold, wind)
- {
- this.jid = jid;
- this.sid = sid;
- this.rid = rid;
- this.connect_callback = callback;
-
- this.domain = Strophe.getDomainFromJid(this.jid);
-
- this.authenticated = true;
- this.connected = true;
-
- this.wait = wait || this.wait;
- this.hold = hold || this.hold;
- this.window = wind || this.window;
-
- this._changeConnectStatus(Strophe.Status.ATTACHED, null);
- },
-
- /** Function: xmlInput
- * User overrideable function that receives XML data coming into the
- * connection.
- *
- * The default function does nothing. User code can override this with
- * > Strophe.Connection.xmlInput = function (elem) {
- * > (user code)
- * > };
- *
- * Parameters:
- * (XMLElement) elem - The XML data received by the connection.
- */
- xmlInput: function (elem)
- {
- return;
- },
-
- /** Function: xmlOutput
- * User overrideable function that receives XML data sent to the
- * connection.
- *
- * The default function does nothing. User code can override this with
- * > Strophe.Connection.xmlOutput = function (elem) {
- * > (user code)
- * > };
- *
- * Parameters:
- * (XMLElement) elem - The XMLdata sent by the connection.
- */
- xmlOutput: function (elem)
- {
- return;
- },
-
- /** Function: rawInput
- * User overrideable function that receives raw data coming into the
- * connection.
- *
- * The default function does nothing. User code can override this with
- * > Strophe.Connection.rawInput = function (data) {
- * > (user code)
- * > };
- *
- * Parameters:
- * (String) data - The data received by the connection.
- */
- rawInput: function (data)
- {
- return;
- },
-
- /** Function: rawOutput
- * User overrideable function that receives raw data sent to the
- * connection.
- *
- * The default function does nothing. User code can override this with
- * > Strophe.Connection.rawOutput = function (data) {
- * > (user code)
- * > };
- *
- * Parameters:
- * (String) data - The data sent by the connection.
- */
- rawOutput: function (data)
- {
- return;
- },
-
- /** Function: send
- * Send a stanza.
- *
- * This function is called to push data onto the send queue to
- * go out over the wire. Whenever a request is sent to the BOSH
- * server, all pending data is sent and the queue is flushed.
- *
- * Parameters:
- * (XMLElement |
- * [XMLElement] |
- * Strophe.Builder) elem - The stanza to send.
- */
- send: function (elem)
- {
- if (elem === null) { return ; }
- if (typeof(elem.sort) === "function") {
- for (var i = 0; i < elem.length; i++) {
- this._queueData(elem[i]);
- }
- } else if (typeof(elem.tree) === "function") {
- this._queueData(elem.tree());
- } else {
- this._queueData(elem);
- }
-
- this._throttledRequestHandler();
- clearTimeout(this._idleTimeout);
- this._idleTimeout = setTimeout(this._onIdle.bind(this), 100);
- },
-
- /** Function: flush
- * Immediately send any pending outgoing data.
- *
- * Normally send() queues outgoing data until the next idle period
- * (100ms), which optimizes network use in the common cases when
- * several send()s are called in succession. flush() can be used to
- * immediately send all pending data.
- */
- flush: function ()
- {
- // cancel the pending idle period and run the idle function
- // immediately
- clearTimeout(this._idleTimeout);
- this._onIdle();
- },
-
- /** Function: sendIQ
- * Helper function to send IQ stanzas.
- *
- * Parameters:
- * (XMLElement) elem - The stanza to send.
- * (Function) callback - The callback function for a successful request.
- * (Function) errback - The callback function for a failed or timed
- * out request. On timeout, the stanza will be null.
- * (Integer) timeout - The time specified in milliseconds for a
- * timeout to occur.
- *
- * Returns:
- * The id used to send the IQ.
- */
- sendIQ: function(elem, callback, errback, timeout) {
- var timeoutHandler = null;
- var that = this;
-
- if (typeof(elem.tree) === "function") {
- elem = elem.tree();
- }
- var id = elem.getAttribute('id');
-
- // inject id if not found
- if (!id) {
- id = this.getUniqueId("sendIQ");
- elem.setAttribute("id", id);
- }
-
- var handler = this.addHandler(function (stanza) {
- // remove timeout handler if there is one
- if (timeoutHandler) {
- that.deleteTimedHandler(timeoutHandler);
- }
-
- var iqtype = stanza.getAttribute('type');
- if (iqtype == 'result') {
- if (callback) {
- callback(stanza);
- }
- } else if (iqtype == 'error') {
- if (errback) {
- errback(stanza);
- }
- } else {
- throw {
- name: "StropheError",
- message: "Got bad IQ type of " + iqtype
- };
- }
- }, null, 'iq', null, id);
-
- // if timeout specified, setup timeout handler.
- if (timeout) {
- timeoutHandler = this.addTimedHandler(timeout, function () {
- // get rid of normal handler
- that.deleteHandler(handler);
-
- // call errback on timeout with null stanza
- if (errback) {
- errback(null);
- }
- return false;
- });
- }
-
- this.send(elem);
-
- return id;
- },
-
- /** PrivateFunction: _queueData
- * Queue outgoing data for later sending. Also ensures that the data
- * is a DOMElement.
- */
- _queueData: function (element) {
- if (element === null ||
- !element.tagName ||
- !element._childNodes) {
- throw {
- name: "StropheError",
- message: "Cannot queue non-DOMElement."
- };
- }
-
- this._data.push(element);
- },
-
- /** PrivateFunction: _sendRestart
- * Send an xmpp:restart stanza.
- */
- _sendRestart: function ()
- {
- this._data.push("restart");
-
- this._throttledRequestHandler();
- clearTimeout(this._idleTimeout);
- this._idleTimeout = setTimeout(this._onIdle.bind(this), 100);
- },
-
- /** Function: addTimedHandler
- * Add a timed handler to the connection.
- *
- * This function adds a timed handler. The provided handler will
- * be called every period milliseconds until it returns false,
- * the connection is terminated, or the handler is removed. Handlers
- * that wish to continue being invoked should return true.
- *
- * Because of method binding it is necessary to save the result of
- * this function if you wish to remove a handler with
- * deleteTimedHandler().
- *
- * Note that user handlers are not active until authentication is
- * successful.
- *
- * Parameters:
- * (Integer) period - The period of the handler.
- * (Function) handler - The callback function.
- *
- * Returns:
- * A reference to the handler that can be used to remove it.
- */
- addTimedHandler: function (period, handler)
- {
- var thand = new Strophe.TimedHandler(period, handler);
- this.addTimeds.push(thand);
- return thand;
- },
-
- /** Function: deleteTimedHandler
- * Delete a timed handler for a connection.
- *
- * This function removes a timed handler from the connection. The
- * handRef parameter is *not* the function passed to addTimedHandler(),
- * but is the reference returned from addTimedHandler().
- *
- * Parameters:
- * (Strophe.TimedHandler) handRef - The handler reference.
- */
- deleteTimedHandler: function (handRef)
- {
- // this must be done in the Idle loop so that we don't change
- // the handlers during iteration
- this.removeTimeds.push(handRef);
- },
-
- /** Function: addHandler
- * Add a stanza handler for the connection.
- *
- * This function adds a stanza handler to the connection. The
- * handler callback will be called for any stanza that matches
- * the parameters. Note that if multiple parameters are supplied,
- * they must all match for the handler to be invoked.
- *
- * The handler will receive the stanza that triggered it as its argument.
- * The handler should return true if it is to be invoked again;
- * returning false will remove the handler after it returns.
- *
- * As a convenience, the ns parameters applies to the top level element
- * and also any of its immediate children. This is primarily to make
- * matching /iq/query elements easy.
- *
- * The options argument contains handler matching flags that affect how
- * matches are determined. Currently the only flag is matchBare (a
- * boolean). When matchBare is true, the from parameter and the from
- * attribute on the stanza will be matched as bare JIDs instead of
- * full JIDs. To use this, pass {matchBare: true} as the value of
- * options. The default value for matchBare is false.
- *
- * The return value should be saved if you wish to remove the handler
- * with deleteHandler().
- *
- * Parameters:
- * (Function) handler - The user callback.
- * (String) ns - The namespace to match.
- * (String) name - The stanza name to match.
- * (String) type - The stanza type attribute to match.
- * (String) id - The stanza id attribute to match.
- * (String) from - The stanza from attribute to match.
- * (String) options - The handler options
- *
- * Returns:
- * A reference to the handler that can be used to remove it.
- */
- addHandler: function (handler, ns, name, type, id, from, options)
- {
- var hand = new Strophe.Handler(handler, ns, name, type, id, from, options);
- this.addHandlers.push(hand);
- return hand;
- },
-
- /** Function: deleteHandler
- * Delete a stanza handler for a connection.
- *
- * This function removes a stanza handler from the connection. The
- * handRef parameter is *not* the function passed to addHandler(),
- * but is the reference returned from addHandler().
- *
- * Parameters:
- * (Strophe.Handler) handRef - The handler reference.
- */
- deleteHandler: function (handRef)
- {
- // this must be done in the Idle loop so that we don't change
- // the handlers during iteration
- this.removeHandlers.push(handRef);
- },
-
- /** Function: disconnect
- * Start the graceful disconnection process.
- *
- * This function starts the disconnection process. This process starts
- * by sending unavailable presence and sending BOSH body of type
- * terminate. A timeout handler makes sure that disconnection happens
- * even if the BOSH server does not respond.
- *
- * The user supplied connection callback will be notified of the
- * progress as this process happens.
- *
- * Parameters:
- * (String) reason - The reason the disconnect is occuring.
- */
- disconnect: function (reason)
- {
- this._changeConnectStatus(Strophe.Status.DISCONNECTING, reason);
-
- Strophe.info("Disconnect was called because: " + reason);
- if (this.connected) {
- // setup timeout handler
- this._disconnectTimeout = this._addSysTimedHandler(
- 3000, this._onDisconnectTimeout.bind(this));
- this._sendTerminate();
- }
- },
-
- /** PrivateFunction: _changeConnectStatus
- * _Private_ helper function that makes sure plugins and the user's
- * callback are notified of connection status changes.
- *
- * Parameters:
- * (Integer) status - the new connection status, one of the values
- * in Strophe.Status
- * (String) condition - the error condition or null
- */
- _changeConnectStatus: function (status, condition)
- {
- // notify all plugins listening for status changes
- for (var k in Strophe._connectionPlugins) {
- if (Strophe._connectionPlugins.hasOwnProperty(k)) {
- var plugin = this[k];
- if (plugin.statusChanged) {
- try {
- plugin.statusChanged(status, condition);
- } catch (err) {
- Strophe.error("" + k + " plugin caused an exception " +
- "changing status: " + err);
- }
- }
- }
- }
-
- // notify the user's callback
- if (this.connect_callback) {
- try {
- this.connect_callback(status, condition);
- } catch (e) {
- Strophe.error("User connection callback caused an " +
- "exception: " + e);
- }
- }
- },
-
- /** PrivateFunction: _buildBody
- * _Private_ helper function to generate the <body/> wrapper for BOSH.
- *
- * Returns:
- * A Strophe.Builder with a <body/> element.
- */
- _buildBody: function ()
- {
- var bodyWrap = $build('body', {
- rid: this.rid++,
- xmlns: Strophe.NS.HTTPBIND
- });
-
- if (this.sid !== null) {
- bodyWrap.attrs({sid: this.sid});
- }
-
- return bodyWrap;
- },
-
- /** PrivateFunction: _removeRequest
- * _Private_ function to remove a request from the queue.
- *
- * Parameters:
- * (Strophe.Request) req - The request to remove.
- */
- _removeRequest: function (req)
- {
- Strophe.debug("removing request");
-
- var i;
- for (i = this._requests.length - 1; i >= 0; i--) {
- if (req == this._requests[i]) {
- this._requests.splice(i, 1);
- }
- }
-
- // IE6 fails on setting to null, so set to empty function
- req.xhr.onreadystatechange = function () {};
-
- this._throttledRequestHandler();
- },
-
- /** PrivateFunction: _restartRequest
- * _Private_ function to restart a request that is presumed dead.
- *
- * Parameters:
- * (Integer) i - The index of the request in the queue.
- */
- _restartRequest: function (i)
- {
- var req = this._requests[i];
- if (req.dead === null) {
- req.dead = new Date();
- }
-
- this._processRequest(i);
- },
-
- /** PrivateFunction: _processRequest
- * _Private_ function to process a request in the queue.
- *
- * This function takes requests off the queue and sends them and
- * restarts dead requests.
- *
- * Parameters:
- * (Integer) i - The index of the request in the queue.
- */
- _processRequest: function (i)
- {
- var req = this._requests[i];
- var reqStatus = -1;
-
- try {
- if (req.xhr.readyState == 4) {
- reqStatus = req.xhr.status;
- }
- } catch (e) {
- Strophe.error("caught an error in _requests[" + i +
- "], reqStatus: " + reqStatus);
- }
-
- if (typeof(reqStatus) == "undefined") {
- reqStatus = -1;
- }
-
- // make sure we limit the number of retries
- if (req.sends > 5) {
- this._onDisconnectTimeout();
- return;
- }
-
- var time_elapsed = req.age();
- var primaryTimeout = (!isNaN(time_elapsed) &&
- time_elapsed > Math.floor(Strophe.TIMEOUT * this.wait));
- var secondaryTimeout = (req.dead !== null &&
- req.timeDead() > Math.floor(Strophe.SECONDARY_TIMEOUT * this.wait));
- var requestCompletedWithServerError = (req.xhr.readyState == 4 &&
- (reqStatus < 1 ||
- reqStatus >= 500));
- if (primaryTimeout || secondaryTimeout ||
- requestCompletedWithServerError) {
- if (secondaryTimeout) {
- Strophe.error("Request " +
- this._requests[i].id +
- " timed out (secondary), restarting");
- }
- req.abort = true;
- req.xhr.abort();
- // setting to null fails on IE6, so set to empty function
- req.xhr.onreadystatechange = function () {};
- this._requests[i] = new Strophe.Request(req.xmlData,
- req.origFunc,
- req.rid,
- req.sends);
- req = this._requests[i];
- }
-
- if (req.xhr.readyState === 0) {
- Strophe.debug("request id " + req.id +
- "." + req.sends + " posting");
-
- req.date = new Date();
- try {
- req.xhr.open("POST", this.service, true);
- } catch (e2) {
- Strophe.error("XHR open failed.");
- if (!this.connected) {
- this._changeConnectStatus(Strophe.Status.CONNFAIL,
- "bad-service");
- }
- this.disconnect();
- return;
- }
-
- // Fires the XHR request -- may be invoked immediately
- // or on a gradually expanding retry window for reconnects
- var sendFunc = function () {
- req.xhr.send(req.data);
- };
-
- // Implement progressive backoff for reconnects --
- // First retry (send == 1) should also be instantaneous
- if (req.sends > 1) {
- // Using a cube of the retry number creats a nicely
- // expanding retry window
- var backoff = Math.pow(req.sends, 3) * 1000;
- setTimeout(sendFunc, backoff);
- } else {
- sendFunc();
- }
-
- req.sends++;
-
- this.xmlOutput(req.xmlData);
- this.rawOutput(req.data);
- } else {
- Strophe.debug("_processRequest: " +
- (i === 0 ? "first" : "second") +
- " request has readyState of " +
- req.xhr.readyState);
- }
- },
-
- /** PrivateFunction: _throttledRequestHandler
- * _Private_ function to throttle requests to the connection window.
- *
- * This function makes sure we don't send requests so fast that the
- * request ids overflow the connection window in the case that one
- * request died.
- */
- _throttledRequestHandler: function ()
- {
- if (!this._requests) {
- Strophe.debug("_throttledRequestHandler called with " +
- "undefined requests");
- } else {
- Strophe.debug("_throttledRequestHandler called with " +
- this._requests.length + " requests");
- }
-
- if (!this._requests || this._requests.length === 0) {
- return;
- }
-
- if (this._requests.length > 0) {
- this._processRequest(0);
- }
-
- if (this._requests.length > 1 &&
- Math.abs(this._requests[0].rid -
- this._requests[1].rid) < this.window) {
- this._processRequest(1);
- }
- },
-
- /** PrivateFunction: _onRequestStateChange
- * _Private_ handler for Strophe.Request state changes.
- *
- * This function is called when the XMLHttpRequest readyState changes.
- * It contains a lot of error handling logic for the many ways that
- * requests can fail, and calls the request callback when requests
- * succeed.
- *
- * Parameters:
- * (Function) func - The handler for the request.
- * (Strophe.Request) req - The request that is changing readyState.
- */
- _onRequestStateChange: function (func, req)
- {
- Strophe.debug("request id " + req.id +
- "." + req.sends + " state changed to " +
- req.xhr.readyState);
-
- if (req.abort) {
- req.abort = false;
- return;
- }
-
- // request complete
- var reqStatus;
- if (req.xhr.readyState == 4) {
- reqStatus = 0;
- try {
- reqStatus = req.xhr.status;
- } catch (e) {
- // ignore errors from undefined status attribute. works
- // around a browser bug
- }
-
- if (typeof(reqStatus) == "undefined") {
- reqStatus = 0;
- }
-
- if (this.disconnecting) {
- if (reqStatus >= 400) {
- this._hitError(reqStatus);
- return;
- }
- }
-
- var reqIs0 = (this._requests[0] == req);
- var reqIs1 = (this._requests[1] == req);
-
- if ((reqStatus > 0 && reqStatus < 500) || req.sends > 5) {
- // remove from internal queue
- this._removeRequest(req);
- Strophe.debug("request id " +
- req.id +
- " should now be removed");
- }
-
- // request succeeded
- if (reqStatus == 200) {
- // if request 1 finished, or request 0 finished and request
- // 1 is over Strophe.SECONDARY_TIMEOUT seconds old, we need to
- // restart the other - both will be in the first spot, as the
- // completed request has been removed from the queue already
- if (reqIs1 ||
- (reqIs0 && this._requests.length > 0 &&
- this._requests[0].age() > Math.floor(Strophe.SECONDARY_TIMEOUT * this.wait))) {
- this._restartRequest(0);
- }
- // call handler
- Strophe.debug("request id " +
- req.id + "." +
- req.sends + " got 200");
- func(req);
- this.errors = 0;
- } else {
- Strophe.error("request id " +
- req.id + "." +
- req.sends + " error " + reqStatus +
- " happened");
- if (reqStatus === 0 ||
- (reqStatus >= 400 && reqStatus < 600) ||
- reqStatus >= 12000) {
- this._hitError(reqStatus);
- if (reqStatus >= 400 && reqStatus < 500) {
- this._changeConnectStatus(Strophe.Status.DISCONNECTING,
- null);
- this._doDisconnect();
- }
- }
- }
-
- if (!((reqStatus > 0 && reqStatus < 500) ||
- req.sends > 5)) {
- this._throttledRequestHandler();
- }
- }
- },
-
- /** PrivateFunction: _hitError
- * _Private_ function to handle the error count.
- *
- * Requests are resent automatically until their error count reaches
- * 5. Each time an error is encountered, this function is called to
- * increment the count and disconnect if the count is too high.
- *
- * Parameters:
- * (Integer) reqStatus - The request status.
- */
- _hitError: function (reqStatus)
- {
- this.errors++;
- Strophe.warn("request errored, status: " + reqStatus +
- ", number of errors: " + this.errors);
- if (this.errors > 4) {
- this._onDisconnectTimeout();
- }
- },
-
- /** PrivateFunction: _doDisconnect
- * _Private_ function to disconnect.
- *
- * This is the last piece of the disconnection logic. This resets the
- * connection and alerts the user's connection callback.
- */
- _doDisconnect: function ()
- {
- Strophe.info("_doDisconnect was called");
- this.authenticated = false;
- this.disconnecting = false;
- this.sid = null;
- this.streamId = null;
- this.rid = Math.floor(Math.random() * 4294967295);
-
- // tell the parent we disconnected
- if (this.connected) {
- this._changeConnectStatus(Strophe.Status.DISCONNECTED, null);
- this.connected = false;
- }
-
- // delete handlers
- this.handlers = [];
- this.timedHandlers = [];
- this.removeTimeds = [];
- this.removeHandlers = [];
- this.addTimeds = [];
- this.addHandlers = [];
- },
-
- /** PrivateFunction: _dataRecv
- * _Private_ handler to processes incoming data from the the connection.
- *
- * Except for _connect_cb handling the initial connection request,
- * this function handles the incoming data for all requests. This
- * function also fires stanza handlers that match each incoming
- * stanza.
- *
- * Parameters:
- * (Strophe.Request) req - The request that has data ready.
- */
- _dataRecv: function (req)
- {
- try {
- var elem = req.getResponse();
- } catch (e) {
- if (e != "parsererror") { throw e; }
- this.disconnect("strophe-parsererror");
- }
- if (elem === null) { return; }
-
- this.xmlInput(elem);
- this.rawInput(Strophe.serialize(elem));
-
- // remove handlers scheduled for deletion
- var i, hand;
- while (this.removeHandlers.length > 0) {
- hand = this.removeHandlers.pop();
- i = this.handlers.indexOf(hand);
- if (i >= 0) {
- this.handlers.splice(i, 1);
- }
- }
-
- // add handlers scheduled for addition
- while (this.addHandlers.length > 0) {
- this.handlers.push(this.addHandlers.pop());
- }
-
- // handle graceful disconnect
- if (this.disconnecting && this._requests.length === 0) {
- this.deleteTimedHandler(this._disconnectTimeout);
- this._disconnectTimeout = null;
- this._doDisconnect();
- return;
- }
-
- var typ = elem.getAttribute("type");
- var cond, conflict;
- if (typ !== null && typ == "terminate") {
- // Don't process stanzas that come in after disconnect
- if (this.disconnecting) {
- return;
- }
-
- // an error occurred
- cond = elem.getAttribute("condition");
- conflict = elem.getElementsByTagName("conflict");
- if (cond !== null) {
- if (cond == "remote-stream-error" && conflict.length > 0) {
- cond = "conflict";
- }
- this._changeConnectStatus(Strophe.Status.CONNFAIL, cond);
- } else {
- this._changeConnectStatus(Strophe.Status.CONNFAIL, "unknown");
- }
- this.disconnect();
- return;
- }
-
- // send each incoming stanza through the handler chain
- var that = this;
- Strophe.forEachChild(elem, null, function (child) {
- var i, newList;
- // process handlers
- newList = that.handlers;
- that.handlers = [];
- for (i = 0; i < newList.length; i++) {
- var hand = newList[i];
- if (hand.isMatch(child) &&
- (that.authenticated || !hand.user)) {
- if (hand.run(child)) {
- that.handlers.push(hand);
- }
- } else {
- that.handlers.push(hand);
- }
- }
- });
- },
-
- /** PrivateFunction: _sendTerminate
- * _Private_ function to send initial disconnect sequence.
- *
- * This is the first step in a graceful disconnect. It sends
- * the BOSH server a terminate body and includes an unavailable
- * presence if authentication has completed.
- */
- _sendTerminate: function ()
- {
- Strophe.info("_sendTerminate was called");
- var body = this._buildBody().attrs({type: "terminate"});
-
- if (this.authenticated) {
- body.c('presence', {
- xmlns: Strophe.NS.CLIENT,
- type: 'unavailable'
- });
- }
-
- this.disconnecting = true;
-
- var req = new Strophe.Request(body.tree(),
- this._onRequestStateChange.bind(
- this, this._dataRecv.bind(this)),
- body.tree().getAttribute("rid"));
-
- this._requests.push(req);
- this._throttledRequestHandler();
- },
-
- /** PrivateFunction: _connect_cb
- * _Private_ handler for initial connection request.
- *
- * This handler is used to process the initial connection request
- * response from the BOSH server. It is used to set up authentication
- * handlers and start the authentication process.
- *
- * SASL authentication will be attempted if available, otherwise
- * the code will fall back to legacy authentication.
- *
- * Parameters:
- * (Strophe.Request) req - The current request.
- */
- _connect_cb: function (req)
- {
- Strophe.info("_connect_cb was called");
-
- this.connected = true;
- var bodyWrap = req.getResponse();
- if (!bodyWrap) { return; }
-
- this.xmlInput(bodyWrap);
- this.rawInput(Strophe.serialize(bodyWrap));
-
- var typ = bodyWrap.getAttribute("type");
- var cond, conflict;
- if (typ !== null && typ == "terminate") {
- // an error occurred
- cond = bodyWrap.getAttribute("condition");
- conflict = bodyWrap.getElementsByTagName("conflict");
- if (cond !== null) {
- if (cond == "remote-stream-error" && conflict.length > 0) {
- cond = "conflict";
- }
- this._changeConnectStatus(Strophe.Status.CONNFAIL, cond);
- } else {
- this._changeConnectStatus(Strophe.Status.CONNFAIL, "unknown");
- }
- return;
- }
-
- // check to make sure we don't overwrite these if _connect_cb is
- // called multiple times in the case of missing stream:features
- if (!this.sid) {
- this.sid = bodyWrap.getAttribute("sid");
- }
- if (!this.stream_id) {
- this.stream_id = bodyWrap.getAttribute("authid");
- }
- var wind = bodyWrap.getAttribute('requests');
- if (wind) { this.window = parseInt(wind, 10); }
- var hold = bodyWrap.getAttribute('hold');
- if (hold) { this.hold = parseInt(hold, 10); }
- var wait = bodyWrap.getAttribute('wait');
- if (wait) { this.wait = parseInt(wait, 10); }
-
-
- var do_sasl_plain = false;
- var do_sasl_digest_md5 = false;
- var do_sasl_anonymous = false;
-
- var mechanisms = bodyWrap.getElementsByTagName("mechanism");
- var i, mech, auth_str, hashed_auth_str;
- if (mechanisms.length > 0) {
- for (i = 0; i < mechanisms.length; i++) {
- mech = Strophe.getText(mechanisms[i]);
- if (mech == 'DIGEST-MD5') {
- do_sasl_digest_md5 = true;
- } else if (mech == 'PLAIN') {
- do_sasl_plain = true;
- } else if (mech == 'ANONYMOUS') {
- do_sasl_anonymous = true;
- }
- }
- } else {
- // we didn't get stream:features yet, so we need wait for it
- // by sending a blank poll request
- var body = this._buildBody();
- this._requests.push(
- new Strophe.Request(body.tree(),
- this._onRequestStateChange.bind(
- this, this._connect_cb.bind(this)),
- body.tree().getAttribute("rid")));
- this._throttledRequestHandler();
- return;
- }
-
- if (Strophe.getNodeFromJid(this.jid) === null &&
- do_sasl_anonymous) {
- this._changeConnectStatus(Strophe.Status.AUTHENTICATING, null);
- this._sasl_success_handler = this._addSysHandler(
- this._sasl_success_cb.bind(this), null,
- "success", null, null);
- this._sasl_failure_handler = this._addSysHandler(
- this._sasl_failure_cb.bind(this), null,
- "failure", null, null);
-
- this.send($build("auth", {
- xmlns: Strophe.NS.SASL,
- mechanism: "ANONYMOUS"
- }).tree());
- } else if (Strophe.getNodeFromJid(this.jid) === null) {
- // we don't have a node, which is required for non-anonymous
- // client connections
- this._changeConnectStatus(Strophe.Status.CONNFAIL,
- 'x-strophe-bad-non-anon-jid');
- this.disconnect();
- } else if (do_sasl_digest_md5) {
- this._changeConnectStatus(Strophe.Status.AUTHENTICATING, null);
- this._sasl_challenge_handler = this._addSysHandler(
- this._sasl_challenge1_cb.bind(this), null,
- "challenge", null, null);
- this._sasl_failure_handler = this._addSysHandler(
- this._sasl_failure_cb.bind(this), null,
- "failure", null, null);
-
- this.send($build("auth", {
- xmlns: Strophe.NS.SASL,
- mechanism: "DIGEST-MD5"
- }).tree());
- } else if (do_sasl_plain) {
- // Build the plain auth string (barejid null
- // username null password) and base 64 encoded.
- auth_str = Strophe.getBareJidFromJid(this.jid);
- auth_str = auth_str + "\u0000";
- auth_str = auth_str + Strophe.getNodeFromJid(this.jid);
- auth_str = auth_str + "\u0000";
- auth_str = auth_str + this.pass;
-
- this._changeConnectStatus(Strophe.Status.AUTHENTICATING, null);
- this._sasl_success_handler = this._addSysHandler(
- this._sasl_success_cb.bind(this), null,
- "success", null, null);
- this._sasl_failure_handler = this._addSysHandler(
- this._sasl_failure_cb.bind(this), null,
- "failure", null, null);
-
- hashed_auth_str = Base64.encode(auth_str);
- this.send($build("auth", {
- xmlns: Strophe.NS.SASL,
- mechanism: "PLAIN"
- }).t(hashed_auth_str).tree());
- } else {
- this._changeConnectStatus(Strophe.Status.AUTHENTICATING, null);
- this._addSysHandler(this._auth1_cb.bind(this), null, null,
- null, "_auth_1");
-
- this.send($iq({
- type: "get",
- to: this.domain,
- id: "_auth_1"
- }).c("query", {
- xmlns: Strophe.NS.AUTH
- }).c("username", {}).t(Strophe.getNodeFromJid(this.jid)).tree());
- }
- },
-
- /** PrivateFunction: _sasl_challenge1_cb
- * _Private_ handler for DIGEST-MD5 SASL authentication.
- *
- * Parameters:
- * (XMLElement) elem - The challenge stanza.
- *
- * Returns:
- * false to remove the handler.
- */
- _sasl_challenge1_cb: function (elem)
- {
- var attribMatch = /([a-z]+)=("[^"]+"|[^,"]+)(?:,|$)/;
-
- var challenge = Base64.decode(Strophe.getText(elem));
- var cnonce = MD5.hexdigest(Math.random() * 1234567890);
- var realm = "";
- var host = null;
- var nonce = "";
- var qop = "";
- var matches;
-
- // remove unneeded handlers
- this.deleteHandler(this._sasl_failure_handler);
-
- while (challenge.match(attribMatch)) {
- matches = challenge.match(attribMatch);
- challenge = challenge.replace(matches[0], "");
- matches[2] = matches[2].replace(/^"(.+)"$/, "$1");
- switch (matches[1]) {
- case "realm":
- realm = matches[2];
- break;
- case "nonce":
- nonce = matches[2];
- break;
- case "qop":
- qop = matches[2];
- break;
- case "host":
- host = matches[2];
- break;
- }
- }
-
- var digest_uri = "xmpp/" + this.domain;
- if (host !== null) {
- digest_uri = digest_uri + "/" + host;
- }
-
- var A1 = MD5.hash(Strophe.getNodeFromJid(this.jid) +
- ":" + realm + ":" + this.pass) +
- ":" + nonce + ":" + cnonce;
- var A2 = 'AUTHENTICATE:' + digest_uri;
-
- var responseText = "";
- responseText += 'username=' +
- this._quote(Strophe.getNodeFromJid(this.jid)) + ',';
- responseText += 'realm=' + this._quote(realm) + ',';
- responseText += 'nonce=' + this._quote(nonce) + ',';
- responseText += 'cnonce=' + this._quote(cnonce) + ',';
- responseText += 'nc="00000001",';
- responseText += 'qop="auth",';
- responseText += 'digest-uri=' + this._quote(digest_uri) + ',';
- responseText += 'response=' + this._quote(
- MD5.hexdigest(MD5.hexdigest(A1) + ":" +
- nonce + ":00000001:" +
- cnonce + ":auth:" +
- MD5.hexdigest(A2))) + ',';
- responseText += 'charset="utf-8"';
-
- this._sasl_challenge_handler = this._addSysHandler(
- this._sasl_challenge2_cb.bind(this), null,
- "challenge", null, null);
- this._sasl_success_handler = this._addSysHandler(
- this._sasl_success_cb.bind(this), null,
- "success", null, null);
- this._sasl_failure_handler = this._addSysHandler(
- this._sasl_failure_cb.bind(this), null,
- "failure", null, null);
-
- this.send($build('response', {
- xmlns: Strophe.NS.SASL
- }).t(Base64.encode(responseText)).tree());
-
- return false;
- },
-
- /** PrivateFunction: _quote
- * _Private_ utility function to backslash escape and quote strings.
- *
- * Parameters:
- * (String) str - The string to be quoted.
- *
- * Returns:
- * quoted string
- */
- _quote: function (str)
- {
- return '"' + str.replace(/\\/g, "\\\\").replace(/"/g, '\\"') + '"';
- //" end string workaround for emacs
- },
-
-
- /** PrivateFunction: _sasl_challenge2_cb
- * _Private_ handler for second step of DIGEST-MD5 SASL authentication.
- *
- * Parameters:
- * (XMLElement) elem - The challenge stanza.
- *
- * Returns:
- * false to remove the handler.
- */
- _sasl_challenge2_cb: function (elem)
- {
- // remove unneeded handlers
- this.deleteHandler(this._sasl_success_handler);
- this.deleteHandler(this._sasl_failure_handler);
-
- this._sasl_success_handler = this._addSysHandler(
- this._sasl_success_cb.bind(this), null,
- "success", null, null);
- this._sasl_failure_handler = this._addSysHandler(
- this._sasl_failure_cb.bind(this), null,
- "failure", null, null);
- this.send($build('response', {xmlns: Strophe.NS.SASL}).tree());
- return false;
- },
-
- /** PrivateFunction: _auth1_cb
- * _Private_ handler for legacy authentication.
- *
- * This handler is called in response to the initial <iq type='get'/>
- * for legacy authentication. It builds an authentication <iq/> and
- * sends it, creating a handler (calling back to _auth2_cb()) to
- * handle the result
- *
- * Parameters:
- * (XMLElement) elem - The stanza that triggered the callback.
- *
- * Returns:
- * false to remove the handler.
- */
- _auth1_cb: function (elem)
- {
- // build plaintext auth iq
- var iq = $iq({type: "set", id: "_auth_2"})
- .c('query', {xmlns: Strophe.NS.AUTH})
- .c('username', {}).t(Strophe.getNodeFromJid(this.jid))
- .up()
- .c('password').t(this.pass);
-
- if (!Strophe.getResourceFromJid(this.jid)) {
- // since the user has not supplied a resource, we pick
- // a default one here. unlike other auth methods, the server
- // cannot do this for us.
- this.jid = Strophe.getBareJidFromJid(this.jid) + '/strophe';
- }
- iq.up().c('resource', {}).t(Strophe.getResourceFromJid(this.jid));
-
- this._addSysHandler(this._auth2_cb.bind(this), null,
- null, null, "_auth_2");
-
- this.send(iq.tree());
-
- return false;
- },
-
- /** PrivateFunction: _sasl_success_cb
- * _Private_ handler for succesful SASL authentication.
- *
- * Parameters:
- * (XMLElement) elem - The matching stanza.
- *
- * Returns:
- * false to remove the handler.
- */
- _sasl_success_cb: function (elem)
- {
- Strophe.info("SASL authentication succeeded.");
-
- // remove old handlers
- this.deleteHandler(this._sasl_failure_handler);
- this._sasl_failure_handler = null;
- if (this._sasl_challenge_handler) {
- this.deleteHandler(this._sasl_challenge_handler);
- this._sasl_challenge_handler = null;
- }
-
- this._addSysHandler(this._sasl_auth1_cb.bind(this), null,
- "stream:features", null, null);
-
- // we must send an xmpp:restart now
- this._sendRestart();
-
- return false;
- },
-
- /** PrivateFunction: _sasl_auth1_cb
- * _Private_ handler to start stream binding.
- *
- * Parameters:
- * (XMLElement) elem - The matching stanza.
- *
- * Returns:
- * false to remove the handler.
- */
- _sasl_auth1_cb: function (elem)
- {
- // save stream:features for future usage
- this.features = elem;
-
- var i, child;
-
- for (i = 0; i < elem._childNodes.length; i++) {
- child = elem._childNodes[i];
- if (child.nodeName.toLowerCase() == 'bind') {
- this.do_bind = true;
- }
-
- if (child.nodeName.toLowerCase() == 'session') {
- this.do_session = true;
- }
- }
-
- if (!this.do_bind) {
- this._changeConnectStatus(Strophe.Status.AUTHFAIL, null);
- return false;
- } else {
- this._addSysHandler(this._sasl_bind_cb.bind(this), null, null,
- null, "_bind_auth_2");
-
- var resource = Strophe.getResourceFromJid(this.jid);
- if (resource) {
- this.send($iq({type: "set", id: "_bind_auth_2"})
- .c('bind', {xmlns: Strophe.NS.BIND})
- .c('resource', {}).t(resource).tree());
- } else {
- this.send($iq({type: "set", id: "_bind_auth_2"})
- .c('bind', {xmlns: Strophe.NS.BIND})
- .tree());
- }
- }
-
- return false;
- },
-
- /** PrivateFunction: _sasl_bind_cb
- * _Private_ handler for binding result and session start.
- *
- * Parameters:
- * (XMLElement) elem - The matching stanza.
- *
- * Returns:
- * false to remove the handler.
- */
- _sasl_bind_cb: function (elem)
- {
- if (elem.getAttribute("type") == "error") {
- Strophe.info("SASL binding failed.");
- this._changeConnectStatus(Strophe.Status.AUTHFAIL, null);
- return false;
- }
-
- // TODO - need to grab errors
- var bind = elem.getElementsByTagName("bind");
- var jidNode;
- if (bind.length > 0) {
- // Grab jid
- jidNode = bind[0].getElementsByTagName("jid");
- if (jidNode.length > 0) {
- this.jid = Strophe.getText(jidNode[0]);
-
- if (this.do_session) {
- this._addSysHandler(this._sasl_session_cb.bind(this),
- null, null, null, "_session_auth_2");
-
- this.send($iq({type: "set", id: "_session_auth_2"})
- .c('session', {xmlns: Strophe.NS.SESSION})
- .tree());
- } else {
- this.authenticated = true;
- this._changeConnectStatus(Strophe.Status.CONNECTED, null);
- }
- }
- } else {
- Strophe.info("SASL binding failed.");
- this._changeConnectStatus(Strophe.Status.AUTHFAIL, null);
- return false;
- }
- },
-
- /** PrivateFunction: _sasl_session_cb
- * _Private_ handler to finish successful SASL connection.
- *
- * This sets Connection.authenticated to true on success, which
- * starts the processing of user handlers.
- *
- * Parameters:
- * (XMLElement) elem - The matching stanza.
- *
- * Returns:
- * false to remove the handler.
- */
- _sasl_session_cb: function (elem)
- {
- if (elem.getAttribute("type") == "result") {
- this.authenticated = true;
- this._changeConnectStatus(Strophe.Status.CONNECTED, null);
- } else if (elem.getAttribute("type") == "error") {
- Strophe.info("Session creation failed.");
- this._changeConnectStatus(Strophe.Status.AUTHFAIL, null);
- return false;
- }
-
- return false;
- },
-
- /** PrivateFunction: _sasl_failure_cb
- * _Private_ handler for SASL authentication failure.
- *
- * Parameters:
- * (XMLElement) elem - The matching stanza.
- *
- * Returns:
- * false to remove the handler.
- */
- _sasl_failure_cb: function (elem)
- {
- // delete unneeded handlers
- if (this._sasl_success_handler) {
- this.deleteHandler(this._sasl_success_handler);
- this._sasl_success_handler = null;
- }
- if (this._sasl_challenge_handler) {
- this.deleteHandler(this._sasl_challenge_handler);
- this._sasl_challenge_handler = null;
- }
-
- this._changeConnectStatus(Strophe.Status.AUTHFAIL, null);
- return false;
- },
-
- /** PrivateFunction: _auth2_cb
- * _Private_ handler to finish legacy authentication.
- *
- * This handler is called when the result from the jabber:iq:auth
- * <iq/> stanza is returned.
- *
- * Parameters:
- * (XMLElement) elem - The stanza that triggered the callback.
- *
- * Returns:
- * false to remove the handler.
- */
- _auth2_cb: function (elem)
- {
- if (elem.getAttribute("type") == "result") {
- this.authenticated = true;
- this._changeConnectStatus(Strophe.Status.CONNECTED, null);
- } else if (elem.getAttribute("type") == "error") {
- this._changeConnectStatus(Strophe.Status.AUTHFAIL, null);
- this.disconnect();
- }
-
- return false;
- },
-
- /** PrivateFunction: _addSysTimedHandler
- * _Private_ function to add a system level timed handler.
- *
- * This function is used to add a Strophe.TimedHandler for the
- * library code. System timed handlers are allowed to run before
- * authentication is complete.
- *
- * Parameters:
- * (Integer) period - The period of the handler.
- * (Function) handler - The callback function.
- */
- _addSysTimedHandler: function (period, handler)
- {
- var thand = new Strophe.TimedHandler(period, handler);
- thand.user = false;
- this.addTimeds.push(thand);
- return thand;
- },
-
- /** PrivateFunction: _addSysHandler
- * _Private_ function to add a system level stanza handler.
- *
- * This function is used to add a Strophe.Handler for the
- * library code. System stanza handlers are allowed to run before
- * authentication is complete.
- *
- * Parameters:
- * (Function) handler - The callback function.
- * (String) ns - The namespace to match.
- * (String) name - The stanza name to match.
- * (String) type - The stanza type attribute to match.
- * (String) id - The stanza id attribute to match.
- */
- _addSysHandler: function (handler, ns, name, type, id)
- {
- var hand = new Strophe.Handler(handler, ns, name, type, id);
- hand.user = false;
- this.addHandlers.push(hand);
- return hand;
- },
-
- /** PrivateFunction: _onDisconnectTimeout
- * _Private_ timeout handler for handling non-graceful disconnection.
- *
- * If the graceful disconnect process does not complete within the
- * time allotted, this handler finishes the disconnect anyway.
- *
- * Returns:
- * false to remove the handler.
- */
- _onDisconnectTimeout: function ()
- {
- Strophe.info("_onDisconnectTimeout was called");
-
- // cancel all remaining requests and clear the queue
- var req;
- while (this._requests.length > 0) {
- req = this._requests.pop();
- req.abort = true;
- req.xhr.abort();
- // jslint complains, but this is fine. setting to empty func
- // is necessary for IE6
- req.xhr.onreadystatechange = function () {};
- }
-
- // actually disconnect
- this._doDisconnect();
-
- return false;
- },
-
- /** PrivateFunction: _onIdle
- * _Private_ handler to process events during idle cycle.
- *
- * This handler is called every 100ms to fire timed handlers that
- * are ready and keep poll requests going.
- */
- _onIdle: function ()
- {
- var i, thand, since, newList;
-
- // add timed handlers scheduled for addition
- // NOTE: we add before remove in the case a timed handler is
- // added and then deleted before the next _onIdle() call.
- while (this.addTimeds.length > 0) {
- this.timedHandlers.push(this.addTimeds.pop());
- }
-
- // remove timed handlers that have been scheduled for deletion
- while (this.removeTimeds.length > 0) {
- thand = this.removeTimeds.pop();
- i = this.timedHandlers.indexOf(thand);
- if (i >= 0) {
- this.timedHandlers.splice(i, 1);
- }
- }
-
- // call ready timed handlers
- var now = new Date().getTime();
- newList = [];
- for (i = 0; i < this.timedHandlers.length; i++) {
- thand = this.timedHandlers[i];
- if (this.authenticated || !thand.user) {
- since = thand.lastCalled + thand.period;
- if (since - now <= 0) {
- if (thand.run()) {
- newList.push(thand);
- }
- } else {
- newList.push(thand);
- }
- }
- }
- this.timedHandlers = newList;
-
- var body, time_elapsed;
-
- // if no requests are in progress, poll
- if (this.authenticated && this._requests.length === 0 &&
- this._data.length === 0 && !this.disconnecting) {
- Strophe.info("no requests during idle cycle, sending " +
- "blank request");
- this._data.push(null);
- }
-
- if (this._requests.length < 2 && this._data.length > 0 &&
- !this.paused) {
- body = this._buildBody();
- for (i = 0; i < this._data.length; i++) {
- if (this._data[i] !== null) {
- if (this._data[i] === "restart") {
- body.attrs({
- to: this.domain,
- "xml:lang": "en",
- "xmpp:restart": "true",
- "xmlns:xmpp": Strophe.NS.BOSH
- });
- } else {
- body.cnode(this._data[i]).up();
- }
- }
- }
- delete this._data;
- this._data = [];
- this._requests.push(
- new Strophe.Request(body.tree(),
- this._onRequestStateChange.bind(
- this, this._dataRecv.bind(this)),
- body.tree().getAttribute("rid")));
- this._processRequest(this._requests.length - 1);
- }
-
- if (this._requests.length > 0) {
- time_elapsed = this._requests[0].age();
- if (this._requests[0].dead !== null) {
- if (this._requests[0].timeDead() >
- Math.floor(Strophe.SECONDARY_TIMEOUT * this.wait)) {
- this._throttledRequestHandler();
- }
- }
-
- if (time_elapsed > Math.floor(Strophe.TIMEOUT * this.wait)) {
- Strophe.warn("Request " +
- this._requests[0].id +
- " timed out, over " + Math.floor(Strophe.TIMEOUT * this.wait) +
- " seconds since last activity");
- this._throttledRequestHandler();
- }
- }
-
- // reactivate the timer
- clearTimeout(this._idleTimeout);
- this._idleTimeout = setTimeout(this._onIdle.bind(this), 100);
- }
-};
-
-if (callback) {
- callback(Strophe, $build, $msg, $iq, $pres);
-}
-
-})(function () {
- window.Strophe = arguments[0];
- window.$build = arguments[1];
- window.$msg = arguments[2];
- window.$iq = arguments[3];
- window.$pres = arguments[4];
-});
diff --git a/contrib/jitsimeetbridge/unjingle/unjingle.js b/contrib/jitsimeetbridge/unjingle/unjingle.js
deleted file mode 100644
index 3dfe7599..00000000
--- a/contrib/jitsimeetbridge/unjingle/unjingle.js
+++ /dev/null
@@ -1,48 +0,0 @@
-var strophe = require("./strophe/strophe.js").Strophe;
-
-var Strophe = strophe.Strophe;
-var $iq = strophe.$iq;
-var $msg = strophe.$msg;
-var $build = strophe.$build;
-var $pres = strophe.$pres;
-
-var jsdom = require("jsdom");
-var window = jsdom.jsdom().parentWindow;
-var $ = require('jquery')(window);
-
-var stropheJingle = require("./strophe.jingle.sdp.js");
-
-
-var input = '';
-
-process.stdin.on('readable', function() {
- var chunk = process.stdin.read();
- if (chunk !== null) {
- input += chunk;
- }
-});
-
-process.stdin.on('end', function() {
- if (process.argv[2] == '--jingle') {
- var elem = $(input);
- // app does:
- // sess.setRemoteDescription($(iq).find('>jingle'), 'offer');
- //console.log(elem.find('>content'));
- var sdp = new stropheJingle.SDP('');
- sdp.fromJingle(elem);
- console.log(sdp.raw);
- } else if (process.argv[2] == '--sdp') {
- var sdp = new stropheJingle.SDP(input);
- var accept = $iq({to: '%(tojid)s',
- type: 'set'})
- .c('jingle', {xmlns: 'urn:xmpp:jingle:1',
- //action: 'session-accept',
- action: '%(action)s',
- initiator: '%(initiator)s',
- responder: '%(responder)s',
- sid: '%(sid)s' });
- sdp.toJingle(accept, 'responder');
- console.log(Strophe.serialize(accept));
- }
-});
-
diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py
deleted file mode 100755
index f8e0c732..00000000
--- a/contrib/scripts/kick_users.py
+++ /dev/null
@@ -1,88 +0,0 @@
-#!/usr/bin/env python
-
-import json
-import sys
-import urllib
-from argparse import ArgumentParser
-
-import requests
-
-
-def _mkurl(template, kws):
- for key in kws:
- template = template.replace(key, kws[key])
- return template
-
-
-def main(hs, room_id, access_token, user_id_prefix, why):
- if not why:
- why = "Automated kick."
- print(
- "Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix)
- )
- room_state_url = _mkurl(
- "$HS/_matrix/client/api/v1/rooms/$ROOM/state?access_token=$TOKEN",
- {"$HS": hs, "$ROOM": room_id, "$TOKEN": access_token},
- )
- print("Getting room state => %s" % room_state_url)
- res = requests.get(room_state_url)
- print("HTTP %s" % res.status_code)
- state_events = res.json()
- if "error" in state_events:
- print("FATAL")
- print(state_events)
- return
-
- kick_list = []
- room_name = room_id
- for event in state_events:
- if not event["type"] == "m.room.member":
- if event["type"] == "m.room.name":
- room_name = event["content"].get("name")
- continue
- if not event["content"].get("membership") == "join":
- continue
- if event["state_key"].startswith(user_id_prefix):
- kick_list.append(event["state_key"])
-
- if len(kick_list) == 0:
- print("No user IDs match the prefix '%s'" % user_id_prefix)
- return
-
- print("The following user IDs will be kicked from %s" % room_name)
- for uid in kick_list:
- print(uid)
- doit = input("Continue? [Y]es\n")
- if len(doit) > 0 and doit.lower() == "y":
- print("Kicking members...")
- # encode them all
- kick_list = [urllib.quote(uid) for uid in kick_list]
- for uid in kick_list:
- kick_url = _mkurl(
- "$HS/_matrix/client/api/v1/rooms/$ROOM/state/m.room.member/$UID?access_token=$TOKEN",
- {"$HS": hs, "$UID": uid, "$ROOM": room_id, "$TOKEN": access_token},
- )
- kick_body = {"membership": "leave", "reason": why}
- print("Kicking %s" % uid)
- res = requests.put(kick_url, data=json.dumps(kick_body))
- if res.status_code != 200:
- print("ERROR: HTTP %s" % res.status_code)
- if res.json().get("error"):
- print("ERROR: JSON %s" % res.json())
-
-
-if __name__ == "__main__":
- parser = ArgumentParser("Kick members in a room matching a certain user ID prefix.")
- parser.add_argument("-u", "--user-id", help="The user ID prefix e.g. '@irc_'")
- parser.add_argument("-t", "--token", help="Your access_token")
- parser.add_argument("-r", "--room", help="The room ID to kick members in")
- parser.add_argument(
- "-s", "--homeserver", help="The base HS url e.g. http://matrix.org"
- )
- parser.add_argument("-w", "--why", help="Reason for the kick. Optional.")
- args = parser.parse_args()
- if not args.room or not args.token or not args.user_id or not args.homeserver:
- parser.print_help()
- sys.exit(1)
- else:
- main(args.homeserver, args.room, args.token, args.user_id, args.why)
diff --git a/debian/copyright b/debian/copyright
index ca575c68..bac94c0e 100644
--- a/debian/copyright
+++ b/debian/copyright
@@ -26,29 +26,6 @@ Files: synapse/config/saml2.py
Copyright: 2015 Ericsson
License: Apache-2.0
-Files: contrib/jitsimeetbridge/unjingle/strophe/base64.js
-Copyright: Public Domain (Tyler Akins http://rumkin.com)
-License: public-domain
- This code was written by Tyler Akins and has been placed in the
- public domain. It would be nice if you left this header intact.
- Base64 code from Tyler Akins -- http://rumkin.com
-
-Files: contrib/jitsimeetbridge/unjingle/strophe/md5.js
-Copyright: 1999-2002 Paul Johnston & Contributors
-License: BSD-3-clause
-
-Files: contrib/jitsimeetbridge/unjingle/strophe/strophe.js
-Copyright: 2006-2008 OGG, LLC
-License: Expat
-
-Files: contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js
-Copyright: 2010 passive.ly LLC
-License: Expat
-
-Files: contrib/jitsimeetbridge/unjingle/*.js
-Copyright: 2014 Jitsi
-License: Apache-2.0
-
Files: debian/*
Copyright:
2015—2017 Erik Johnston <erikj@matrix.org>
diff --git a/demo/start.sh b/demo/start.sh
index 5a9972d2..fdd75816 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -6,12 +6,14 @@ CWD=$(pwd)
cd "$DIR/.." || exit
-PYTHONPATH=$(readlink -f "$(pwd)")
-export PYTHONPATH
-
-
-echo "$PYTHONPATH"
-
+# Do not override PYTHONPATH if we are in a virtual env
+if [ "$VIRTUAL_ENV" = "" ]; then
+ PYTHONPATH=$(readlink -f "$(pwd)")
+ export PYTHONPATH
+ echo "$PYTHONPATH"
+fi
+
+# Create servers which listen on HTTP at 808x and HTTPS at 848x.
for port in 8080 8081 8082; do
echo "Starting server on port $port... "
@@ -19,10 +21,12 @@ for port in 8080 8081 8082; do
mkdir -p demo/$port
pushd demo/$port || exit
- # Generate the configuration for the homeserver at localhost:848x.
+ # Generate the configuration for the homeserver at localhost:848x, note that
+ # the homeserver name needs to match the HTTPS listening port for federation
+ # to properly work..
python3 -m synapse.app.homeserver \
--generate-config \
- --server-name "localhost:$port" \
+ --server-name "localhost:$https_port" \
--config-path "$port.config" \
--report-stats no
diff --git a/docker/Dockerfile b/docker/Dockerfile
index ccc6a9f7..7af0e51f 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -55,7 +55,7 @@ RUN \
# NB: In poetry 1.2 `poetry export` will be moved into a plugin; we'll need to also
# pip install poetry-plugin-export (https://github.com/python-poetry/poetry-plugin-export).
RUN --mount=type=cache,target=/root/.cache/pip \
- pip install --user git+https://github.com/python-poetry/poetry.git@fb13b3a676f476177f7937ffa480ee5cff9a90a5
+ pip install --user "poetry-core==1.1.0a7" "git+https://github.com/python-poetry/poetry.git@fb13b3a676f476177f7937ffa480ee5cff9a90a5"
WORKDIR /synapse
diff --git a/docker/complement/SynapseWorkers.Dockerfile b/docker/complement/SynapseWorkers.Dockerfile
index 9a4438e7..99a09cbc 100644
--- a/docker/complement/SynapseWorkers.Dockerfile
+++ b/docker/complement/SynapseWorkers.Dockerfile
@@ -6,12 +6,6 @@
# https://github.com/matrix-org/synapse/blob/develop/docker/README-testing.md#testing-with-postgresql-and-single-or-multi-process-synapse
FROM matrixdotorg/synapse-workers
-# Download a caddy server to stand in front of nginx and terminate TLS using Complement's
-# custom CA.
-# We include this near the top of the file in order to cache the result.
-RUN curl -OL "https://github.com/caddyserver/caddy/releases/download/v2.3.0/caddy_2.3.0_linux_amd64.tar.gz" && \
- tar xzf caddy_2.3.0_linux_amd64.tar.gz && rm caddy_2.3.0_linux_amd64.tar.gz && mv caddy /root
-
# Install postgresql
RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y postgresql-13
@@ -31,16 +25,12 @@ COPY conf-workers/workers-shared.yaml /conf/workers/shared.yaml
WORKDIR /data
-# Copy the caddy config
-COPY conf-workers/caddy.complement.json /root/caddy.json
-
COPY conf-workers/postgres.supervisord.conf /etc/supervisor/conf.d/postgres.conf
-COPY conf-workers/caddy.supervisord.conf /etc/supervisor/conf.d/caddy.conf
# Copy the entrypoint
COPY conf-workers/start-complement-synapse-workers.sh /
-# Expose caddy's listener ports
+# Expose nginx's listener ports
EXPOSE 8008 8448
ENTRYPOINT ["/start-complement-synapse-workers.sh"]
diff --git a/docker/complement/conf-workers/caddy.complement.json b/docker/complement/conf-workers/caddy.complement.json
deleted file mode 100644
index 09e2136a..00000000
--- a/docker/complement/conf-workers/caddy.complement.json
+++ /dev/null
@@ -1,72 +0,0 @@
-{
- "apps": {
- "http": {
- "servers": {
- "srv0": {
- "listen": [
- ":8448"
- ],
- "routes": [
- {
- "match": [
- {
- "host": [
- "{{ server_name }}"
- ]
- }
- ],
- "handle": [
- {
- "handler": "subroute",
- "routes": [
- {
- "handle": [
- {
- "handler": "reverse_proxy",
- "upstreams": [
- {
- "dial": "localhost:8008"
- }
- ]
- }
- ]
- }
- ]
- }
- ],
- "terminal": true
- }
- ]
- }
- }
- },
- "tls": {
- "automation": {
- "policies": [
- {
- "subjects": [
- "{{ server_name }}"
- ],
- "issuers": [
- {
- "module": "internal"
- }
- ],
- "on_demand": true
- }
- ]
- }
- },
- "pki": {
- "certificate_authorities": {
- "local": {
- "name": "Complement CA",
- "root": {
- "certificate": "/complement/ca/ca.crt",
- "private_key": "/complement/ca/ca.key"
- }
- }
- }
- }
- }
- }
diff --git a/docker/complement/conf-workers/caddy.supervisord.conf b/docker/complement/conf-workers/caddy.supervisord.conf
deleted file mode 100644
index d9ddb51d..00000000
--- a/docker/complement/conf-workers/caddy.supervisord.conf
+++ /dev/null
@@ -1,7 +0,0 @@
-[program:caddy]
-command=/usr/local/bin/prefix-log /root/caddy run --config /root/caddy.json
-autorestart=unexpected
-stdout_logfile=/dev/stdout
-stdout_logfile_maxbytes=0
-stderr_logfile=/dev/stderr
-stderr_logfile_maxbytes=0
diff --git a/docker/complement/conf-workers/start-complement-synapse-workers.sh b/docker/complement/conf-workers/start-complement-synapse-workers.sh
index b9a6b55b..b7e24440 100755
--- a/docker/complement/conf-workers/start-complement-synapse-workers.sh
+++ b/docker/complement/conf-workers/start-complement-synapse-workers.sh
@@ -9,9 +9,6 @@ function log {
echo "$d $@"
}
-# Replace the server name in the caddy config
-sed -i "s/{{ server_name }}/${SERVER_NAME}/g" /root/caddy.json
-
# Set the server name of the homeserver
export SYNAPSE_SERVER_NAME=${SERVER_NAME}
@@ -39,6 +36,26 @@ export SYNAPSE_WORKER_TYPES="\
appservice, \
pusher"
+# Add Complement's appservice registration directory, if there is one
+# (It can be absent when there are no application services in this test!)
+if [ -d /complement/appservice ]; then
+ export SYNAPSE_AS_REGISTRATION_DIR=/complement/appservice
+fi
+
+# Generate a TLS key, then generate a certificate by having Complement's CA sign it
+# Note that both the key and certificate are in PEM format (not DER).
+openssl genrsa -out /conf/server.tls.key 2048
+
+openssl req -new -key /conf/server.tls.key -out /conf/server.tls.csr \
+ -subj "/CN=${SERVER_NAME}"
+
+openssl x509 -req -in /conf/server.tls.csr \
+ -CA /complement/ca/ca.crt -CAkey /complement/ca/ca.key -set_serial 1 \
+ -out /conf/server.tls.crt
+
+export SYNAPSE_TLS_CERT=/conf/server.tls.crt
+export SYNAPSE_TLS_KEY=/conf/server.tls.key
+
# Run the script that writes the necessary config files and starts supervisord, which in turn
# starts everything else
exec /configure_workers_and_start.py
diff --git a/docker/complement/conf-workers/workers-shared.yaml b/docker/complement/conf-workers/workers-shared.yaml
index 8b698703..cd7b50c6 100644
--- a/docker/complement/conf-workers/workers-shared.yaml
+++ b/docker/complement/conf-workers/workers-shared.yaml
@@ -5,6 +5,12 @@ enable_registration: true
enable_registration_without_verification: true
bcrypt_rounds: 4
+## Registration ##
+
+# Needed by Complement to register admin users
+# DO NOT USE in a production configuration! This should be a random secret.
+registration_shared_secret: complement
+
## Federation ##
# trust certs signed by Complement's CA
@@ -53,6 +59,18 @@ rc_joins:
per_second: 9999
burst_count: 9999
+rc_3pid_validation:
+ per_second: 1000
+ burst_count: 1000
+
+rc_invites:
+ per_room:
+ per_second: 1000
+ burst_count: 1000
+ per_user:
+ per_second: 1000
+ burst_count: 1000
+
federation_rr_transactions_per_room_per_second: 9999
## Experimental Features ##
diff --git a/docker/complement/conf/homeserver.yaml b/docker/complement/conf/homeserver.yaml
index 174f87f5..e2be540b 100644
--- a/docker/complement/conf/homeserver.yaml
+++ b/docker/complement/conf/homeserver.yaml
@@ -87,6 +87,18 @@ rc_joins:
per_second: 9999
burst_count: 9999
+rc_3pid_validation:
+ per_second: 1000
+ burst_count: 1000
+
+rc_invites:
+ per_room:
+ per_second: 1000
+ burst_count: 1000
+ per_user:
+ per_second: 1000
+ burst_count: 1000
+
federation_rr_transactions_per_room_per_second: 9999
## API Configuration ##
diff --git a/docker/conf-workers/nginx.conf.j2 b/docker/conf-workers/nginx.conf.j2
index 1081979e..967fc65e 100644
--- a/docker/conf-workers/nginx.conf.j2
+++ b/docker/conf-workers/nginx.conf.j2
@@ -9,6 +9,22 @@ server {
listen 8008;
listen [::]:8008;
+ {% if tls_cert_path is not none and tls_key_path is not none %}
+ listen 8448 ssl;
+ listen [::]:8448 ssl;
+
+ ssl_certificate {{ tls_cert_path }};
+ ssl_certificate_key {{ tls_key_path }};
+
+ # Some directives from cipherlist.eu (fka cipherli.st):
+ ssl_protocols TLSv1 TLSv1.1 TLSv1.2 TLSv1.3;
+ ssl_prefer_server_ciphers on;
+ ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH";
+ ssl_ecdh_curve secp384r1; # Requires nginx >= 1.1.0
+ ssl_session_cache shared:SSL:10m;
+ ssl_session_tickets off; # Requires nginx >= 1.5.9
+ {% endif %}
+
server_name localhost;
# Nginx by default only allows file uploads up to 1M in size
diff --git a/docker/conf-workers/shared.yaml.j2 b/docker/conf-workers/shared.yaml.j2
index f94b8c6a..644ed788 100644
--- a/docker/conf-workers/shared.yaml.j2
+++ b/docker/conf-workers/shared.yaml.j2
@@ -6,4 +6,13 @@
redis:
enabled: true
-{{ shared_worker_config }} \ No newline at end of file
+{% if appservice_registrations is not none %}
+## Application Services ##
+# A list of application service config files to use.
+app_service_config_files:
+{%- for path in appservice_registrations %}
+ - "{{ path }}"
+{%- endfor %}
+{%- endif %}
+
+{{ shared_worker_config }}
diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index b2b7938a..f7dac902 100755
--- a/docker/configure_workers_and_start.py
+++ b/docker/configure_workers_and_start.py
@@ -21,6 +21,11 @@
# * SYNAPSE_REPORT_STATS: Whether to report stats.
# * SYNAPSE_WORKER_TYPES: A comma separated list of worker names as specified in WORKER_CONFIG
# below. Leave empty for no workers, or set to '*' for all possible workers.
+# * SYNAPSE_AS_REGISTRATION_DIR: If specified, a directory in which .yaml and .yml files
+# will be treated as Application Service registration files.
+# * SYNAPSE_TLS_CERT: Path to a TLS certificate in PEM format.
+# * SYNAPSE_TLS_KEY: Path to a TLS key. If this and SYNAPSE_TLS_CERT are specified,
+# Nginx will be configured to serve TLS on port 8448.
#
# NOTE: According to Complement's ENTRYPOINT expectations for a homeserver image (as defined
# in the project's README), this script may be run multiple times, and functionality should
@@ -29,6 +34,7 @@
import os
import subprocess
import sys
+from pathlib import Path
from typing import Any, Dict, List, Mapping, MutableMapping, NoReturn, Set
import jinja2
@@ -152,6 +158,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/join/",
"^/_matrix/client/(api/v1|r0|v3|unstable)/profile/",
+ "^/_matrix/client/(v1|unstable/org.matrix.msc2716)/rooms/.*/batch_send",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
@@ -488,11 +495,23 @@ def generate_worker_files(
master_log_config = generate_worker_log_config(environ, "master", data_dir)
shared_config["log_config"] = master_log_config
+ # Find application service registrations
+ appservice_registrations = None
+ appservice_registration_dir = os.environ.get("SYNAPSE_AS_REGISTRATION_DIR")
+ if appservice_registration_dir:
+ # Scan for all YAML files that should be application service registrations.
+ appservice_registrations = [
+ str(reg_path.resolve())
+ for reg_path in Path(appservice_registration_dir).iterdir()
+ if reg_path.suffix.lower() in (".yaml", ".yml")
+ ]
+
# Shared homeserver config
convert(
"/conf/shared.yaml.j2",
"/conf/workers/shared.yaml",
shared_worker_config=yaml.dump(shared_config),
+ appservice_registrations=appservice_registrations,
)
# Nginx config
@@ -501,6 +520,8 @@ def generate_worker_files(
"/etc/nginx/conf.d/matrix-synapse.conf",
worker_locations=nginx_location_config,
upstream_directives=nginx_upstream_config,
+ tls_cert_path=os.environ.get("SYNAPSE_TLS_CERT"),
+ tls_key_path=os.environ.get("SYNAPSE_TLS_KEY"),
)
# Supervisord config
diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md
index 65570cef..8400a653 100644
--- a/docs/SUMMARY.md
+++ b/docs/SUMMARY.md
@@ -89,6 +89,7 @@
- [Database Schemas](development/database_schema.md)
- [Experimental features](development/experimental_features.md)
- [Synapse Architecture]()
+ - [Cancellation](development/synapse_architecture/cancellation.md)
- [Log Contexts](log_contexts.md)
- [Replication](replication.md)
- [TCP Replication](tcp_replication.md)
diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
index 96b3668f..d57c5aed 100644
--- a/docs/admin_api/media_admin_api.md
+++ b/docs/admin_api/media_admin_api.md
@@ -289,7 +289,7 @@ POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>
URL Parameters
-* `unix_timestamp_in_ms`: string representing a positive integer - Unix timestamp in milliseconds.
+* `before_ts`: string representing a positive integer - Unix timestamp in milliseconds.
All cached media that was last accessed before this timestamp will be removed.
Response:
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index c8794299..62f89e8c 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -115,7 +115,9 @@ URL parameters:
Body parameters:
- `password` - string, optional. If provided, the user's password is updated and all
- devices are logged out.
+ devices are logged out, unless `logout_devices` is set to `false`.
+- `logout_devices` - bool, optional, defaults to `true`. If set to false, devices aren't
+ logged out even when `password` is provided.
- `displayname` - string, optional, defaults to the value of `user_id`.
- `threepids` - array, optional, allows setting the third-party IDs (email, msisdn)
- `medium` - string. Kind of third-party ID, either `email` or `msisdn`.
diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md
index d356c72b..2b3714df 100644
--- a/docs/development/contributing_guide.md
+++ b/docs/development/contributing_guide.md
@@ -206,7 +206,32 @@ This means that we need to run our unit tests against PostgreSQL too. Our CI doe
this automatically for pull requests and release candidates, but it's sometimes
useful to reproduce this locally.
-To do so, [configure Postgres](../postgres.md) and run `trial` with the
+#### Using Docker
+
+The easiest way to do so is to run Postgres via a docker container. In one
+terminal:
+
+```shell
+docker run --rm -e POSTGRES_PASSWORD=mysecretpassword -e POSTGRES_USER=postgres -e POSTGRES_DB=postgress -p 5432:5432 postgres:14
+```
+
+If you see an error like
+
+```
+docker: Error response from daemon: driver failed programming external connectivity on endpoint nice_ride (b57bbe2e251b70015518d00c9981e8cb8346b5c785250341a6c53e3c899875f1): Error starting userland proxy: listen tcp4 0.0.0.0:5432: bind: address already in use.
+```
+
+then something is already bound to port 5432. You're probably already running postgres locally.
+
+Once you have a postgres server running, invoke `trial` in a second terminal:
+
+```shell
+SYNAPSE_POSTGRES=1 SYNAPSE_POSTGRES_HOST=127.0.0.1 SYNAPSE_POSTGRES_USER=postgres SYNAPSE_POSTGRES_PASSWORD=mysecretpassword poetry run trial tests
+````
+
+#### Using an existing Postgres installation
+
+If you have postgres already installed on your system, you can run `trial` with the
following environment variables matching your configuration:
- `SYNAPSE_POSTGRES` to anything nonempty
@@ -229,8 +254,8 @@ You don't need to specify the host, user, port or password if your Postgres
server is set to authenticate you over the UNIX socket (i.e. if the `psql` command
works without further arguments).
-Your Postgres account needs to be able to create databases.
-
+Your Postgres account needs to be able to create databases; see the postgres
+docs for [`ALTER ROLE`](https://www.postgresql.org/docs/current/sql-alterrole.html).
## Run the integration tests ([Sytest](https://github.com/matrix-org/sytest)).
@@ -397,8 +422,8 @@ same lightweight approach that the Linux Kernel
[submitting patches process](
https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>),
[Docker](https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other
-projects use: the DCO (Developer Certificate of Origin:
-http://developercertificate.org/). This is a simple declaration that you wrote
+projects use: the DCO ([Developer Certificate of Origin](http://developercertificate.org/)).
+This is a simple declaration that you wrote
the contribution or otherwise have the right to contribute it to Matrix:
```
diff --git a/docs/development/demo.md b/docs/development/demo.md
index 4277252c..893ed699 100644
--- a/docs/development/demo.md
+++ b/docs/development/demo.md
@@ -5,7 +5,7 @@
Requires you to have a [Synapse development environment setup](https://matrix-org.github.io/synapse/develop/development/contributing_guide.html#4-install-the-dependencies).
The demo setup allows running three federation Synapse servers, with server
-names `localhost:8080`, `localhost:8081`, and `localhost:8082`.
+names `localhost:8480`, `localhost:8481`, and `localhost:8482`.
You can access them via any Matrix client over HTTP at `localhost:8080`,
`localhost:8081`, and `localhost:8082` or over HTTPS at `localhost:8480`,
@@ -20,9 +20,10 @@ and the servers are configured in a highly insecure way, including:
The servers are configured to store their data under `demo/8080`, `demo/8081`, and
`demo/8082`. This includes configuration, logs, SQLite databases, and media.
-Note that when joining a public room on a different HS via "#foo:bar.net", then
-you are (in the current impl) joining a room with room_id "foo". This means that
-it won't work if your HS already has a room with that name.
+Note that when joining a public room on a different homeserver via "#foo:bar.net",
+then you are (in the current implementation) joining a room with room_id "foo".
+This means that it won't work if your homeserver already has a room with that
+name.
## Using the demo scripts
diff --git a/docs/development/synapse_architecture/cancellation.md b/docs/development/synapse_architecture/cancellation.md
new file mode 100644
index 00000000..ef9e0226
--- /dev/null
+++ b/docs/development/synapse_architecture/cancellation.md
@@ -0,0 +1,392 @@
+# Cancellation
+Sometimes, requests take a long time to service and clients disconnect
+before Synapse produces a response. To avoid wasting resources, Synapse
+can cancel request processing for select endpoints marked with the
+`@cancellable` decorator.
+
+Synapse makes use of Twisted's `Deferred.cancel()` feature to make
+cancellation work. The `@cancellable` decorator does nothing by itself
+and merely acts as a flag, signalling to developers and other code alike
+that a method can be cancelled.
+
+## Enabling cancellation for an endpoint
+1. Check that the endpoint method, and any `async` functions in its call
+ tree handle cancellation correctly. See
+ [Handling cancellation correctly](#handling-cancellation-correctly)
+ for a list of things to look out for.
+2. Add the `@cancellable` decorator to the `on_GET/POST/PUT/DELETE`
+ method. It's not recommended to make non-`GET` methods cancellable,
+ since cancellation midway through some database updates is less
+ likely to be handled correctly.
+
+## Mechanics
+There are two stages to cancellation: downward propagation of a
+`cancel()` call, followed by upwards propagation of a `CancelledError`
+out of a blocked `await`.
+Both Twisted and asyncio have a cancellation mechanism.
+
+| | Method | Exception | Exception inherits from |
+|---------------|---------------------|-----------------------------------------|-------------------------|
+| Twisted | `Deferred.cancel()` | `twisted.internet.defer.CancelledError` | `Exception` (!) |
+| asyncio | `Task.cancel()` | `asyncio.CancelledError` | `BaseException` |
+
+### Deferred.cancel()
+When Synapse starts handling a request, it runs the async method
+responsible for handling it using `defer.ensureDeferred`, which returns
+a `Deferred`. For example:
+
+```python
+def do_something() -> Deferred[None]:
+ ...
+
+@cancellable
+async def on_GET() -> Tuple[int, JsonDict]:
+ d = make_deferred_yieldable(do_something())
+ await d
+ return 200, {}
+
+request = defer.ensureDeferred(on_GET())
+```
+
+When a client disconnects early, Synapse checks for the presence of the
+`@cancellable` decorator on `on_GET`. Since `on_GET` is cancellable,
+`Deferred.cancel()` is called on the `Deferred` from
+`defer.ensureDeferred`, ie. `request`. Twisted knows which `Deferred`
+`request` is waiting on and passes the `cancel()` call on to `d`.
+
+The `Deferred` being waited on, `d`, may have its own handling for
+`cancel()` and pass the call on to other `Deferred`s.
+
+Eventually, a `Deferred` handles the `cancel()` call by resolving itself
+with a `CancelledError`.
+
+### CancelledError
+The `CancelledError` gets raised out of the `await` and bubbles up, as
+per normal Python exception handling.
+
+## Handling cancellation correctly
+In general, when writing code that might be subject to cancellation, two
+things must be considered:
+ * The effect of `CancelledError`s raised out of `await`s.
+ * The effect of `Deferred`s being `cancel()`ed.
+
+Examples of code that handles cancellation incorrectly include:
+ * `try-except` blocks which swallow `CancelledError`s.
+ * Code that shares the same `Deferred`, which may be cancelled, between
+ multiple requests.
+ * Code that starts some processing that's exempt from cancellation, but
+ uses a logging context from cancellable code. The logging context
+ will be finished upon cancellation, while the uncancelled processing
+ is still using it.
+
+Some common patterns are listed below in more detail.
+
+### `async` function calls
+Most functions in Synapse are relatively straightforward from a
+cancellation standpoint: they don't do anything with `Deferred`s and
+purely call and `await` other `async` functions.
+
+An `async` function handles cancellation correctly if its own code
+handles cancellation correctly and all the async function it calls
+handle cancellation correctly. For example:
+```python
+async def do_two_things() -> None:
+ check_something()
+ await do_something()
+ await do_something_else()
+```
+`do_two_things` handles cancellation correctly if `do_something` and
+`do_something_else` handle cancellation correctly.
+
+That is, when checking whether a function handles cancellation
+correctly, its implementation and all its `async` function calls need to
+be checked, recursively.
+
+As `check_something` is not `async`, it does not need to be checked.
+
+### CancelledErrors
+Because Twisted's `CancelledError`s are `Exception`s, it's easy to
+accidentally catch and suppress them. Care must be taken to ensure that
+`CancelledError`s are allowed to propagate upwards.
+
+<table width="100%">
+<tr>
+<td width="50%" valign="top">
+
+**Bad**:
+```python
+try:
+ await do_something()
+except Exception:
+ # `CancelledError` gets swallowed here.
+ logger.info(...)
+```
+</td>
+<td width="50%" valign="top">
+
+**Good**:
+```python
+try:
+ await do_something()
+except CancelledError:
+ raise
+except Exception:
+ logger.info(...)
+```
+</td>
+</tr>
+<tr>
+<td width="50%" valign="top">
+
+**OK**:
+```python
+try:
+ check_something()
+ # A `CancelledError` won't ever be raised here.
+except Exception:
+ logger.info(...)
+```
+</td>
+<td width="50%" valign="top">
+
+**Good**:
+```python
+try:
+ await do_something()
+except ValueError:
+ logger.info(...)
+```
+</td>
+</tr>
+</table>
+
+#### defer.gatherResults
+`defer.gatherResults` produces a `Deferred` which:
+ * broadcasts `cancel()` calls to every `Deferred` being waited on.
+ * wraps the first exception it sees in a `FirstError`.
+
+Together, this means that `CancelledError`s will be wrapped in
+a `FirstError` unless unwrapped. Such `FirstError`s are liable to be
+swallowed, so they must be unwrapped.
+
+<table width="100%">
+<tr>
+<td width="50%" valign="top">
+
+**Bad**:
+```python
+async def do_something() -> None:
+ await make_deferred_yieldable(
+ defer.gatherResults([...], consumeErrors=True)
+ )
+
+try:
+ await do_something()
+except CancelledError:
+ raise
+except Exception:
+ # `FirstError(CancelledError)` gets swallowed here.
+ logger.info(...)
+```
+
+</td>
+<td width="50%" valign="top">
+
+**Good**:
+```python
+async def do_something() -> None:
+ await make_deferred_yieldable(
+ defer.gatherResults([...], consumeErrors=True)
+ ).addErrback(unwrapFirstError)
+
+try:
+ await do_something()
+except CancelledError:
+ raise
+except Exception:
+ logger.info(...)
+```
+</td>
+</tr>
+</table>
+
+### Creation of `Deferred`s
+If a function creates a `Deferred`, the effect of cancelling it must be considered. `Deferred`s that get shared are likely to have unintended behaviour when cancelled.
+
+<table width="100%">
+<tr>
+<td width="50%" valign="top">
+
+**Bad**:
+```python
+cache: Dict[str, Deferred[None]] = {}
+
+def wait_for_room(room_id: str) -> Deferred[None]:
+ deferred = cache.get(room_id)
+ if deferred is None:
+ deferred = Deferred()
+ cache[room_id] = deferred
+ # `deferred` can have multiple waiters.
+ # All of them will observe a `CancelledError`
+ # if any one of them is cancelled.
+ return make_deferred_yieldable(deferred)
+
+# Request 1
+await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
+# Request 2
+await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
+```
+</td>
+<td width="50%" valign="top">
+
+**Good**:
+```python
+cache: Dict[str, Deferred[None]] = {}
+
+def wait_for_room(room_id: str) -> Deferred[None]:
+ deferred = cache.get(room_id)
+ if deferred is None:
+ deferred = Deferred()
+ cache[room_id] = deferred
+ # `deferred` will never be cancelled now.
+ # A `CancelledError` will still come out of
+ # the `await`.
+ # `delay_cancellation` may also be used.
+ return make_deferred_yieldable(stop_cancellation(deferred))
+
+# Request 1
+await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
+# Request 2
+await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
+```
+</td>
+</tr>
+<tr>
+<td width="50%" valign="top">
+</td>
+<td width="50%" valign="top">
+
+**Good**:
+```python
+cache: Dict[str, List[Deferred[None]]] = {}
+
+def wait_for_room(room_id: str) -> Deferred[None]:
+ if room_id not in cache:
+ cache[room_id] = []
+ # Each request gets its own `Deferred` to wait on.
+ deferred = Deferred()
+ cache[room_id]].append(deferred)
+ return make_deferred_yieldable(deferred)
+
+# Request 1
+await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
+# Request 2
+await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
+```
+</td>
+</table>
+
+### Uncancelled processing
+Some `async` functions may kick off some `async` processing which is
+intentionally protected from cancellation, by `stop_cancellation` or
+other means. If the `async` processing inherits the logcontext of the
+request which initiated it, care must be taken to ensure that the
+logcontext is not finished before the `async` processing completes.
+
+<table width="100%">
+<tr>
+<td width="50%" valign="top">
+
+**Bad**:
+```python
+cache: Optional[ObservableDeferred[None]] = None
+
+async def do_something_else(
+ to_resolve: Deferred[None]
+) -> None:
+ await ...
+ logger.info("done!")
+ to_resolve.callback(None)
+
+async def do_something() -> None:
+ if not cache:
+ to_resolve = Deferred()
+ cache = ObservableDeferred(to_resolve)
+ # `do_something_else` will never be cancelled and
+ # can outlive the `request-1` logging context.
+ run_in_background(do_something_else, to_resolve)
+
+ await make_deferred_yieldable(cache.observe())
+
+with LoggingContext("request-1"):
+ await do_something()
+```
+</td>
+<td width="50%" valign="top">
+
+**Good**:
+```python
+cache: Optional[ObservableDeferred[None]] = None
+
+async def do_something_else(
+ to_resolve: Deferred[None]
+) -> None:
+ await ...
+ logger.info("done!")
+ to_resolve.callback(None)
+
+async def do_something() -> None:
+ if not cache:
+ to_resolve = Deferred()
+ cache = ObservableDeferred(to_resolve)
+ run_in_background(do_something_else, to_resolve)
+ # We'll wait until `do_something_else` is
+ # done before raising a `CancelledError`.
+ await make_deferred_yieldable(
+ delay_cancellation(cache.observe())
+ )
+ else:
+ await make_deferred_yieldable(cache.observe())
+
+with LoggingContext("request-1"):
+ await do_something()
+```
+</td>
+</tr>
+<tr>
+<td width="50%">
+
+**OK**:
+```python
+cache: Optional[ObservableDeferred[None]] = None
+
+async def do_something_else(
+ to_resolve: Deferred[None]
+) -> None:
+ await ...
+ logger.info("done!")
+ to_resolve.callback(None)
+
+async def do_something() -> None:
+ if not cache:
+ to_resolve = Deferred()
+ cache = ObservableDeferred(to_resolve)
+ # `do_something_else` will get its own independent
+ # logging context. `request-1` will not count any
+ # metrics from `do_something_else`.
+ run_as_background_process(
+ "do_something_else",
+ do_something_else,
+ to_resolve,
+ )
+
+ await make_deferred_yieldable(cache.observe())
+
+with LoggingContext("request-1"):
+ await do_something()
+```
+</td>
+<td width="50%">
+</td>
+</tr>
+</table>
diff --git a/docs/message_retention_policies.md b/docs/message_retention_policies.md
index 9214d6d7..b52c4aaa 100644
--- a/docs/message_retention_policies.md
+++ b/docs/message_retention_policies.md
@@ -117,7 +117,7 @@ In this example, we define three jobs:
Note that this example is tailored to show different configurations and
features slightly more jobs than it's probably necessary (in practice, a
server admin would probably consider it better to replace the two last
-jobs with one that runs once a day and handles rooms which which
+jobs with one that runs once a day and handles rooms which
policy's `max_lifetime` is greater than 3 days).
Keep in mind, when configuring these jobs, that a purge job can become
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 472d9571..ad35e667 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -12,21 +12,27 @@ The available spam checker callbacks are:
_First introduced in Synapse v1.37.0_
+_Changed in Synapse v1.60.0: `synapse.module_api.NOT_SPAM` and `synapse.module_api.errors.Codes` can be returned by this callback. Returning a boolean or a string is now deprecated._
+
```python
-async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str]
+async def check_event_for_spam(event: "synapse.module_api.EventBase") -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes", str, bool]
```
-Called when receiving an event from a client or via federation. The callback must return
-either:
-- an error message string, to indicate the event must be rejected because of spam and
- give a rejection reason to forward to clients;
-- the boolean `True`, to indicate that the event is spammy, but not provide further details; or
-- the booelan `False`, to indicate that the event is not considered spammy.
+Called when receiving an event from a client or via federation. The callback must return one of:
+ - `synapse.module_api.NOT_SPAM`, to allow the operation. Other callbacks may still
+ decide to reject it.
+ - `synapse.module_api.errors.Codes` to reject the operation with an error code. In case
+ of doubt, `synapse.module_api.errors.Codes.FORBIDDEN` is a good error code.
+ - (deprecated) a non-`Codes` `str` to reject the operation and specify an error message. Note that clients
+ typically will not localize the error message to the user's preferred locale.
+ - (deprecated) `False`, which is the same as returning `synapse.module_api.NOT_SPAM`.
+ - (deprecated) `True`, which is the same as returning `synapse.module_api.errors.Codes.FORBIDDEN`.
If multiple modules implement this callback, they will be considered in order. If a
-callback returns `False`, Synapse falls through to the next one. The value of the first
-callback that does not return `False` will be used. If this happens, Synapse will not call
-any of the subsequent implementations of this callback.
+callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
+The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
+be used. If this happens, Synapse will not call any of the subsequent implementations of
+this callback.
### `user_may_join_room`
@@ -249,6 +255,24 @@ callback returns `False`, Synapse falls through to the next one. The value of th
callback that does not return `False` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.
+### `should_drop_federated_event`
+
+_First introduced in Synapse v1.60.0_
+
+```python
+async def should_drop_federated_event(event: "synapse.events.EventBase") -> bool
+```
+
+Called when checking whether a remote server can federate an event with us. **Returning
+`True` from this function will silently drop a federated event and split-brain our view
+of a room's DAG, and thus you shouldn't use this callback unless you know what you are
+doing.**
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `False`, Synapse falls through to the next one. The value of the first
+callback that does not return `False` will be used. If this happens, Synapse will not call
+any of the subsequent implementations of this callback.
+
## Example
The example below is a module that implements the spam checker callback
diff --git a/docs/openid.md b/docs/openid.md
index 19cacaaf..9d615a57 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -159,7 +159,7 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
oidc_providers:
- idp_id: keycloak
idp_name: "My KeyCloak server"
- issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
+ issuer: "https://127.0.0.1:8443/realms/{realm_name}"
client_id: "synapse"
client_secret: "copy secret generated from above"
scopes: ["openid", "profile"]
@@ -293,7 +293,7 @@ can be used to retrieve information on the authenticated user. As the Synapse
login mechanism needs an attribute to uniquely identify users, and that endpoint
does not return a `sub` property, an alternative `subject_claim` has to be set.
-1. Create a new OAuth application: https://github.com/settings/applications/new.
+1. Create a new OAuth application: [https://github.com/settings/applications/new](https://github.com/settings/applications/new).
2. Set the callback URL to `[synapse public baseurl]/_synapse/client/oidc/callback`.
Synapse config:
@@ -322,10 +322,10 @@ oidc_providers:
[Google][google-idp] is an OpenID certified authentication and authorisation provider.
-1. Set up a project in the Google API Console (see
- https://developers.google.com/identity/protocols/oauth2/openid-connect#appsetup).
-2. Add an "OAuth Client ID" for a Web Application under "Credentials".
-3. Copy the Client ID and Client Secret, and add the following to your synapse config:
+1. Set up a project in the Google API Console (see
+ [documentation](https://developers.google.com/identity/protocols/oauth2/openid-connect#appsetup)).
+3. Add an "OAuth Client ID" for a Web Application under "Credentials".
+4. Copy the Client ID and Client Secret, and add the following to your synapse config:
```yaml
oidc_providers:
- idp_id: google
@@ -501,8 +501,8 @@ As well as the private key file, you will need:
* Team ID: a 10-character ID associated with your developer account.
* Key ID: the 10-character identifier for the key.
-https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
-documentation on setting up SiWA.
+[Apple's developer documentation](https://help.apple.com/developer-account/?lang=en#/dev77c875b7e)
+has more information on setting up SiWA.
The synapse config will look like this:
@@ -535,8 +535,8 @@ needed to add OAuth2 capabilities to your Django projects. It supports
Configuration on Django's side:
-1. Add an application: https://example.com/admin/oauth2_provider/application/add/ and choose parameters like this:
-* `Redirect uris`: https://synapse.example.com/_synapse/client/oidc/callback
+1. Add an application: `https://example.com/admin/oauth2_provider/application/add/` and choose parameters like this:
+* `Redirect uris`: `https://synapse.example.com/_synapse/client/oidc/callback`
* `Client type`: `Confidential`
* `Authorization grant type`: `Authorization code`
* `Algorithm`: `HMAC with SHA-2 256`
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index a803b826..56a25c53 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -289,7 +289,7 @@ presence:
# federation: the server-server API (/_matrix/federation). Also implies
# 'media', 'keys', 'openid'
#
-# keys: the key discovery API (/_matrix/keys).
+# keys: the key discovery API (/_matrix/key).
#
# media: the media API (/_matrix/media).
#
@@ -730,6 +730,12 @@ retention:
# A cache 'factor' is a multiplier that can be applied to each of
# Synapse's caches in order to increase or decrease the maximum
# number of entries that can be stored.
+#
+# The configuration for cache factors (caches.global_factor and
+# caches.per_cache_factors) can be reloaded while the application is running,
+# by sending a SIGHUP signal to the Synapse process. Changes to other parts of
+# the caching config will NOT be applied after a SIGHUP is received; a restart
+# is necessary.
# The number of events to cache in memory. Not affected by
# caches.global_factor.
@@ -778,6 +784,24 @@ caches:
#
#cache_entry_ttl: 30m
+ # This flag enables cache autotuning, and is further specified by the sub-options `max_cache_memory_usage`,
+ # `target_cache_memory_usage`, `min_cache_ttl`. These flags work in conjunction with each other to maintain
+ # a balance between cache memory usage and cache entry availability. You must be using jemalloc to utilize
+ # this option, and all three of the options must be specified for this feature to work.
+ #cache_autotuning:
+ # This flag sets a ceiling on much memory the cache can use before caches begin to be continuously evicted.
+ # They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in
+ # the flag below, or until the `min_cache_ttl` is hit.
+ #max_cache_memory_usage: 1024M
+
+ # This flag sets a rough target for the desired memory usage of the caches.
+ #target_cache_memory_usage: 758M
+
+ # 'min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when
+ # caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches
+ # from being emptied while Synapse is evicting due to memory.
+ #min_cache_ttl: 5m
+
# Controls how long the results of a /sync request are cached for after
# a successful response is returned. A higher duration can help clients with
# intermittent connections, at the cost of higher memory usage.
@@ -2192,7 +2216,9 @@ sso:
password_config:
- # Uncomment to disable password login
+ # Uncomment to disable password login.
+ # Set to `only_for_reauth` to permit reauthentication for users that
+ # have passwords and are already logged in.
#
#enabled: false
@@ -2462,15 +2488,39 @@ push:
#
#encryption_enabled_by_default_for_room_type: invite
-
-# Uncomment to allow non-server-admin users to create groups on this server
-#
-#enable_group_creation: true
-
-# If enabled, non server admins can only create groups with local parts
-# starting with this prefix
-#
-#group_creation_prefix: "unofficial_"
+# Override the default power levels for rooms created on this server, per
+# room creation preset.
+#
+# The appropriate dictionary for the room preset will be applied on top
+# of the existing power levels content.
+#
+# Useful if you know that your users need special permissions in rooms
+# that they create (e.g. to send particular types of state events without
+# needing an elevated power level). This takes the same shape as the
+# `power_level_content_override` parameter in the /createRoom API, but
+# is applied before that parameter.
+#
+# Valid keys are some or all of `private_chat`, `trusted_private_chat`
+# and `public_chat`. Inside each of those should be any of the
+# properties allowed in `power_level_content_override` in the
+# /createRoom API. If any property is missing, its default value will
+# continue to be used. If any property is present, it will overwrite
+# the existing default completely (so if the `events` property exists,
+# the default event power levels will be ignored).
+#
+#default_power_level_content_override:
+# private_chat:
+# "events":
+# "com.example.myeventtype" : 0
+# "m.room.avatar": 50
+# "m.room.canonical_alias": 50
+# "m.room.encryption": 100
+# "m.room.history_visibility": 100
+# "m.room.name": 50
+# "m.room.power_levels": 100
+# "m.room.server_acl": 100
+# "m.room.tombstone": 100
+# "events_default": 1
diff --git a/docs/structured_logging.md b/docs/structured_logging.md
index a6667e1a..d43dc9eb 100644
--- a/docs/structured_logging.md
+++ b/docs/structured_logging.md
@@ -43,7 +43,7 @@ loggers:
The above logging config will set Synapse as 'INFO' logging level by default,
with the SQL layer at 'WARNING', and will log to a file, stored as JSON.
-It is also possible to figure Synapse to log to a remote endpoint by using the
+It is also possible to configure Synapse to log to a remote endpoint by using the
`synapse.logging.RemoteHandler` class included with Synapse. It takes the
following arguments:
diff --git a/docs/upgrade.md b/docs/upgrade.md
index fa4b3ef5..5ac29abb 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -89,6 +89,143 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
+# Upgrading to v1.61.0
+
+## Removal of deprecated community/groups
+
+This release of Synapse will remove deprecated community/groups from codebase.
+
+### Worker endpoints
+
+For those who have deployed workers, following worker endpoints will no longer
+exist and they can be removed from the reverse proxy configuration:
+
+- `^/_matrix/federation/v1/get_groups_publicised$`
+- `^/_matrix/client/(r0|v3|unstable)/joined_groups$`
+- `^/_matrix/client/(r0|v3|unstable)/publicised_groups$`
+- `^/_matrix/client/(r0|v3|unstable)/publicised_groups/`
+- `^/_matrix/federation/v1/groups/`
+- `^/_matrix/client/(r0|v3|unstable)/groups/`
+
+# Upgrading to v1.60.0
+
+## Adding a new unique index to `state_group_edges` could fail if your database is corrupted
+
+This release of Synapse will add a unique index to the `state_group_edges` table, in order
+to prevent accidentally introducing duplicate information (for example, because a database
+backup was restored multiple times).
+
+Duplicate rows being present in this table could cause drastic performance problems; see
+[issue 11779](https://github.com/matrix-org/synapse/issues/11779) for more details.
+
+If your Synapse database already has had duplicate rows introduced into this table,
+this could fail, with either of these errors:
+
+
+**On Postgres:**
+```
+synapse.storage.background_updates - 623 - INFO - background_updates-0 - Adding index state_group_edges_unique_idx to state_group_edges
+synapse.storage.background_updates - 282 - ERROR - background_updates-0 - Error doing update
+...
+psycopg2.errors.UniqueViolation: could not create unique index "state_group_edges_unique_idx"
+DETAIL: Key (state_group, prev_state_group)=(2, 1) is duplicated.
+```
+(The numbers may be different.)
+
+**On SQLite:**
+```
+synapse.storage.background_updates - 623 - INFO - background_updates-0 - Adding index state_group_edges_unique_idx to state_group_edges
+synapse.storage.background_updates - 282 - ERROR - background_updates-0 - Error doing update
+...
+sqlite3.IntegrityError: UNIQUE constraint failed: state_group_edges.state_group, state_group_edges.prev_state_group
+```
+
+
+<details>
+<summary><b>Expand this section for steps to resolve this problem</b></summary>
+
+### On Postgres
+
+Connect to your database with `psql`.
+
+```sql
+BEGIN;
+DELETE FROM state_group_edges WHERE (ctid, state_group, prev_state_group) IN (
+ SELECT row_id, state_group, prev_state_group
+ FROM (
+ SELECT
+ ctid AS row_id,
+ MIN(ctid) OVER (PARTITION BY state_group, prev_state_group) AS min_row_id,
+ state_group,
+ prev_state_group
+ FROM state_group_edges
+ ) AS t1
+ WHERE row_id <> min_row_id
+);
+COMMIT;
+```
+
+
+### On SQLite
+
+At the command-line, use `sqlite3 path/to/your-homeserver-database.db`:
+
+```sql
+BEGIN;
+DELETE FROM state_group_edges WHERE (rowid, state_group, prev_state_group) IN (
+ SELECT row_id, state_group, prev_state_group
+ FROM (
+ SELECT
+ rowid AS row_id,
+ MIN(rowid) OVER (PARTITION BY state_group, prev_state_group) AS min_row_id,
+ state_group,
+ prev_state_group
+ FROM state_group_edges
+ )
+ WHERE row_id <> min_row_id
+);
+COMMIT;
+```
+
+
+### For more details
+
+[This comment on issue 11779](https://github.com/matrix-org/synapse/issues/11779#issuecomment-1131545970)
+has queries that can be used to check a database for this problem in advance.
+
+</details>
+
+## New signature for the spam checker callback `check_event_for_spam`
+
+The previous signature has been deprecated.
+
+Whereas `check_event_for_spam` callbacks used to return `Union[str, bool]`, they should now return `Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]`.
+
+This is part of an ongoing refactoring of the SpamChecker API to make it less ambiguous and more powerful.
+
+If your module implements `check_event_for_spam` as follows:
+
+```python
+async def check_event_for_spam(event):
+ if ...:
+ # Event is spam
+ return True
+ # Event is not spam
+ return False
+```
+
+you should rewrite it as follows:
+
+```python
+async def check_event_for_spam(event):
+ if ...:
+ # Event is spam, mark it as forbidden (you may use some more precise error
+ # code if it is useful).
+ return synapse.module_api.errors.Codes.FORBIDDEN
+ # Event is not spam, mark it as such.
+ return synapse.module_api.NOT_SPAM
+```
+
# Upgrading to v1.59.0
## Device name lookup over federation has been disabled by default
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 21dad0ac..392ae80a 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -23,6 +23,14 @@ followed by a letter. Letters have the following meanings:
For example, setting `redaction_retention_period: 5m` would remove redacted
messages from the database after 5 minutes, rather than 5 months.
+In addition, configuration options referring to size use the following suffixes:
+
+* `M` = MiB, or 1,048,576 bytes
+* `K` = KiB, or 1024 bytes
+
+For example, setting `max_avatar_size: 10M` means that Synapse will not accept files larger than 10,485,760 bytes
+for a user avatar.
+
### YAML
The configuration file is a [YAML](https://yaml.org/) file, which means that certain syntax rules
apply if you want your config file to be read properly. A few helpful things to know:
@@ -467,13 +475,13 @@ Sub-options for each listener include:
Valid resource names are:
-* `client`: the client-server API (/_matrix/client), and the synapse admin API (/_synapse/admin). Also implies 'media' and 'static'.
+* `client`: the client-server API (/_matrix/client), and the synapse admin API (/_synapse/admin). Also implies `media` and `static`.
* `consent`: user consent forms (/_matrix/consent). See [here](../../consent_tracking.md) for more.
* `federation`: the server-server API (/_matrix/federation). Also implies `media`, `keys`, `openid`
-* `keys`: the key discovery API (/_matrix/keys).
+* `keys`: the key discovery API (/_matrix/key).
* `media`: the media API (/_matrix/media).
@@ -567,6 +575,18 @@ Example configuration:
dummy_events_threshold: 5
```
---
+Config option `delete_stale_devices_after`
+
+An optional duration. If set, Synapse will run a daily background task to log out and
+delete any device that hasn't been accessed for more than the specified amount of time.
+
+Defaults to no duration, which means devices are never pruned.
+
+Example configuration:
+```yaml
+delete_stale_devices_after: 1y
+```
+
## Homeserver blocking ##
Useful options for Synapse admins.
@@ -1119,7 +1139,22 @@ Caching can be configured through the following sub-options:
with intermittent connections, at the cost of higher memory usage.
By default, this is zero, which means that sync responses are not cached
at all.
-
+* `cache_autotuning` and its sub-options `max_cache_memory_usage`, `target_cache_memory_usage`, and
+ `min_cache_ttl` work in conjunction with each other to maintain a balance between cache memory
+ usage and cache entry availability. You must be using [jemalloc](https://github.com/matrix-org/synapse#help-synapse-is-slow-and-eats-all-my-ramcpu)
+ to utilize this option, and all three of the options must be specified for this feature to work. This option
+ defaults to off, enable it by providing values for the sub-options listed below. Please note that the feature will not work
+ and may cause unstable behavior (such as excessive emptying of caches or exceptions) if all of the values are not provided.
+ Please see the [Config Conventions](#config-conventions) for information on how to specify memory size and cache expiry
+ durations.
+ * `max_cache_memory_usage` sets a ceiling on how much memory the cache can use before caches begin to be continuously evicted.
+ They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in
+ the setting below, or until the `min_cache_ttl` is hit. There is no default value for this option.
+ * `target_memory_usage` sets a rough target for the desired memory usage of the caches. There is no default value
+ for this option.
+ * `min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when
+ caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches
+ from being emptied while Synapse is evicting due to memory. There is no default value for this option.
Example configuration:
```yaml
@@ -1127,9 +1162,29 @@ caches:
global_factor: 1.0
per_cache_factors:
get_users_who_share_room_with_user: 2.0
- expire_caches: false
sync_response_cache_duration: 2m
+ cache_autotuning:
+ max_cache_memory_usage: 1024M
+ target_cache_memory_usage: 758M
+ min_cache_ttl: 5m
+```
+
+### Reloading cache factors
+
+The cache factors (i.e. `caches.global_factor` and `caches.per_cache_factors`) may be reloaded at any time by sending a
+[`SIGHUP`](https://en.wikipedia.org/wiki/SIGHUP) signal to Synapse using e.g.
+
+```commandline
+kill -HUP [PID_OF_SYNAPSE_PROCESS]
```
+
+If you are running multiple workers, you must individually update the worker
+config file and send this signal to each worker process.
+
+If you're using the [example systemd service](https://github.com/matrix-org/synapse/blob/develop/contrib/systemd/matrix-synapse.service)
+file in Synapse's `contrib` directory, you can send a `SIGHUP` signal by using
+`systemctl reload matrix-synapse`.
+
---
## Database ##
Config options related to database settings.
@@ -1164,7 +1219,7 @@ For more information on using Synapse with Postgres,
see [here](../../postgres.md).
Example SQLite configuration:
-```
+```yaml
database:
name: sqlite3
args:
@@ -1172,7 +1227,7 @@ database:
```
Example Postgres configuration:
-```
+```yaml
database:
name: psycopg2
txn_limit: 10000
@@ -1327,6 +1382,20 @@ This option sets ratelimiting how often invites can be sent in a room or to a
specific user. `per_room` defaults to `per_second: 0.3`, `burst_count: 10` and
`per_user` defaults to `per_second: 0.003`, `burst_count: 5`.
+Client requests that invite user(s) when [creating a
+room](https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom)
+will count against the `rc_invites.per_room` limit, whereas
+client requests to [invite a single user to a
+room](https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite)
+will count against both the `rc_invites.per_user` and `rc_invites.per_room` limits.
+
+Federation requests to invite a user will count against the `rc_invites.per_user`
+limit only, as Synapse presumes ratelimiting by room will be done by the sending server.
+
+The `rc_invites.per_user` limit applies to the *receiver* of the invite, rather than the
+sender, meaning that a `rc_invite.per_user.burst_count` of 5 mandates that a single user
+cannot *receive* more than a burst of 5 invites at a time.
+
Example configuration:
```yaml
rc_invites:
@@ -1390,7 +1459,7 @@ federation_rr_transactions_per_room_per_second: 40
```
---
## Media Store ##
-Config options relating to Synapse media store.
+Config options related to Synapse's media store.
---
Config option: `enable_media_repo`
@@ -1494,6 +1563,39 @@ thumbnail_sizes:
height: 600
method: scale
```
+---
+Config option: `media_retention`
+
+Controls whether local media and entries in the remote media cache
+(media that is downloaded from other homeservers) should be removed
+under certain conditions, typically for the purpose of saving space.
+
+Purging media files will be the carried out by the media worker
+(that is, the worker that has the `enable_media_repo` homeserver config
+option set to 'true'). This may be the main process.
+
+The `media_retention.local_media_lifetime` and
+`media_retention.remote_media_lifetime` config options control whether
+media will be purged if it has not been accessed in a given amount of
+time. Note that media is 'accessed' when loaded in a room in a client, or
+otherwise downloaded by a local or remote user. If the media has never
+been accessed, the media's creation time is used instead. Both thumbnails
+and the original media will be removed. If either of these options are unset,
+then media of that type will not be purged.
+
+Local or cached remote media that has been
+[quarantined](../../admin_api/media_admin_api.md#quarantining-media-in-a-room)
+will not be deleted. Similarly, local media that has been marked as
+[protected from quarantine](../../admin_api/media_admin_api.md#protecting-media-from-being-quarantined)
+will not be deleted.
+
+Example configuration:
+```yaml
+media_retention:
+ local_media_lifetime: 90d
+ remote_media_lifetime: 14d
+```
+---
Config option: `url_preview_enabled`
This setting determines whether the preview URL API is enabled.
@@ -1635,10 +1737,10 @@ Defaults to "en".
Example configuration:
```yaml
url_preview_accept_language:
- - en-UK
- - en-US;q=0.9
- - fr;q=0.8
- - *;q=0.7
+ - 'en-UK'
+ - 'en-US;q=0.9'
+ - 'fr;q=0.8'
+ - '*;q=0.7'
```
----
Config option: `oembed`
@@ -2873,6 +2975,9 @@ Use this setting to enable password-based logins.
This setting has the following sub-options:
* `enabled`: Defaults to true.
+ Set to false to disable password authentication.
+ Set to `only_for_reauth` to allow users with existing passwords to use them
+ to log in and reauthenticate, whilst preventing new users from setting passwords.
* `localdb_enabled`: Set to false to disable authentication against the local password
database. This is ignored if `enabled` is false, and is only useful
if you have other `password_providers`. Defaults to true.
@@ -3088,25 +3193,6 @@ Example configuration:
encryption_enabled_by_default_for_room_type: invite
```
---
-Config option: `enable_group_creation`
-
-Set to true to allow non-server-admin users to create groups on this server
-
-Example configuration:
-```yaml
-enable_group_creation: true
-```
----
-Config option: `group_creation_prefix`
-
-If enabled/present, non-server admins can only create groups with local parts
-starting with this prefix.
-
-Example configuration:
-```yaml
-group_creation_prefix: "unofficial_"
-```
----
Config option: `user_directory`
This setting defines options related to the user directory.
@@ -3298,6 +3384,32 @@ room_list_publication_rules:
room_id: "*"
action: allow
```
+
+---
+Config option: `default_power_level_content_override`
+
+The `default_power_level_content_override` option controls the default power
+levels for rooms.
+
+Useful if you know that your users need special permissions in rooms
+that they create (e.g. to send particular types of state events without
+needing an elevated power level). This takes the same shape as the
+`power_level_content_override` parameter in the /createRoom API, but
+is applied before that parameter.
+
+Note that each key provided inside a preset (for example `events` in the example
+below) will overwrite all existing defaults inside that key. So in the example
+below, newly-created private_chat rooms will have no rules for any event types
+except `com.example.foo`.
+
+Example configuration:
+```yaml
+default_power_level_content_override:
+ private_chat: { "events": { "com.example.foo" : 0 } }
+ trusted_private_chat: null
+ public_chat: null
+```
+
---
## Opentracing ##
Configuration options related to Opentracing support.
@@ -3398,7 +3510,7 @@ stream_writers:
typing: worker1
```
---
-Config option: `run_background_task_on`
+Config option: `run_background_tasks_on`
The worker that is used to run background tasks (e.g. cleaning up expired
data). If not provided this defaults to the main process.
diff --git a/docs/welcome_and_overview.md b/docs/welcome_and_overview.md
index aab2d6b4..451759f0 100644
--- a/docs/welcome_and_overview.md
+++ b/docs/welcome_and_overview.md
@@ -7,10 +7,10 @@ team.
## Installing and using Synapse
This documentation covers topics for **installation**, **configuration** and
-**maintainence** of your Synapse process:
+**maintenance** of your Synapse process:
* Learn how to [install](setup/installation.md) and
- [configure](usage/configuration/index.html) your own instance, perhaps with [Single
+ [configure](usage/configuration/config_documentation.md) your own instance, perhaps with [Single
Sign-On](usage/configuration/user_authentication/index.html).
* See how to [upgrade](upgrade.md) between Synapse versions.
@@ -65,7 +65,7 @@ following documentation:
Want to help keep Synapse going but don't know how to code? Synapse is a
[Matrix.org Foundation](https://matrix.org) project. Consider becoming a
-supportor on [Liberapay](https://liberapay.com/matrixdotorg),
+supporter on [Liberapay](https://liberapay.com/matrixdotorg),
[Patreon](https://patreon.com/matrixdotorg) or through
[PayPal](https://paypal.me/matrixdotorg) via a one-time donation.
diff --git a/docs/workers.md b/docs/workers.md
index 553792d2..6969c424 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -1,6 +1,6 @@
# Scaling synapse via workers
-For small instances it recommended to run Synapse in the default monolith mode.
+For small instances it is recommended to run Synapse in the default monolith mode.
For larger instances where performance is a concern it can be helpful to split
out functionality into multiple separate python processes. These processes are
called 'workers', and are (eventually) intended to scale horizontally
@@ -191,9 +191,8 @@ information.
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/user/devices/
- ^/_matrix/federation/v1/get_groups_publicised$
^/_matrix/key/v2/query
- ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/
+ ^/_matrix/federation/v1/hierarchy/
# Inbound federation transaction request
^/_matrix/federation/v1/send/
@@ -205,15 +204,14 @@ information.
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$
- ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$
+ ^/_matrix/client/v1/rooms/.*/hierarchy$
+ ^/_matrix/client/unstable/org.matrix.msc2716/rooms/.*/batch_send$
^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$
^/_matrix/client/(r0|v3|unstable)/account/3pid$
+ ^/_matrix/client/(r0|v3|unstable)/account/whoami$
^/_matrix/client/(r0|v3|unstable)/devices$
^/_matrix/client/versions$
^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$
- ^/_matrix/client/(r0|v3|unstable)/joined_groups$
- ^/_matrix/client/(r0|v3|unstable)/publicised_groups$
- ^/_matrix/client/(r0|v3|unstable)/publicised_groups/
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/
^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$
^/_matrix/client/(api/v1|r0|v3|unstable)/search$
@@ -237,9 +235,6 @@ information.
^/_matrix/client/(api/v1|r0|v3|unstable)/join/
^/_matrix/client/(api/v1|r0|v3|unstable)/profile/
- # Device requests
- ^/_matrix/client/(r0|v3|unstable)/sendToDevice/
-
# Account data requests
^/_matrix/client/(r0|v3|unstable)/.*/tags
^/_matrix/client/(r0|v3|unstable)/.*/account_data
@@ -251,12 +246,12 @@ information.
# Presence requests
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
+ # User directory search requests
+ ^/_matrix/client/(r0|v3|unstable)/user_directory/search$
Additionally, the following REST endpoints can be handled for GET requests:
- ^/_matrix/federation/v1/groups/
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
- ^/_matrix/client/(r0|v3|unstable)/groups/
Pagination requests can also be handled, but all requests for a given
room must be routed to the same instance. Additionally, care must be taken to
@@ -448,6 +443,14 @@ update_user_directory_from_worker: worker_name
This work cannot be load-balanced; please ensure the main process is restarted
after setting this option in the shared configuration!
+User directory updates allow REST endpoints matching the following regular
+expressions to work:
+
+ ^/_matrix/client/(r0|v3|unstable)/user_directory/search$
+
+The above endpoints can be routed to any worker, though you may choose to route
+it to the chosen user directory worker.
+
This style of configuration supersedes the legacy `synapse.app.user_dir`
worker application type.
diff --git a/mypy.ini b/mypy.ini
index ba0de419..fe3e3f9b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -10,6 +10,7 @@ warn_unreachable = True
warn_unused_ignores = True
local_partial_types = True
no_implicit_optional = True
+disallow_untyped_defs = True
files =
docker/,
@@ -27,9 +28,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
- |synapse/storage/databases/main/event_federation.py
- |synapse/storage/databases/main/push_rule.py
- |synapse/storage/databases/main/roommember.py
|synapse/storage/schema/
|tests/api/test_auth.py
@@ -43,16 +41,11 @@ exclude = (?x)
|tests/events/test_utils.py
|tests/federation/test_federation_catch_up.py
|tests/federation/test_federation_sender.py
- |tests/federation/test_federation_server.py
|tests/federation/transport/test_knocking.py
- |tests/federation/transport/test_server.py
|tests/handlers/test_typing.py
|tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py
- |tests/http/test_fedclient.py
|tests/http/test_proxyagent.py
- |tests/http/test_servlet.py
- |tests/http/test_site.py
|tests/logging/__init__.py
|tests/logging/test_terse_json.py
|tests/module_api/test_api.py
@@ -61,12 +54,9 @@ exclude = (?x)
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py
|tests/rest/media/v1/test_media_storage.py
- |tests/scripts/test_new_matrix_user.py
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
- |tests/storage/test_base.py
- |tests/storage/test_roommember.py
|tests/test_metrics.py
|tests/test_server.py
|tests/test_state.py
@@ -89,129 +79,37 @@ exclude = (?x)
|tests/utils.py
)$
-[mypy-synapse._scripts.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.api.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.app.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.appservice.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.config.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.crypto.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.event_auth]
-disallow_untyped_defs = True
-
-[mypy-synapse.events.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.federation.*]
-disallow_untyped_defs = True
-
[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False
-[mypy-synapse.handlers.*]
-disallow_untyped_defs = True
+[mypy-synapse.http.client]
+disallow_untyped_defs = False
-[mypy-synapse.http.server]
-disallow_untyped_defs = True
+[mypy-synapse.http.matrixfederationclient]
+disallow_untyped_defs = False
-[mypy-synapse.logging.context]
-disallow_untyped_defs = True
+[mypy-synapse.logging.opentracing]
+disallow_untyped_defs = False
-[mypy-synapse.metrics.*]
-disallow_untyped_defs = True
+[mypy-synapse.logging.scopecontextmanager]
+disallow_untyped_defs = False
[mypy-synapse.metrics._reactor_metrics]
+disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
# See https://github.com/matrix-org/synapse/pull/11771.
warn_unused_ignores = False
-[mypy-synapse.module_api.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.notifier]
-disallow_untyped_defs = True
-
-[mypy-synapse.push.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.replication.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.rest.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.server_notices.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.state.*]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.account_data]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.client_ips]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.directory]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.e2e_room_keys]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.end_to_end_keys]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.event_push_actions]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.events_bg_updates]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.events_worker]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.room]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.room_batch]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.profile]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.stats]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.state_deltas]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.transactions]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.databases.main.user_erasure_store]
-disallow_untyped_defs = True
-
-[mypy-synapse.storage.util.*]
-disallow_untyped_defs = True
+[mypy-synapse.util.caches.treecache]
+disallow_untyped_defs = False
-[mypy-synapse.streams.*]
-disallow_untyped_defs = True
+[mypy-synapse.server]
+disallow_untyped_defs = False
-[mypy-synapse.util.*]
-disallow_untyped_defs = True
+[mypy-synapse.storage.database]
+disallow_untyped_defs = False
-[mypy-synapse.util.caches.treecache]
+[mypy-tests.*]
disallow_untyped_defs = False
[mypy-tests.handlers.test_user_directory]
diff --git a/poetry.lock b/poetry.lock
index 49a912a5..7c561e31 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -813,7 +813,7 @@ python-versions = ">=3.5"
[[package]]
name = "pyjwt"
-version = "2.3.0"
+version = "2.4.0"
description = "JSON Web Token implementation in Python"
category = "main"
optional = false
@@ -1355,7 +1355,7 @@ python-versions = "*"
[[package]]
name = "types-jsonschema"
-version = "4.4.1"
+version = "4.4.6"
description = "Typing stubs for jsonschema"
category = "dev"
optional = false
@@ -1563,7 +1563,7 @@ url_preview = ["lxml"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7.1"
-content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1"
+content-hash = "539e5326f401472d1ffc8325d53d72e544cd70156b3f43f32f1285c4c131f831"
[metadata.files]
attrs = [
@@ -2264,8 +2264,8 @@ pygments = [
{file = "Pygments-2.11.2.tar.gz", hash = "sha256:4e426f72023d88d03b2fa258de560726ce890ff3b630f88c21cbb8b2503b8c6a"},
]
pyjwt = [
- {file = "PyJWT-2.3.0-py3-none-any.whl", hash = "sha256:e0c4bb8d9f0af0c7f5b1ec4c5036309617d03d56932877f2f7a0beeb5318322f"},
- {file = "PyJWT-2.3.0.tar.gz", hash = "sha256:b888b4d56f06f6dcd777210c334e69c737be74755d3e5e9ee3fe67dc18a0ee41"},
+ {file = "PyJWT-2.4.0-py3-none-any.whl", hash = "sha256:72d1d253f32dbd4f5c88eaf1fdc62f3a19f676ccbadb9dbc5d07e951b2b26daf"},
+ {file = "PyJWT-2.4.0.tar.gz", hash = "sha256:d42908208c699b3b973cbeb01a969ba6a96c821eefb1c5bfe4c390c01d67abba"},
]
pymacaroons = [
{file = "pymacaroons-0.13.0-py2.py3-none-any.whl", hash = "sha256:3e14dff6a262fdbf1a15e769ce635a8aea72e6f8f91e408f9a97166c53b91907"},
@@ -2618,8 +2618,8 @@ types-ipaddress = [
{file = "types_ipaddress-1.0.8-py3-none-any.whl", hash = "sha256:4933b74da157ba877b1a705d64f6fa7742745e9ffd65e51011f370c11ebedb55"},
]
types-jsonschema = [
- {file = "types-jsonschema-4.4.1.tar.gz", hash = "sha256:bd68b75217ebbb33b0242db10047581dad3b061a963a46ee80d4a9044080663e"},
- {file = "types_jsonschema-4.4.1-py3-none-any.whl", hash = "sha256:ab3ecfdc912d6091cc82f4b7556cfbf1a7cbabc26da0ceaa1cbbc232d1d09971"},
+ {file = "types-jsonschema-4.4.6.tar.gz", hash = "sha256:7f2a804618756768c7c0616f8c794b61fcfe3077c7ee1ad47dcf01c5e5f692bb"},
+ {file = "types_jsonschema-4.4.6-py3-none-any.whl", hash = "sha256:1db9031ca49a8444d01bd2ce8cf2f89318382b04610953b108321e6f8fb03390"},
]
types-opentracing = [
{file = "types-opentracing-2.4.7.tar.gz", hash = "sha256:be60e9618355aa892571ace002e6b353702538b1c0dc4fbc1c921219d6658830"},
diff --git a/pyproject.toml b/pyproject.toml
index 5a5a2eab..8b21bdc8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,7 +54,7 @@ skip_gitignore = true
[tool.poetry]
name = "matrix-synapse"
-version = "1.59.1"
+version = "1.61.0"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0"
@@ -113,7 +113,6 @@ unpaddedbase64 = ">=2.1.0"
canonicaljson = ">=1.4.0"
# we use the type definitions added in signedjson 1.1.
signedjson = ">=1.1.0"
-PyNaCl = ">=1.2.1"
# validating SSL certs for IP addresses requires service_identity 18.1.
service-identity = ">=18.1.0"
# Twisted 18.9 introduces some logger improvements that the structured
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 190df690..3c472c57 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -45,6 +45,8 @@ docker build -t matrixdotorg/synapse -f "docker/Dockerfile" .
extra_test_args=()
+test_tags="synapse_blacklist,msc2716,msc3030,msc3787"
+
# If we're using workers, modify the docker files slightly.
if [[ -n "$WORKERS" ]]; then
# Build the workers docker image (from the base Synapse image).
@@ -65,6 +67,10 @@ if [[ -n "$WORKERS" ]]; then
else
export COMPLEMENT_BASE_IMAGE=complement-synapse
COMPLEMENT_DOCKERFILE=Dockerfile
+
+ # We only test faster room joins on monoliths, because they are purposefully
+ # being developed without worker support to start with.
+ test_tags="$test_tags,faster_joins"
fi
# Build the Complement image from the Synapse image we just built.
@@ -73,4 +79,5 @@ docker build -t $COMPLEMENT_BASE_IMAGE -f "docker/complement/$COMPLEMENT_DOCKERF
# Run the tests!
echo "Images built; running complement"
cd "$COMPLEMENT_DIR"
-go test -v -tags synapse_blacklist,msc2716,msc3030,faster_joins -count=1 "${extra_test_args[@]}" "$@" ./tests/...
+
+go test -v -tags $test_tags -count=1 "${extra_test_args[@]}" "$@" ./tests/...
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index c7758652..d08517a9 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -21,7 +21,7 @@ from typing import Callable, Optional, Type
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
-from mypy.types import CallableType, NoneType
+from mypy.types import CallableType, NoneType, UnionType
class SynapsePlugin(Plugin):
@@ -72,13 +72,20 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
# Third, we add an optional "on_invalidate" argument.
#
- # This is a callable which accepts no input and returns nothing.
- calltyp = CallableType(
- arg_types=[],
- arg_kinds=[],
- arg_names=[],
- ret_type=NoneType(),
- fallback=ctx.api.named_generic_type("builtins.function", []),
+ # This is a either
+ # - a callable which accepts no input and returns nothing, or
+ # - None.
+ calltyp = UnionType(
+ [
+ NoneType(),
+ CallableType(
+ arg_types=[],
+ arg_kinds=[],
+ arg_names=[],
+ ret_type=NoneType(),
+ fallback=ctx.api.named_generic_type("builtins.function", []),
+ ),
+ ]
)
arg_types.append(calltyp)
@@ -95,7 +102,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
def plugin(version: str) -> Type[SynapsePlugin]:
- # This is the entry point of the plugin, and let's us deal with the fact
+ # This is the entry point of the plugin, and lets us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
#
diff --git a/synapse/_scripts/hash_password.py b/synapse/_scripts/hash_password.py
index 3aa29de5..3bed367b 100755
--- a/synapse/_scripts/hash_password.py
+++ b/synapse/_scripts/hash_password.py
@@ -46,14 +46,14 @@ def main() -> None:
"Path to server config file. "
"Used to read in bcrypt_rounds and password_pepper."
),
+ required=True,
)
args = parser.parse_args()
- if "config" in args and args.config:
- config = yaml.safe_load(args.config)
- bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
- password_config = config.get("password_config", None) or {}
- password_pepper = password_config.get("pepper", password_pepper)
+ config = yaml.safe_load(args.config)
+ bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
+ password_config = config.get("password_config", None) or {}
+ password_pepper = password_config.get("pepper", password_pepper)
password = args.password
if not password:
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 12ff79f6..361b51d2 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -62,7 +62,7 @@ from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackground
from synapse.storage.databases.main.events_bg_updates import (
EventsBackgroundUpdatesStore,
)
-from synapse.storage.databases.main.group_server import GroupServerWorkerStore
+from synapse.storage.databases.main.group_server import GroupServerStore
from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore,
)
@@ -102,14 +102,6 @@ BOOLEAN_COLUMNS = {
"devices": ["hidden"],
"device_lists_outbound_pokes": ["sent"],
"users_who_share_rooms": ["share_private"],
- "groups": ["is_public"],
- "group_rooms": ["is_public"],
- "group_users": ["is_public", "is_admin"],
- "group_summary_rooms": ["is_public"],
- "group_room_categories": ["is_public"],
- "group_summary_users": ["is_public"],
- "group_roles": ["is_public"],
- "local_group_membership": ["is_publicised", "is_admin"],
"e2e_room_keys": ["is_verified"],
"account_validity": ["email_sent"],
"redactions": ["have_censored"],
@@ -175,6 +167,22 @@ IGNORED_TABLES = {
"ui_auth_sessions",
"ui_auth_sessions_credentials",
"ui_auth_sessions_ips",
+ # Groups/communities is no longer supported.
+ "group_attestations_remote",
+ "group_attestations_renewals",
+ "group_invites",
+ "group_roles",
+ "group_room_categories",
+ "group_rooms",
+ "group_summary_roles",
+ "group_summary_room_categories",
+ "group_summary_rooms",
+ "group_summary_users",
+ "group_users",
+ "groups",
+ "local_group_membership",
+ "local_group_updates",
+ "remote_profile_cache",
}
@@ -211,7 +219,7 @@ class Store(
PushRuleStore,
PusherWorkerStore,
PresenceBackgroundUpdateStore,
- GroupServerWorkerStore,
+ GroupServerStore,
):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 93175066..5a410f80 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -29,12 +29,11 @@ from synapse.api.errors import (
MissingClientTokenError,
)
from synapse.appservice import ApplicationService
-from synapse.events import EventBase
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import Requester, StateMap, UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
@@ -61,8 +60,8 @@ class Auth:
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- self.state = hs.get_state_handler()
self._account_validity_handler = hs.get_account_validity_handler()
+ self._storage_controllers = hs.get_storage_controllers()
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache"
@@ -79,9 +78,8 @@ class Auth:
self,
room_id: str,
user_id: str,
- current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
- ) -> EventBase:
+ ) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@@ -99,29 +97,28 @@ class Auth:
Raises:
AuthError if the user is/was not in the room.
Returns:
- Membership event for the user if the user was in the
- room. This will be the join event if they are currently joined to
- the room. This will be the leave event if they have left the room.
+ The current membership of the user in the room and the
+ membership event ID of the user.
"""
- if current_state:
- member = current_state.get((EventTypes.Member, user_id), None)
- else:
- member = await self.state.get_current_state(
- room_id=room_id, event_type=EventTypes.Member, state_key=user_id
- )
- if member:
- membership = member.membership
+ (
+ membership,
+ member_event_id,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ user_id=user_id,
+ room_id=room_id,
+ )
+ if membership:
if membership == Membership.JOIN:
- return member
+ return membership, member_event_id
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
- return member
+ return membership, member_event_id
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@@ -602,8 +599,11 @@ class Auth:
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events
- power_level_event = await self.state.get_current_state(
- room_id, EventTypes.PowerLevels, ""
+
+ power_level_event = (
+ await self._storage_controllers.state.get_current_state_event(
+ room_id, EventTypes.PowerLevels, ""
+ )
)
auth_events = {}
@@ -693,12 +693,11 @@ class Auth:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
- member_event = await self.check_user_in_room(
+ return await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
- return member_event.membership, member_event.event_id
except AuthError:
- visibility = await self.state.get_current_state(
+ visibility = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 0ccd4c95..e1d31cab 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -31,11 +31,6 @@ MAX_ALIAS_LENGTH = 255
# the maximum length for a user id is 255 characters
MAX_USERID_LENGTH = 255
-# The maximum length for a group id is 255 characters
-MAX_GROUPID_LENGTH = 255
-MAX_GROUP_CATEGORYID_LENGTH = 255
-MAX_GROUP_ROLEID_LENGTH = 255
-
class Membership:
@@ -65,6 +60,8 @@ class JoinRules:
PRIVATE: Final = "private"
# As defined for MSC3083.
RESTRICTED: Final = "restricted"
+ # As defined for MSC3787.
+ KNOCK_RESTRICTED: Final = "knock_restricted"
class RestrictedJoinRuleTypes:
@@ -98,7 +95,6 @@ class EventTypes:
Aliases: Final = "m.room.aliases"
Redaction: Final = "m.room.redaction"
ThirdPartyInvite: Final = "m.room.third_party_invite"
- RelatedGroups: Final = "m.room.related_groups"
RoomHistoryVisibility: Final = "m.room.history_visibility"
CanonicalAlias: Final = "m.room.canonical_alias"
@@ -140,7 +136,13 @@ class DeviceKeyAlgorithms:
class EduTypes:
- Presence: Final = "m.presence"
+ PRESENCE: Final = "m.presence"
+ TYPING: Final = "m.typing"
+ RECEIPT: Final = "m.receipt"
+ DEVICE_LIST_UPDATE: Final = "m.device_list_update"
+ SIGNING_KEY_UPDATE: Final = "m.signing_key_update"
+ UNSTABLE_SIGNING_KEY_UPDATE: Final = "org.matrix.signing_key_update"
+ DIRECT_TO_DEVICE: Final = "m.direct_to_device"
class RejectedReason:
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cb3b7323..cc7b7854 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,6 +17,7 @@
import logging
import typing
+from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
@@ -30,7 +31,11 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class Codes:
+class Codes(str, Enum):
+ """
+ All known error codes, as an enum of strings.
+ """
+
UNRECOGNIZED = "M_UNRECOGNIZED"
UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN"
@@ -74,6 +79,13 @@ class Codes:
WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
+
+ # The account has been suspended on the server.
+ # By opposition to `USER_DEACTIVATED`, this is a reversible measure
+ # that can possibly be appealed and reverted.
+ # Part of MSC3823.
+ USER_ACCOUNT_SUSPENDED = "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+
BAD_ALIAS = "M_BAD_ALIAS"
# For restricted join rules.
UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN"
@@ -134,7 +146,13 @@ class SynapseError(CodeMessageException):
errcode: Matrix error code e.g 'M_FORBIDDEN'
"""
- def __init__(self, code: int, msg: str, errcode: str = Codes.UNKNOWN):
+ def __init__(
+ self,
+ code: int,
+ msg: str,
+ errcode: str = Codes.UNKNOWN,
+ additional_fields: Optional[Dict] = None,
+ ):
"""Constructs a synapse error.
Args:
@@ -144,9 +162,13 @@ class SynapseError(CodeMessageException):
"""
super().__init__(code, msg)
self.errcode = errcode
+ if additional_fields is None:
+ self._additional_fields: Dict = {}
+ else:
+ self._additional_fields = dict(additional_fields)
def error_dict(self) -> "JsonDict":
- return cs_error(self.msg, self.errcode)
+ return cs_error(self.msg, self.errcode, **self._additional_fields)
class InvalidAPICallError(SynapseError):
@@ -171,14 +193,7 @@ class ProxiedRequestError(SynapseError):
errcode: str = Codes.UNKNOWN,
additional_fields: Optional[Dict] = None,
):
- super().__init__(code, msg, errcode)
- if additional_fields is None:
- self._additional_fields: Dict = {}
- else:
- self._additional_fields = dict(additional_fields)
-
- def error_dict(self) -> "JsonDict":
- return cs_error(self.msg, self.errcode, **self._additional_fields)
+ super().__init__(code, msg, errcode, additional_fields)
class ConsentNotGivenError(SynapseError):
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 4a808e33..b0071475 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -19,6 +19,7 @@ from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
+ Collection,
Dict,
Iterable,
List,
@@ -32,7 +33,7 @@ from typing import (
import jsonschema
from jsonschema import FormatChecker
-from synapse.api.constants import EventContentFields
+from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
@@ -346,7 +347,7 @@ class Filter:
user_id = event.user_id
field_matchers = {
"senders": lambda v: user_id == v,
- "types": lambda v: "m.presence" == v,
+ "types": lambda v: EduTypes.PRESENCE == v,
}
return self._check_fields(field_matchers)
else:
@@ -444,9 +445,9 @@ class Filter:
return room_ids
async def _check_event_relations(
- self, events: Iterable[FilterEvent]
+ self, events: Collection[FilterEvent]
) -> List[FilterEvent]:
- # The event IDs to check, mypy doesn't understand the ifinstance check.
+ # The event IDs to check, mypy doesn't understand the isinstance check.
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index a747a408..3f85d61b 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -81,6 +81,9 @@ class RoomVersion:
msc2716_historical: bool
# MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events
msc2716_redactions: bool
+ # MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of
+ # knocks and restricted join rules into the same join condition.
+ msc3787_knock_restricted_join_rule: bool
class RoomVersions:
@@ -99,6 +102,7 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
V2 = RoomVersion(
"2",
@@ -115,6 +119,7 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
V3 = RoomVersion(
"3",
@@ -131,6 +136,7 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
V4 = RoomVersion(
"4",
@@ -147,6 +153,7 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
V5 = RoomVersion(
"5",
@@ -163,6 +170,7 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
V6 = RoomVersion(
"6",
@@ -179,6 +187,7 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -195,6 +204,7 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
V7 = RoomVersion(
"7",
@@ -211,6 +221,7 @@ class RoomVersions:
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
V8 = RoomVersion(
"8",
@@ -227,6 +238,7 @@ class RoomVersions:
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
V9 = RoomVersion(
"9",
@@ -243,6 +255,7 @@ class RoomVersions:
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=False,
)
MSC2716v3 = RoomVersion(
"org.matrix.msc2716v3",
@@ -259,6 +272,24 @@ class RoomVersions:
msc2403_knocking=True,
msc2716_historical=True,
msc2716_redactions=True,
+ msc3787_knock_restricted_join_rule=False,
+ )
+ MSC3787 = RoomVersion(
+ "org.matrix.msc3787",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=True,
+ msc3375_redaction_rules=True,
+ msc2403_knocking=True,
+ msc2716_historical=False,
+ msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=True,
)
@@ -276,6 +307,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V8,
RoomVersions.V9,
RoomVersions.MSC2716v3,
+ RoomVersions.MSC3787,
)
}
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3623c172..a3446ac6 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -49,9 +49,12 @@ from twisted.logger import LoggingFile, LogLevel
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.python.threadpool import ThreadPool
+import synapse.util.caches
from synapse.api.constants import MAX_PDU_SIZE
from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
+from synapse.config import ConfigError
+from synapse.config._base import format_config_error
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import ManholeConfig
from synapse.crypto import context_factory
@@ -432,6 +435,10 @@ async def start(hs: "HomeServer") -> None:
signal.signal(signal.SIGHUP, run_sighup)
register_sighup(refresh_certificate, hs)
+ register_sighup(reload_cache_config, hs.config)
+
+ # Apply the cache config.
+ hs.config.caches.resize_all_caches()
# Load the certificate from disk.
refresh_certificate(hs)
@@ -486,6 +493,43 @@ async def start(hs: "HomeServer") -> None:
atexit.register(gc.freeze)
+def reload_cache_config(config: HomeServerConfig) -> None:
+ """Reload cache config from disk and immediately apply it.resize caches accordingly.
+
+ If the config is invalid, a `ConfigError` is logged and no changes are made.
+
+ Otherwise, this:
+ - replaces the `caches` section on the given `config` object,
+ - resizes all caches according to the new cache factors, and
+
+ Note that the following cache config keys are read, but not applied:
+ - event_cache_size: used to set a max_size and _original_max_size on
+ EventsWorkerStore._get_event_cache when it is created. We'd have to update
+ the _original_max_size (and maybe
+ - sync_response_cache_duration: would have to update the timeout_sec attribute on
+ HomeServer -> SyncHandler -> ResponseCache.
+ - track_memory_usage. This affects synapse.util.caches.TRACK_MEMORY_USAGE which
+ influences Synapse's self-reported metrics.
+
+ Also, the HTTPConnectionPool in SimpleHTTPClient sets its maxPersistentPerHost
+ parameter based on the global_factor. This won't be applied on a config reload.
+ """
+ try:
+ previous_cache_config = config.reload_config_section("caches")
+ except ConfigError as e:
+ logger.warning("Failed to reload cache config")
+ for f in format_config_error(e):
+ logger.warning(f)
+ else:
+ logger.debug(
+ "New cache config. Was:\n %s\nNow:\n",
+ previous_cache_config.__dict__,
+ config.caches.__dict__,
+ )
+ synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
+ config.caches.resize_all_caches()
+
+
def setup_sentry(hs: "HomeServer") -> None:
"""Enable sentry integration, if enabled in configuration"""
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 2a4c2e59..6fedf681 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -37,7 +37,6 @@ from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
-from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
@@ -55,7 +54,6 @@ class AdminCmdSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
- SlavedGroupServerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedPushRuleStore,
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 2a9480a5..89f8998f 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -58,7 +58,6 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
-from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
@@ -69,7 +68,6 @@ from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client import (
account_data,
events,
- groups,
initial_sync,
login,
presence,
@@ -78,6 +76,7 @@ from synapse.rest.client import (
read_marker,
receipts,
room,
+ room_batch,
room_keys,
sendtodevice,
sync,
@@ -87,7 +86,7 @@ from synapse.rest.client import (
voip,
)
from synapse.rest.client._base import client_patterns
-from synapse.rest.client.account import ThreepidRestServlet
+from synapse.rest.client.account import ThreepidRestServlet, WhoamiRestServlet
from synapse.rest.client.devices import DevicesRestServlet
from synapse.rest.client.keys import (
KeyChangesServlet,
@@ -233,7 +232,6 @@ class GenericWorkerSlavedStore(
SlavedDeviceStore,
SlavedReceiptsStore,
SlavedPushRuleStore,
- SlavedGroupServerStore,
SlavedAccountDataStore,
SlavedPusherStore,
CensorEventsStore,
@@ -289,6 +287,7 @@ class GenericWorkerServer(HomeServer):
RegistrationTokenValidityRestServlet(self).register(resource)
login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource)
+ WhoamiRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)
# Read-only
@@ -308,6 +307,7 @@ class GenericWorkerServer(HomeServer):
room.register_servlets(self, resource, is_worker=True)
room.register_deprecated_servlets(self, resource)
initial_sync.register_servlets(self, resource)
+ room_batch.register_servlets(self, resource)
room_keys.register_servlets(self, resource)
tags.register_servlets(self, resource)
account_data.register_servlets(self, resource)
@@ -320,9 +320,6 @@ class GenericWorkerServer(HomeServer):
presence.register_servlets(self, resource)
- if self.config.experimental.groups_enabled:
- groups.register_servlets(self, resource)
-
resources.update({CLIENT_API_PREFIX: resource})
resources.update(build_synapse_client_resource_tree(self))
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0f75e7b9..4c6c0658 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -16,7 +16,7 @@
import logging
import os
import sys
-from typing import Dict, Iterable, Iterator, List
+from typing import Dict, Iterable, List
from matrix_common.versionstring import get_distribution_version_string
@@ -45,7 +45,7 @@ from synapse.app._base import (
redirect_stdio_to_logs,
register_start,
)
-from synapse.config._base import ConfigError
+from synapse.config._base import ConfigError, format_config_error
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import ListenerConfig
@@ -399,38 +399,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
return hs
-def format_config_error(e: ConfigError) -> Iterator[str]:
- """
- Formats a config error neatly
-
- The idea is to format the immediate error, plus the "causes" of those errors,
- hopefully in a way that makes sense to the user. For example:
-
- Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
- Failed to parse config for module 'JinjaOidcMappingProvider':
- invalid jinja template:
- unexpected end of template, expected 'end of print statement'.
-
- Args:
- e: the error to be formatted
-
- Returns: An iterator which yields string fragments to be formatted
- """
- yield "Error in configuration"
-
- if e.path:
- yield " at '%s'" % (".".join(e.path),)
-
- yield ":\n %s" % (e.msg,)
-
- parent_e = e.__cause__
- indent = 1
- while parent_e:
- indent += 1
- yield ":\n%s%s" % (" " * indent, str(parent_e))
- parent_e = parent_e.__cause__
-
-
def run(hs: HomeServer) -> None:
_base.start_reactor(
"synapse-homeserver",
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index a610fb78..0dfa00df 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -23,13 +23,7 @@ from netaddr import IPSet
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.types import (
- DeviceListUpdates,
- GroupID,
- JsonDict,
- UserID,
- get_domain_from_id,
-)
+from synapse.types import DeviceListUpdates, JsonDict, UserID
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
@@ -55,7 +49,6 @@ class ApplicationServiceState(Enum):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Namespace:
exclusive: bool
- group_id: Optional[str]
regex: Pattern[str]
@@ -77,7 +70,6 @@ class ApplicationService:
def __init__(
self,
token: str,
- hostname: str,
id: str,
sender: str,
url: Optional[str] = None,
@@ -95,7 +87,6 @@ class ApplicationService:
) # url must not end with a slash
self.hs_token = hs_token
self.sender = sender
- self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces)
self.id = id
self.ip_range_whitelist = ip_range_whitelist
@@ -141,30 +132,13 @@ class ApplicationService:
exclusive = regex_obj.get("exclusive")
if not isinstance(exclusive, bool):
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
- group_id = regex_obj.get("group_id")
- if group_id:
- if not isinstance(group_id, str):
- raise ValueError(
- "Expected string for 'group_id' in ns '%s'" % ns
- )
- try:
- GroupID.from_string(group_id)
- except Exception:
- raise ValueError(
- "Expected valid group ID for 'group_id' in ns '%s'" % ns
- )
-
- if get_domain_from_id(group_id) != self.server_name:
- raise ValueError(
- "Expected 'group_id' to be this host in ns '%s'" % ns
- )
regex = regex_obj.get("regex")
if not isinstance(regex, str):
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
# Pre-compile regex.
- result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))
+ result[ns].append(Namespace(exclusive, re.compile(regex)))
return result
@@ -369,21 +343,6 @@ class ApplicationService:
if namespace.exclusive
]
- def get_groups_for_user(self, user_id: str) -> Iterable[str]:
- """Get the groups that this user is associated with by this AS
-
- Args:
- user_id: The ID of the user.
-
- Returns:
- An iterable that yields group_id strings.
- """
- return (
- namespace.group_id
- for namespace in self.namespaces[ApplicationService.NS_USERS]
- if namespace.group_id and namespace.regex.match(user_id)
- )
-
def is_rate_limited(self) -> bool:
return self.rate_limited
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index d19f8dd9..df1c2144 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
from prometheus_client import Counter
from typing_extensions import TypeGuard
@@ -155,6 +155,9 @@ class ApplicationServiceApi(SimpleHttpClient):
if service.url is None:
return []
+ # This is required by the configuration.
+ assert service.hs_token is not None
+
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
@@ -162,7 +165,11 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- response = await self.get_json(uri, fields)
+ args: Mapping[Any, Any] = {
+ **fields,
+ b"access_token": service.hs_token,
+ }
+ response = await self.get_json(uri, args=args)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r", uri, response
@@ -190,13 +197,15 @@ class ApplicationServiceApi(SimpleHttpClient):
return {}
async def _get() -> Optional[JsonDict]:
+ # This is required by the configuration.
+ assert service.hs_token is not None
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.parse.quote(protocol),
)
try:
- info = await self.get_json(uri)
+ info = await self.get_json(uri, {"access_token": service.hs_token})
if not _is_valid_3pe_metadata(info):
logger.warning(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 3b49e607..de5e5216 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -384,6 +384,11 @@ class _TransactionController:
device_list_summary: The device list summary to include in the transaction.
"""
try:
+ service_is_up = await self._is_service_up(service)
+ # Don't create empty txns when in recovery mode (ephemeral events are dropped)
+ if not service_is_up and not events:
+ return
+
txn = await self.store.create_appservice_txn(
service=service,
events=events,
@@ -393,7 +398,6 @@ class _TransactionController:
unused_fallback_keys=unused_fallback_keys or {},
device_list_summary=device_list_summary or DeviceListUpdates(),
)
- service_is_up = await self._is_service_up(service)
if service_is_up:
sent = await txn.send(self.as_api)
if sent:
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 179aa7ff..42364fc1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -16,14 +16,18 @@
import argparse
import errno
+import logging
import os
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
from typing import (
Any,
+ ClassVar,
+ Collection,
Dict,
Iterable,
+ Iterator,
List,
MutableMapping,
Optional,
@@ -40,6 +44,8 @@ import yaml
from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
+logger = logging.getLogger(__name__)
+
class ConfigError(Exception):
"""Represents a problem parsing the configuration
@@ -55,6 +61,38 @@ class ConfigError(Exception):
self.path = path
+def format_config_error(e: ConfigError) -> Iterator[str]:
+ """
+ Formats a config error neatly
+
+ The idea is to format the immediate error, plus the "causes" of those errors,
+ hopefully in a way that makes sense to the user. For example:
+
+ Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
+ Failed to parse config for module 'JinjaOidcMappingProvider':
+ invalid jinja template:
+ unexpected end of template, expected 'end of print statement'.
+
+ Args:
+ e: the error to be formatted
+
+ Returns: An iterator which yields string fragments to be formatted
+ """
+ yield "Error in configuration"
+
+ if e.path:
+ yield " at '%s'" % (".".join(e.path),)
+
+ yield ":\n %s" % (e.msg,)
+
+ parent_e = e.__cause__
+ indent = 1
+ while parent_e:
+ indent += 1
+ yield ":\n%s%s" % (" " * indent, str(parent_e))
+ parent_e = parent_e.__cause__
+
+
# We split these messages out to allow packages to override with package
# specific instructions.
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
@@ -119,7 +157,7 @@ class Config:
defined in subclasses.
"""
- section: str
+ section: ClassVar[str]
def __init__(self, root_config: "RootConfig" = None):
self.root = root_config
@@ -309,9 +347,12 @@ class RootConfig:
class, lower-cased and with "Config" removed.
"""
- config_classes = []
+ config_classes: List[Type[Config]] = []
+
+ def __init__(self, config_files: Collection[str] = ()):
+ # Capture absolute paths here, so we can reload config after we daemonize.
+ self.config_files = [os.path.abspath(path) for path in config_files]
- def __init__(self):
for config_class in self.config_classes:
if config_class.section is None:
raise ValueError("%r requires a section name" % (config_class,))
@@ -512,12 +553,10 @@ class RootConfig:
object from parser.parse_args(..)`
"""
- obj = cls()
-
config_args = parser.parse_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
-
+ obj = cls(config_files)
if not config_files:
parser.error("Must supply a config file.")
@@ -627,7 +666,7 @@ class RootConfig:
generate_missing_configs = config_args.generate_missing_configs
- obj = cls()
+ obj = cls(config_files)
if config_args.generate_config:
if config_args.report_stats is None:
@@ -727,6 +766,34 @@ class RootConfig:
) -> None:
self.invoke_all("generate_files", config_dict, config_dir_path)
+ def reload_config_section(self, section_name: str) -> Config:
+ """Reconstruct the given config section, leaving all others unchanged.
+
+ This works in three steps:
+
+ 1. Create a new instance of the relevant `Config` subclass.
+ 2. Call `read_config` on that instance to parse the new config.
+ 3. Replace the existing config instance with the new one.
+
+ :raises ValueError: if the given `section` does not exist.
+ :raises ConfigError: for any other problems reloading config.
+
+ :returns: the previous config object, which no longer has a reference to this
+ RootConfig.
+ """
+ existing_config: Optional[Config] = getattr(self, section_name, None)
+ if existing_config is None:
+ raise ValueError(f"Unknown config section '{section_name}'")
+ logger.info("Reloading config section '%s'", section_name)
+
+ new_config_data = read_config_files(self.config_files)
+ new_config = type(existing_config)(self)
+ new_config.read_config(new_config_data)
+ setattr(self, section_name, new_config)
+
+ existing_config.root = None
+ return existing_config
+
def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
"""Read the config files into a dict
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index bd092f95..01ea2b4d 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,15 +1,19 @@
import argparse
from typing import (
Any,
+ Collection,
Dict,
Iterable,
+ Iterator,
List,
+ Literal,
MutableMapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
+ overload,
)
import jinja2
@@ -28,7 +32,6 @@ from synapse.config import (
emailconfig,
experimental,
federation,
- groups,
jwt,
key,
logger,
@@ -64,6 +67,8 @@ class ConfigError(Exception):
self.msg = msg
self.path = path
+def format_config_error(e: ConfigError) -> Iterator[str]: ...
+
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
MISSING_SERVER_NAME: str
@@ -101,7 +106,6 @@ class RootConfig:
push: push.PushConfig
spamchecker: spam_checker.SpamCheckerConfig
room: room.RoomConfig
- groups: groups.GroupsConfig
userdirectory: user_directory.UserDirectoryConfig
consent: consent.ConsentConfig
stats: stats.StatsConfig
@@ -117,7 +121,8 @@ class RootConfig:
background_updates: background_updates.BackgroundUpdateConfig
config_classes: List[Type["Config"]] = ...
- def __init__(self) -> None: ...
+ config_files: List[str]
+ def __init__(self, config_files: Collection[str] = ...) -> None: ...
def invoke_all(
self, func_name: str, *args: Any, **kwargs: Any
) -> MutableMapping[str, Any]: ...
@@ -157,6 +162,12 @@ class RootConfig:
def generate_missing_files(
self, config_dict: dict, config_dir_path: str
) -> None: ...
+ @overload
+ def reload_config_section(
+ self, section_name: Literal["caches"]
+ ) -> cache.CacheConfig: ...
+ @overload
+ def reload_config_section(self, section_name: str) -> Config: ...
class Config:
root: RootConfig
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 24498e79..16f93273 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -179,7 +179,6 @@ def _load_appservice(
return ApplicationService(
token=as_info["as_token"],
- hostname=hostname,
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
diff --git a/synapse/config/auth.py b/synapse/config/auth.py
index bb417a23..265a554a 100644
--- a/synapse/config/auth.py
+++ b/synapse/config/auth.py
@@ -29,7 +29,18 @@ class AuthConfig(Config):
if password_config is None:
password_config = {}
- self.password_enabled = password_config.get("enabled", True)
+ passwords_enabled = password_config.get("enabled", True)
+ # 'only_for_reauth' allows users who have previously set a password to use it,
+ # even though passwords would otherwise be disabled.
+ passwords_for_reauth_only = passwords_enabled == "only_for_reauth"
+
+ self.password_enabled_for_login = (
+ passwords_enabled and not passwords_for_reauth_only
+ )
+ self.password_enabled_for_reauth = (
+ passwords_for_reauth_only or passwords_enabled
+ )
+
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
self.password_pepper = password_config.get("pepper", "")
@@ -46,7 +57,9 @@ class AuthConfig(Config):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
password_config:
- # Uncomment to disable password login
+ # Uncomment to disable password login.
+ # Set to `only_for_reauth` to permit reauthentication for users that
+ # have passwords and are already logged in.
#
#enabled: false
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 94d852f4..d2f55534 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -69,11 +69,11 @@ def _canonicalise_cache_name(cache_name: str) -> str:
def add_resizable_cache(
cache_name: str, cache_resize_callback: Callable[[float], None]
) -> None:
- """Register a cache that's size can dynamically change
+ """Register a cache whose size can dynamically change
Args:
cache_name: A reference to the cache
- cache_resize_callback: A callback function that will be ran whenever
+ cache_resize_callback: A callback function that will run whenever
the cache needs to be resized
"""
# Some caches have '*' in them which we strip out.
@@ -96,6 +96,13 @@ class CacheConfig(Config):
section = "caches"
_environ = os.environ
+ event_cache_size: int
+ cache_factors: Dict[str, float]
+ global_factor: float
+ track_memory_usage: bool
+ expiry_time_msec: Optional[int]
+ sync_response_cache_duration: int
+
@staticmethod
def reset() -> None:
"""Resets the caches to their defaults. Used for tests."""
@@ -115,6 +122,12 @@ class CacheConfig(Config):
# A cache 'factor' is a multiplier that can be applied to each of
# Synapse's caches in order to increase or decrease the maximum
# number of entries that can be stored.
+ #
+ # The configuration for cache factors (caches.global_factor and
+ # caches.per_cache_factors) can be reloaded while the application is running,
+ # by sending a SIGHUP signal to the Synapse process. Changes to other parts of
+ # the caching config will NOT be applied after a SIGHUP is received; a restart
+ # is necessary.
# The number of events to cache in memory. Not affected by
# caches.global_factor.
@@ -163,6 +176,24 @@ class CacheConfig(Config):
#
#cache_entry_ttl: 30m
+ # This flag enables cache autotuning, and is further specified by the sub-options `max_cache_memory_usage`,
+ # `target_cache_memory_usage`, `min_cache_ttl`. These flags work in conjunction with each other to maintain
+ # a balance between cache memory usage and cache entry availability. You must be using jemalloc to utilize
+ # this option, and all three of the options must be specified for this feature to work.
+ #cache_autotuning:
+ # This flag sets a ceiling on much memory the cache can use before caches begin to be continuously evicted.
+ # They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in
+ # the flag below, or until the `min_cache_ttl` is hit.
+ #max_cache_memory_usage: 1024M
+
+ # This flag sets a rough target for the desired memory usage of the caches.
+ #target_cache_memory_usage: 758M
+
+ # 'min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when
+ # caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches
+ # from being emptied while Synapse is evicting due to memory.
+ #min_cache_ttl: 5m
+
# Controls how long the results of a /sync request are cached for after
# a successful response is returned. A higher duration can help clients with
# intermittent connections, at the cost of higher memory usage.
@@ -174,21 +205,21 @@ class CacheConfig(Config):
"""
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+ """Populate this config object with values from `config`.
+
+ This method does NOT resize existing or future caches: use `resize_all_caches`.
+ We use two separate methods so that we can reject bad config before applying it.
+ """
self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
)
- self.cache_factors: Dict[str, float] = {}
+ self.cache_factors = {}
cache_config = config.get("caches") or {}
- self.global_factor = cache_config.get(
- "global_factor", properties.default_factor_size
- )
+ self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE)
if not isinstance(self.global_factor, (int, float)):
raise ConfigError("caches.global_factor must be a number.")
- # Set the global one so that it's reflected in new caches
- properties.default_factor_size = self.global_factor
-
# Load cache factors from the config
individual_factors = cache_config.get("per_cache_factors") or {}
if not isinstance(individual_factors, dict):
@@ -230,7 +261,7 @@ class CacheConfig(Config):
cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m")
if expire_caches:
- self.expiry_time_msec: Optional[int] = self.parse_duration(cache_entry_ttl)
+ self.expiry_time_msec = self.parse_duration(cache_entry_ttl)
else:
self.expiry_time_msec = None
@@ -250,23 +281,38 @@ class CacheConfig(Config):
)
self.expiry_time_msec = self.parse_duration(expiry_time)
+ self.cache_autotuning = cache_config.get("cache_autotuning")
+ if self.cache_autotuning:
+ max_memory_usage = self.cache_autotuning.get("max_cache_memory_usage")
+ self.cache_autotuning["max_cache_memory_usage"] = self.parse_size(
+ max_memory_usage
+ )
+
+ target_mem_size = self.cache_autotuning.get("target_cache_memory_usage")
+ self.cache_autotuning["target_cache_memory_usage"] = self.parse_size(
+ target_mem_size
+ )
+
+ min_cache_ttl = self.cache_autotuning.get("min_cache_ttl")
+ self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl)
+
self.sync_response_cache_duration = self.parse_duration(
cache_config.get("sync_response_cache_duration", 0)
)
- # Resize all caches (if necessary) with the new factors we've loaded
- self.resize_all_caches()
-
- # Store this function so that it can be called from other classes without
- # needing an instance of Config
- properties.resize_all_caches_func = self.resize_all_caches
-
def resize_all_caches(self) -> None:
- """Ensure all cache sizes are up to date
+ """Ensure all cache sizes are up-to-date.
For each cache, run the mapped callback function with either
a specific cache factor or the default, global one.
"""
+ # Set the global factor size, so that new caches are appropriately sized.
+ properties.default_factor_size = self.global_factor
+
+ # Store this function so that it can be called from other classes without
+ # needing an instance of CacheConfig
+ properties.resize_all_caches_func = self.resize_all_caches
+
# block other threads from modifying _CACHES while we iterate it.
with _CACHES_LOCK:
for cache_name, callback in _CACHES.items():
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b20d9496..f2dfd49b 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -73,9 +73,6 @@ class ExperimentalConfig(Config):
# MSC3720 (Account status endpoint)
self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)
- # The deprecated groups feature.
- self.groups_enabled: bool = experimental.get("groups_enabled", False)
-
# MSC2654: Unread counts
self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False)
@@ -84,3 +81,6 @@ class ExperimentalConfig(Config):
# MSC3786 (Add a default push rule to ignore m.room.server_acl events)
self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
+
+ # MSC3772: A push rule for mutual relations.
+ self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
deleted file mode 100644
index c9b9c6da..00000000
--- a/synapse/config/groups.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any
-
-from synapse.types import JsonDict
-
-from ._base import Config
-
-
-class GroupsConfig(Config):
- section = "groups"
-
- def read_config(self, config: JsonDict, **kwargs: Any) -> None:
- self.enable_group_creation = config.get("enable_group_creation", False)
- self.group_creation_prefix = config.get("group_creation_prefix", "")
-
- def generate_config_section(self, **kwargs: Any) -> str:
- return """\
- # Uncomment to allow non-server-admin users to create groups on this server
- #
- #enable_group_creation: true
-
- # If enabled, non server admins can only create groups with local parts
- # starting with this prefix
- #
- #group_creation_prefix: "unofficial_"
- """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index a4ec7069..4d2b298a 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -25,7 +25,6 @@ from .database import DatabaseConfig
from .emailconfig import EmailConfig
from .experimental import ExperimentalConfig
from .federation import FederationConfig
-from .groups import GroupsConfig
from .jwt import JWTConfig
from .key import KeyConfig
from .logger import LoggingConfig
@@ -89,7 +88,6 @@ class HomeServerConfig(RootConfig):
PushConfig,
SpamCheckerConfig,
RoomConfig,
- GroupsConfig,
UserDirectoryConfig,
ConsentConfig,
StatsConfig,
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
index 690ffb52..e9edea07 100644
--- a/synapse/config/oembed.py
+++ b/synapse/config/oembed.py
@@ -57,9 +57,9 @@ class OembedConfig(Config):
"""
# Whether to use the packaged providers.json file.
if not oembed_config.get("disable_default_providers") or False:
- providers = json.load(
- pkg_resources.resource_stream("synapse", "res/providers.json")
- )
+ with pkg_resources.resource_stream("synapse", "res/providers.json") as s:
+ providers = json.load(s)
+
yield from self._parse_and_validate_provider(
providers, config_path=("oembed",)
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 98d8a166..f9c55143 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -223,6 +223,22 @@ class ContentRepositoryConfig(Config):
"url_preview_accept_language"
) or ["en"]
+ media_retention = config.get("media_retention") or {}
+
+ self.media_retention_local_media_lifetime_ms = None
+ local_media_lifetime = media_retention.get("local_media_lifetime")
+ if local_media_lifetime is not None:
+ self.media_retention_local_media_lifetime_ms = self.parse_duration(
+ local_media_lifetime
+ )
+
+ self.media_retention_remote_media_lifetime_ms = None
+ remote_media_lifetime = media_retention.get("remote_media_lifetime")
+ if remote_media_lifetime is not None:
+ self.media_retention_remote_media_lifetime_ms = self.parse_duration(
+ remote_media_lifetime
+ )
+
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
assert data_dir_path is not None
media_store = os.path.join(data_dir_path, "media_store")
diff --git a/synapse/config/room.py b/synapse/config/room.py
index e18a87ea..462d85ac 100644
--- a/synapse/config/room.py
+++ b/synapse/config/room.py
@@ -63,6 +63,19 @@ class RoomConfig(Config):
"Invalid value for encryption_enabled_by_default_for_room_type"
)
+ self.default_power_level_content_override = config.get(
+ "default_power_level_content_override",
+ None,
+ )
+ if self.default_power_level_content_override is not None:
+ for preset in self.default_power_level_content_override:
+ if preset not in vars(RoomCreationPreset).values():
+ raise ConfigError(
+ "Unrecognised room preset %s in default_power_level_content_override"
+ % preset
+ )
+ # We validate the actual overrides when we try to apply them.
+
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Rooms ##
@@ -83,4 +96,38 @@ class RoomConfig(Config):
# will also not affect rooms created by other servers.
#
#encryption_enabled_by_default_for_room_type: invite
+
+ # Override the default power levels for rooms created on this server, per
+ # room creation preset.
+ #
+ # The appropriate dictionary for the room preset will be applied on top
+ # of the existing power levels content.
+ #
+ # Useful if you know that your users need special permissions in rooms
+ # that they create (e.g. to send particular types of state events without
+ # needing an elevated power level). This takes the same shape as the
+ # `power_level_content_override` parameter in the /createRoom API, but
+ # is applied before that parameter.
+ #
+ # Valid keys are some or all of `private_chat`, `trusted_private_chat`
+ # and `public_chat`. Inside each of those should be any of the
+ # properties allowed in `power_level_content_override` in the
+ # /createRoom API. If any property is missing, its default value will
+ # continue to be used. If any property is present, it will overwrite
+ # the existing default completely (so if the `events` property exists,
+ # the default event power levels will be ignored).
+ #
+ #default_power_level_content_override:
+ # private_chat:
+ # "events":
+ # "com.example.myeventtype" : 0
+ # "m.room.avatar": 50
+ # "m.room.canonical_alias": 50
+ # "m.room.encryption": 100
+ # "m.room.history_visibility": 100
+ # "m.room.name": 50
+ # "m.room.power_levels": 100
+ # "m.room.server_acl": 100
+ # "m.room.tombstone": 100
+ # "events_default": 1
"""
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 005a3ee4..657322cb 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -679,6 +679,17 @@ class ServerConfig(Config):
config.get("exclude_rooms_from_sync") or []
)
+ delete_stale_devices_after: Optional[str] = (
+ config.get("delete_stale_devices_after") or None
+ )
+
+ if delete_stale_devices_after is not None:
+ self.delete_stale_devices_after: Optional[int] = self.parse_duration(
+ delete_stale_devices_after
+ )
+ else:
+ self.delete_stale_devices_after = None
+
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
@@ -996,7 +1007,7 @@ class ServerConfig(Config):
# federation: the server-server API (/_matrix/federation). Also implies
# 'media', 'keys', 'openid'
#
- # keys: the key discovery API (/_matrix/keys).
+ # keys: the key discovery API (/_matrix/key).
#
# media: the media API (/_matrix/media).
#
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 3472a9a0..ae68a3dd 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Set
+from typing import Any, List, Set
from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements
@@ -49,7 +49,9 @@ class TracerConfig(Config):
# The tracer is enabled so sanitize the config
- self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", [])
+ self.opentracer_whitelist: List[str] = opentracing_config.get(
+ "homeserver_whitelist", []
+ )
if not isinstance(self.opentracer_whitelist, list):
raise ConfigError("Tracer homeserver_whitelist config is malformed")
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 621a3efc..4c0b587a 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -414,7 +414,12 @@ def _is_membership_change_allowed(
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
- elif room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED:
+ elif (
+ room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED
+ ) or (
+ room_version.msc3787_knock_restricted_join_rule
+ and join_rule == JoinRules.KNOCK_RESTRICTED
+ ):
# This is the same as public, but the event must contain a reference
# to the server who authorised the join. If the event does not contain
# the proper content it is rejected.
@@ -440,8 +445,13 @@ def _is_membership_change_allowed(
if authorising_user_level < invite_level:
raise AuthError(403, "Join event authorised by invalid server.")
- elif join_rule == JoinRules.INVITE or (
- room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+ elif (
+ join_rule == JoinRules.INVITE
+ or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
+ or (
+ room_version.msc3787_knock_restricted_join_rule
+ and join_rule == JoinRules.KNOCK_RESTRICTED
+ )
):
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
@@ -462,7 +472,10 @@ def _is_membership_change_allowed(
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
elif room_version.msc2403_knocking and Membership.KNOCK == membership:
- if join_rule != JoinRules.KNOCK:
+ if join_rule != JoinRules.KNOCK and (
+ not room_version.msc3787_knock_restricted_join_rule
+ or join_rule != JoinRules.KNOCK_RESTRICTED
+ ):
raise AuthError(403, "You don't have permission to knock")
elif target_user_id != event.user_id:
raise AuthError(403, "You cannot knock for other users")
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index c238376c..39ad2793 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -15,6 +15,7 @@
# limitations under the License.
import abc
+import collections.abc
import os
from typing import (
TYPE_CHECKING,
@@ -32,9 +33,11 @@ from typing import (
overload,
)
+import attr
from typing_extensions import Literal
from unpaddedbase64 import encode_base64
+from synapse.api.constants import RelationTypes
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
@@ -615,3 +618,45 @@ def make_event_from_dict(
return event_type(
event_dict, room_version, internal_metadata_dict or {}, rejected_reason
)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRelation:
+ # The target event of the relation.
+ parent_id: str
+ # The relation type.
+ rel_type: str
+ # The aggregation key. Will be None if the rel_type is not m.annotation or is
+ # not a string.
+ aggregation_key: Optional[str]
+
+
+def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
+ """
+ Attempt to parse relation information an event.
+
+ Returns:
+ The event relation information, if it is valid. None, otherwise.
+ """
+ relation = event.content.get("m.relates_to")
+ if not relation or not isinstance(relation, collections.abc.Mapping):
+ # No relation information.
+ return None
+
+ # Relations must have a type and parent event ID.
+ rel_type = relation.get("rel_type")
+ if not isinstance(rel_type, str):
+ return None
+
+ parent_id = relation.get("event_id")
+ if not isinstance(parent_id, str):
+ return None
+
+ # Annotations have a key field.
+ aggregation_key = None
+ if rel_type == RelationTypes.ANNOTATION:
+ aggregation_key = relation.get("key")
+ if not isinstance(aggregation_key, str):
+ aggregation_key = None
+
+ return _EventRelation(parent_id, rel_type, aggregation_key)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 46042b2b..b700cbbf 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,17 +15,16 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import attr
from frozendict import frozendict
-
-from twisted.internet.defer import Deferred
+from typing_extensions import Literal
from synapse.appservice import ApplicationService
from synapse.events import EventBase
-from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
- from synapse.storage import Storage
+ from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
+ from synapse.storage.state import StateFilter
@attr.s(slots=True, auto_attribs=True)
@@ -60,6 +59,9 @@ class EventContext:
If ``state_group`` is None (ie, the event is an outlier),
``state_group_before_event`` will always also be ``None``.
+ state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
+ then this is the delta of the state between the two groups.
+
prev_group: If it is known, ``state_group``'s prev_group. Note that this being
None does not necessarily mean that ``state_group`` does not have
a prev_group!
@@ -78,73 +80,47 @@ class EventContext:
app_service: If this event is being sent by a (local) application service, that
app service.
- _current_state_ids: The room state map, including this event - ie, the state
- in ``state_group``.
-
- (type, state_key) -> event_id
-
- For an outlier, this is {}
-
- Note that this is a private attribute: it should be accessed via
- ``get_current_state_ids``. _AsyncEventContext impl calculates this
- on-demand: it will be None until that happens.
-
- _prev_state_ids: The room state map, excluding this event - ie, the state
- in ``state_group_before_event``. For a non-state
- event, this will be the same as _current_state_events.
-
- Note that it is a completely different thing to prev_group!
-
- (type, state_key) -> event_id
-
- For an outlier, this is {}
-
- As with _current_state_ids, this is a private attribute. It should be
- accessed via get_prev_state_ids.
-
partial_state: if True, we may be storing this event with a temporary,
incomplete state.
"""
- rejected: Union[bool, str] = False
+ _storage: "StorageControllers"
+ rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
+ _state_delta_due_to_event: Optional[StateMap[str]] = None
prev_group: Optional[int] = None
delta_ids: Optional[StateMap[str]] = None
app_service: Optional[ApplicationService] = None
- _current_state_ids: Optional[StateMap[str]] = None
- _prev_state_ids: Optional[StateMap[str]] = None
-
partial_state: bool = False
@staticmethod
def with_state(
+ storage: "StorageControllers",
state_group: Optional[int],
state_group_before_event: Optional[int],
- current_state_ids: Optional[StateMap[str]],
- prev_state_ids: Optional[StateMap[str]],
+ state_delta_due_to_event: Optional[StateMap[str]],
partial_state: bool,
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
) -> "EventContext":
return EventContext(
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
+ storage=storage,
state_group=state_group,
state_group_before_event=state_group_before_event,
+ state_delta_due_to_event=state_delta_due_to_event,
prev_group=prev_group,
delta_ids=delta_ids,
partial_state=partial_state,
)
@staticmethod
- def for_outlier() -> "EventContext":
+ def for_outlier(
+ storage: "StorageControllers",
+ ) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
- return EventContext(
- current_state_ids={},
- prev_state_ids={},
- )
+ return EventContext(storage=storage)
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then
@@ -157,31 +133,21 @@ class EventContext:
The serialized event.
"""
- # We don't serialize the full state dicts, instead they get pulled out
- # of the DB on the other side. However, the other side can't figure out
- # the prev_state_ids, so if we're a state event we include the event
- # id that we replaced in the state.
- if event.is_state():
- prev_state_ids = await self.get_prev_state_ids()
- prev_state_id = prev_state_ids.get((event.type, event.state_key))
- else:
- prev_state_id = None
-
return {
- "prev_state_id": prev_state_id,
- "event_type": event.type,
- "event_state_key": event.get_state_key(),
"state_group": self._state_group,
"state_group_before_event": self.state_group_before_event,
"rejected": self.rejected,
"prev_group": self.prev_group,
+ "state_delta_due_to_event": _encode_state_dict(
+ self._state_delta_due_to_event
+ ),
"delta_ids": _encode_state_dict(self.delta_ids),
"app_service_id": self.app_service.id if self.app_service else None,
"partial_state": self.partial_state,
}
@staticmethod
- def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
+ def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext":
"""Converts a dict that was produced by `serialize` back into a
EventContext.
@@ -192,16 +158,16 @@ class EventContext:
Returns:
The event context.
"""
- context = _AsyncEventContextImpl(
+ context = EventContext(
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
storage=storage,
- prev_state_id=input["prev_state_id"],
- event_type=input["event_type"],
- event_state_key=input["event_state_key"],
state_group=input["state_group"],
state_group_before_event=input["state_group_before_event"],
prev_group=input["prev_group"],
+ state_delta_due_to_event=_decode_state_dict(
+ input["state_delta_due_to_event"]
+ ),
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"],
partial_state=input.get("partial_state", False),
@@ -231,7 +197,9 @@ class EventContext:
return self._state_group
- async def get_current_state_ids(self) -> Optional[StateMap[str]]:
+ async def get_current_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@@ -239,6 +207,9 @@ class EventContext:
not make it into the room state. This method will raise an exception if
``rejected`` is set.
+ Arg:
+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
Returns:
Returns None if state_group is None, which happens when the associated
event is an outlier.
@@ -249,15 +220,27 @@ class EventContext:
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")
- await self._ensure_fetched()
- return self._current_state_ids
+ assert self._state_delta_due_to_event is not None
+
+ prev_state_ids = await self.get_prev_state_ids(state_filter)
+
+ if self._state_delta_due_to_event:
+ prev_state_ids = dict(prev_state_ids)
+ prev_state_ids.update(self._state_delta_due_to_event)
- async def get_prev_state_ids(self) -> StateMap[str]:
+ return prev_state_ids
+
+ async def get_prev_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> StateMap[str]:
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
+ Args:
+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
Returns:
Returns {} if state_group is None, which happens when the associated
event is an outlier.
@@ -265,94 +248,10 @@ class EventContext:
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
- await self._ensure_fetched()
- # There *should* be previous state IDs now.
- assert self._prev_state_ids is not None
- return self._prev_state_ids
-
- def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
- """Gets the current state IDs if we have them already cached.
-
- It is an error to access this for a rejected event, since rejected state should
- not make it into the room state. This method will raise an exception if
- ``rejected`` is set.
-
- Returns:
- Returns None if we haven't cached the state or if state_group is None
- (which happens when the associated event is an outlier).
-
- Otherwise, returns the the current state IDs.
- """
- if self.rejected:
- raise RuntimeError("Attempt to access state_ids of rejected event")
-
- return self._current_state_ids
-
- async def _ensure_fetched(self) -> None:
- return None
-
-
-@attr.s(slots=True)
-class _AsyncEventContextImpl(EventContext):
- """
- An implementation of EventContext which fetches _current_state_ids and
- _prev_state_ids from the database on demand.
-
- Attributes:
-
- _storage
-
- _fetching_state_deferred: Resolves when *_state_ids have been calculated.
- None if we haven't started calculating yet
-
- _event_type: The type of the event the context is associated with.
-
- _event_state_key: The state_key of the event the context is associated with.
-
- _prev_state_id: If the event associated with the context is a state event,
- then `_prev_state_id` is the event_id of the state that was replaced.
- """
-
- # This needs to have a default as we're inheriting
- _storage: "Storage" = attr.ib(default=None)
- _prev_state_id: Optional[str] = attr.ib(default=None)
- _event_type: str = attr.ib(default=None)
- _event_state_key: Optional[str] = attr.ib(default=None)
- _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
-
- async def _ensure_fetched(self) -> None:
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(self._fill_out_state)
-
- await make_deferred_yieldable(self._fetching_state_deferred)
-
- async def _fill_out_state(self) -> None:
- """Called to populate the _current_state_ids and _prev_state_ids
- attributes by loading from the database.
- """
- if self.state_group is None:
- # No state group means the event is an outlier. Usually the state_ids dicts are also
- # pre-set to empty dicts, but they get reset when the context is serialized, so set
- # them to empty dicts again here.
- self._current_state_ids = {}
- self._prev_state_ids = {}
- return
-
- current_state_ids = await self._storage.state.get_state_ids_for_group(
- self.state_group
+ assert self.state_group_before_event is not None
+ return await self._storage.state.get_state_ids_for_group(
+ self.state_group_before_event, state_filter
)
- # Set this separately so mypy knows current_state_ids is not None.
- self._current_state_ids = current_state_ids
- if self._event_state_key is not None:
- self._prev_state_ids = dict(current_state_ids)
-
- key = (self._event_type, self._event_state_key)
- if self._prev_state_id:
- self._prev_state_ids[key] = self._prev_state_id
- else:
- self._prev_state_ids.pop(key, None)
- else:
- self._prev_state_ids = current_state_ids
def _encode_state_dict(
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 3b6795d4..d2e06c75 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -21,17 +21,20 @@ from typing import (
Awaitable,
Callable,
Collection,
+ Dict,
List,
Optional,
Tuple,
Union,
)
+from synapse.api.errors import Codes
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.metrics import Measure
if TYPE_CHECKING:
import synapse.events
@@ -41,6 +44,22 @@ logger = logging.getLogger(__name__)
CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
["synapse.events.EventBase"],
+ Awaitable[
+ Union[
+ str,
+ Codes,
+ # Highly experimental, not officially part of the spamchecker API, may
+ # disappear without warning depending on the results of ongoing
+ # experiments.
+ # Use this to return additional information as part of an error.
+ Tuple[Codes, Dict],
+ # Deprecated
+ bool,
+ ]
+ ],
+]
+SHOULD_DROP_FEDERATED_EVENT_CALLBACK = Callable[
+ ["synapse.events.EventBase"],
Awaitable[Union[bool, str]],
]
USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
@@ -162,8 +181,16 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
class SpamChecker:
- def __init__(self) -> None:
+ NOT_SPAM = "NOT_SPAM"
+
+ def __init__(self, hs: "synapse.server.HomeServer") -> None:
+ self.hs = hs
+ self.clock = hs.get_clock()
+
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
+ self._should_drop_federated_event_callbacks: List[
+ SHOULD_DROP_FEDERATED_EVENT_CALLBACK
+ ] = []
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
self._user_may_send_3pid_invite_callbacks: List[
@@ -187,6 +214,9 @@ class SpamChecker:
def register_callbacks(
self,
check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+ should_drop_federated_event: Optional[
+ SHOULD_DROP_FEDERATED_EVENT_CALLBACK
+ ] = None,
user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
@@ -205,6 +235,11 @@ class SpamChecker:
if check_event_for_spam is not None:
self._check_event_for_spam_callbacks.append(check_event_for_spam)
+ if should_drop_federated_event is not None:
+ self._should_drop_federated_event_callbacks.append(
+ should_drop_federated_event
+ )
+
if user_may_join_room is not None:
self._user_may_join_room_callbacks.append(user_may_join_room)
@@ -240,7 +275,7 @@ class SpamChecker:
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
- ) -> Union[bool, str]:
+ ) -> Union[Tuple[Codes, Dict], str]:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
@@ -251,11 +286,65 @@ class SpamChecker:
event: the event to be checked
Returns:
- True or a string if the event is spammy. If a string is returned it
- will be used as the error message returned to the user.
+ - `NOT_SPAM` if the event is considered good (non-spammy) and should be let
+ through. Other spamcheck filters may still reject it.
+ - A `Code` if the event is considered spammy and is rejected with a specific
+ error message/code.
+ - A string that isn't `NOT_SPAM` if the event is considered spammy and the
+ string should be used as the client-facing error message. This usage is
+ generally discouraged as it doesn't support internationalization.
"""
for callback in self._check_event_for_spam_callbacks:
- res: Union[bool, str] = await delay_cancellation(callback(event))
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ res = await delay_cancellation(callback(event))
+ if res is False or res == self.NOT_SPAM:
+ # This spam-checker accepts the event.
+ # Other spam-checkers may reject it, though.
+ continue
+ elif res is True:
+ # This spam-checker rejects the event with deprecated
+ # return value `True`
+ return Codes.FORBIDDEN
+ elif not isinstance(res, str):
+ # mypy complains that we can't reach this code because of the
+ # return type in CHECK_EVENT_FOR_SPAM_CALLBACK, but we don't know
+ # for sure that the module actually returns it.
+ logger.warning(
+ "Module returned invalid value, rejecting message as spam"
+ )
+ res = "This message has been rejected as probable spam"
+ else:
+ # The module rejected the event either with a `Codes`
+ # or some other `str`. In either case, we stop here.
+ pass
+
+ return res
+
+ # No spam-checker has rejected the event, let it pass.
+ return self.NOT_SPAM
+
+ async def should_drop_federated_event(
+ self, event: "synapse.events.EventBase"
+ ) -> Union[bool, str]:
+ """Checks if a given federated event is considered "spammy" by this
+ server.
+
+ If the server considers an event spammy, it will be silently dropped,
+ and in doing so will split-brain our view of the room's DAG.
+
+ Args:
+ event: the event to be checked
+
+ Returns:
+ True if the event should be silently dropped
+ """
+ for callback in self._should_drop_federated_event_callbacks:
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ res: Union[bool, str] = await delay_cancellation(callback(event))
if res:
return res
@@ -276,9 +365,12 @@ class SpamChecker:
Whether the user may join the room
"""
for callback in self._user_may_join_room_callbacks:
- may_join_room = await delay_cancellation(
- callback(user_id, room_id, is_invited)
- )
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ may_join_room = await delay_cancellation(
+ callback(user_id, room_id, is_invited)
+ )
if may_join_room is False:
return False
@@ -300,9 +392,12 @@ class SpamChecker:
True if the user may send an invite, otherwise False
"""
for callback in self._user_may_invite_callbacks:
- may_invite = await delay_cancellation(
- callback(inviter_userid, invitee_userid, room_id)
- )
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ may_invite = await delay_cancellation(
+ callback(inviter_userid, invitee_userid, room_id)
+ )
if may_invite is False:
return False
@@ -328,9 +423,12 @@ class SpamChecker:
True if the user may send the invite, otherwise False
"""
for callback in self._user_may_send_3pid_invite_callbacks:
- may_send_3pid_invite = await delay_cancellation(
- callback(inviter_userid, medium, address, room_id)
- )
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ may_send_3pid_invite = await delay_cancellation(
+ callback(inviter_userid, medium, address, room_id)
+ )
if may_send_3pid_invite is False:
return False
@@ -348,7 +446,10 @@ class SpamChecker:
True if the user may create a room, otherwise False
"""
for callback in self._user_may_create_room_callbacks:
- may_create_room = await delay_cancellation(callback(userid))
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ may_create_room = await delay_cancellation(callback(userid))
if may_create_room is False:
return False
@@ -369,9 +470,12 @@ class SpamChecker:
True if the user may create a room alias, otherwise False
"""
for callback in self._user_may_create_room_alias_callbacks:
- may_create_room_alias = await delay_cancellation(
- callback(userid, room_alias)
- )
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ may_create_room_alias = await delay_cancellation(
+ callback(userid, room_alias)
+ )
if may_create_room_alias is False:
return False
@@ -390,7 +494,10 @@ class SpamChecker:
True if the user may publish the room, otherwise False
"""
for callback in self._user_may_publish_room_callbacks:
- may_publish_room = await delay_cancellation(callback(userid, room_id))
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ may_publish_room = await delay_cancellation(callback(userid, room_id))
if may_publish_room is False:
return False
@@ -412,9 +519,13 @@ class SpamChecker:
True if the user is spammy.
"""
for callback in self._check_username_for_spam_callbacks:
- # Make a copy of the user profile object to ensure the spam checker cannot
- # modify it.
- if await delay_cancellation(callback(user_profile.copy())):
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ # Make a copy of the user profile object to ensure the spam checker cannot
+ # modify it.
+ res = await delay_cancellation(callback(user_profile.copy()))
+ if res:
return True
return False
@@ -442,9 +553,12 @@ class SpamChecker:
"""
for callback in self._check_registration_for_spam_callbacks:
- behaviour = await delay_cancellation(
- callback(email_threepid, username, request_info, auth_provider_id)
- )
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ behaviour = await delay_cancellation(
+ callback(email_threepid, username, request_info, auth_provider_id)
+ )
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
@@ -486,7 +600,10 @@ class SpamChecker:
"""
for callback in self._check_media_file_for_spam_callbacks:
- spam = await delay_cancellation(callback(file_wrapper, file_info))
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ spam = await delay_cancellation(callback(file_wrapper, file_info))
if spam:
return True
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9f4ff979..35f3f369 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -152,6 +152,7 @@ class ThirdPartyEventRules:
self.third_party_rules = None
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
@@ -463,7 +464,7 @@ class ThirdPartyEventRules:
Returns:
A dict mapping (event type, state key) to state event.
"""
- state_ids = await self.store.get_filtered_current_state_ids(room_id)
+ state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())
state_events = {}
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 360d2427..29fa9b38 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
-from typing import Iterable, Type, Union
+from typing import Iterable, Type, Union, cast
import jsonschema
@@ -103,7 +103,12 @@ class EventValidator:
except jsonschema.ValidationError as e:
if e.path:
# example: "users_default": '0' is not of type 'integer'
- message = '"' + e.path[-1] + '": ' + e.message # noqa: B306
+ # cast safety: path entries can be integers, if we fail to validate
+ # items in an array. However the POWER_LEVELS_SCHEMA doesn't expect
+ # to see any arrays.
+ message = (
+ '"' + cast(str, e.path[-1]) + '": ' + e.message # noqa: B306
+ )
# jsonschema.ValidationError.message is a valid attribute
else:
# example: '0' is not of type 'integer'
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 41ac49fd..2522bf78 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -32,6 +32,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class InvalidEventSignatureError(RuntimeError):
+ """Raised when the signature on an event is invalid.
+
+ The stringification of this exception is just the error message without reference
+ to the event id. The event id is available as a property.
+ """
+
+ def __init__(self, message: str, event_id: str):
+ super().__init__(message)
+ self.event_id = event_id
+
+
class FederationBase:
def __init__(self, hs: "HomeServer"):
self.hs = hs
@@ -41,6 +53,7 @@ class FederationBase:
self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastores().main
self._clock = hs.get_clock()
+ self._storage_controllers = hs.get_storage_controllers()
async def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase
@@ -59,20 +72,13 @@ class FederationBase:
Returns:
* the original event if the checks pass
* a redacted version of the event (if the signature
- matched but the hash did not)
+ matched but the hash did not). In this case a warning will be logged.
Raises:
- SynapseError if the signature check failed.
+ InvalidEventSignatureError if the signature check failed. Nothing
+ will be logged in this case.
"""
- try:
- await _check_sigs_on_pdu(self.keyring, room_version, pdu)
- except SynapseError as e:
- logger.warning(
- "Signature check failed for %s: %s",
- pdu.event_id,
- e,
- )
- raise
+ await _check_sigs_on_pdu(self.keyring, room_version, pdu)
if not check_event_content_hash(pdu):
# let's try to distinguish between failures because the event was
@@ -87,7 +93,7 @@ class FederationBase:
if set(redacted_event.keys()) == set(pdu.keys()) and set(
redacted_event.content.keys()
) == set(pdu.content.keys()):
- logger.info(
+ logger.debug(
"Event %s seems to have been redacted; using our redacted copy",
pdu.event_id,
)
@@ -98,9 +104,9 @@ class FederationBase:
)
return redacted_event
- result = await self.spam_checker.check_event_for_spam(pdu)
+ spam_check = await self.spam_checker.check_event_for_spam(pdu)
- if result:
+ if spam_check != self.spam_checker.NOT_SPAM:
logger.warning("Event contains spam, soft-failing %s", pdu.event_id)
# we redact (to save disk space) as well as soft-failing (to stop
# using the event in prev_events).
@@ -116,12 +122,13 @@ async def _check_sigs_on_pdu(
) -> None:
"""Check that the given events are correctly signed
- Raise a SynapseError if the event wasn't correctly signed.
-
Args:
keyring: keyring object to do the checks
room_version: the room version of the PDUs
pdus: the events to be checked
+
+ Raises:
+ InvalidEventSignatureError if the event wasn't correctly signed.
"""
# we want to check that the event is signed by:
@@ -147,44 +154,38 @@ async def _check_sigs_on_pdu(
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
+ sender_domain = get_domain_from_id(pdu.sender)
if not _is_invite_via_3pid(pdu):
try:
await keyring.verify_event_for_server(
- get_domain_from_id(pdu.sender),
+ sender_domain,
pdu,
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
except Exception as e:
- errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
+ raise InvalidEventSignatureError(
+ f"unable to verify signature for sender domain {sender_domain}: {e}",
pdu.event_id,
- get_domain_from_id(pdu.sender),
- e,
- )
- raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+ ) from None
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
- if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id(
- pdu.event_id
- ) != get_domain_from_id(pdu.sender):
- try:
- await keyring.verify_event_for_server(
- get_domain_from_id(pdu.event_id),
- pdu,
- pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- )
- except Exception as e:
- errmsg = (
- "event id %s: unable to verify signature for event id domain %s: %s"
- % (
- pdu.event_id,
- get_domain_from_id(pdu.event_id),
- e,
+ if room_version.event_format == EventFormatVersions.V1:
+ event_domain = get_domain_from_id(pdu.event_id)
+ if event_domain != sender_domain:
+ try:
+ await keyring.verify_event_for_server(
+ event_domain,
+ pdu,
+ pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
- )
- raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+ except Exception as e:
+ raise InvalidEventSignatureError(
+ f"unable to verify signature for event domain {event_domain}: {e}",
+ pdu.event_id,
+ ) from None
# If this is a join event for a restricted room it may have been authorised
# via a different server from the sending server. Check those signatures.
@@ -204,15 +205,10 @@ async def _check_sigs_on_pdu(
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
except Exception as e:
- errmsg = (
- "event id %s: unable to verify signature for authorising server %s: %s"
- % (
- pdu.event_id,
- authorising_server,
- e,
- )
- )
- raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+ raise InvalidEventSignatureError(
+ f"unable to verify signature for authorising serve {authorising_server}: {e}",
+ pdu.event_id,
+ ) from None
def _is_invite_via_3pid(event: EventBase) -> bool:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 17eff609..ad475a91 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -54,7 +54,11 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.events import EventBase, builder
-from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.federation.federation_base import (
+ FederationBase,
+ InvalidEventSignatureError,
+ event_from_pdu_json,
+)
from synapse.federation.transport.client import SendJoinResponse
from synapse.http.types import QueryParams
from synapse.types import JsonDict, UserID, get_domain_from_id
@@ -319,7 +323,13 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
+ try:
+ signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
+ except InvalidEventSignatureError as e:
+ errmsg = f"event id {pdu.event_id}: {e}"
+ logger.warning("%s", errmsg)
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+
return signed_pdu
return None
@@ -405,6 +415,9 @@ class FederationClient(FederationBase):
Returns:
a tuple of (state event_ids, auth event_ids)
+
+ Raises:
+ InvalidResponseError: if fields in the response have the wrong type.
"""
result = await self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id
@@ -416,7 +429,7 @@ class FederationClient(FederationBase):
if not isinstance(state_event_ids, list) or not isinstance(
auth_event_ids, list
):
- raise Exception("invalid response from /state_ids")
+ raise InvalidResponseError("invalid response from /state_ids")
return state_event_ids, auth_event_ids
@@ -552,20 +565,24 @@ class FederationClient(FederationBase):
Returns:
The PDU (possibly redacted) if it has valid signatures and hashes.
+ None if no valid copy could be found.
"""
- res = None
try:
- res = await self._check_sigs_and_hash(room_version, pdu)
- except SynapseError:
- pass
-
- if not res:
- # Check local db.
- res = await self.store.get_event(
- pdu.event_id, allow_rejected=True, allow_none=True
+ return await self._check_sigs_and_hash(room_version, pdu)
+ except InvalidEventSignatureError as e:
+ logger.warning(
+ "Signature on retrieved event %s was invalid (%s). "
+ "Checking local store/orgin server",
+ pdu.event_id,
+ e,
)
+ # Check local db.
+ res = await self.store.get_event(
+ pdu.event_id, allow_rejected=True, allow_none=True
+ )
+
pdu_origin = get_domain_from_id(pdu.sender)
if not res and pdu_origin != origin:
try:
@@ -1040,9 +1057,14 @@ class FederationClient(FederationBase):
pdu = event_from_pdu_json(pdu_dict, room_version)
# Check signatures are correct.
- pdu = await self._check_sigs_and_hash(room_version, pdu)
+ try:
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+ except InvalidEventSignatureError as e:
+ errmsg = f"event id {pdu.event_id}: {e}"
+ logger.warning("%s", errmsg)
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
- # FIXME: We should handle signature failures more gracefully.
+ # FIXME: We should handle signature failures more gracefully.
return pdu
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 884b5d60..3e1518f1 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -48,7 +48,11 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.federation.federation_base import (
+ FederationBase,
+ InvalidEventSignatureError,
+ event_from_pdu_json,
+)
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.servlet import assert_params_in_dict
@@ -109,11 +113,13 @@ class FederationServer(FederationBase):
super().__init__(hs)
self.handler = hs.get_federation_handler()
- self.storage = hs.get_storage()
+ self._spam_checker = hs.get_spam_checker()
self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
+ self._state_storage_controller = hs.get_storage_controllers().state
+
self.device_handler = hs.get_device_handler()
# Ensure the following handlers are loaded since they register callbacks
@@ -631,7 +637,12 @@ class FederationServer(FederationBase):
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
- pdu = await self._check_sigs_and_hash(room_version, pdu)
+ try:
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+ except InvalidEventSignatureError as e:
+ errmsg = f"event id {pdu.event_id}: {e}"
+ logger.warning("%s", errmsg)
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
@@ -864,7 +875,12 @@ class FederationServer(FederationBase):
)
)
- event = await self._check_sigs_and_hash(room_version, event)
+ try:
+ event = await self._check_sigs_and_hash(room_version, event)
+ except InvalidEventSignatureError as e:
+ errmsg = f"event id {event.event_id}: {e}"
+ logger.warning("%s", errmsg)
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
return await self._federation_event_handler.on_send_membership_event(
origin, event
@@ -1016,8 +1032,15 @@ class FederationServer(FederationBase):
# Check signature.
try:
pdu = await self._check_sigs_and_hash(room_version, pdu)
- except SynapseError as e:
- raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
+ except InvalidEventSignatureError as e:
+ logger.warning("event id %s: %s", pdu.event_id, e)
+ raise FederationError("ERROR", 403, str(e), affected=pdu.event_id)
+
+ if await self._spam_checker.should_drop_federated_event(pdu):
+ logger.warning(
+ "Unstaged federated event contains spam, dropping %s", pdu.event_id
+ )
+ return
# Add the event to our staging area
await self.store.insert_received_event_to_staging(origin, pdu)
@@ -1032,6 +1055,41 @@ class FederationServer(FederationBase):
pdu.room_id, room_version, lock, origin, pdu
)
+ async def _get_next_nonspam_staged_event_for_room(
+ self, room_id: str, room_version: RoomVersion
+ ) -> Optional[Tuple[str, EventBase]]:
+ """Fetch the first non-spam event from staging queue.
+
+ Args:
+ room_id: the room to fetch the first non-spam event in.
+ room_version: the version of the room.
+
+ Returns:
+ The first non-spam event in that room.
+ """
+
+ while True:
+ # We need to do this check outside the lock to avoid a race between
+ # a new event being inserted by another instance and it attempting
+ # to acquire the lock.
+ next = await self.store.get_next_staged_event_for_room(
+ room_id, room_version
+ )
+
+ if next is None:
+ return None
+
+ origin, event = next
+
+ if await self._spam_checker.should_drop_federated_event(event):
+ logger.warning(
+ "Staged federated event contains spam, dropping %s",
+ event.event_id,
+ )
+ continue
+
+ return next
+
@wrap_as_background_process("_process_incoming_pdus_in_room_inner")
async def _process_incoming_pdus_in_room_inner(
self,
@@ -1109,12 +1167,10 @@ class FederationServer(FederationBase):
(self._clock.time_msec() - received_ts) / 1000
)
- # We need to do this check outside the lock to avoid a race between
- # a new event being inserted by another instance and it attempting
- # to acquire the lock.
- next = await self.store.get_next_staged_event_for_room(
+ next = await self._get_next_nonspam_staged_event_for_room(
room_id, room_version
)
+
if not next:
break
@@ -1167,14 +1223,10 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
- state_ids = await self.store.get_current_state_ids(room_id)
- acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
-
- if not acl_event_id:
- return
-
- acl_event = await self.store.get_event(acl_event_id)
- if server_matches_acl_event(server_name, acl_event):
+ acl_event = await self._storage_controllers.state.get_current_state_event(
+ room_id, EventTypes.ServerACL, ""
+ )
+ if not acl_event or server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")
@@ -1313,7 +1365,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
- if not self.config.server.use_presence and edu_type == EduTypes.Presence:
+ if not self.config.server.use_presence and edu_type == EduTypes.PRESENCE:
return
# Check if we have a handler on this instance
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 6d2f4631..99a794c0 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -15,7 +15,17 @@
import abc
import logging
from collections import OrderedDict
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Hashable,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
import attr
from prometheus_client import Counter
@@ -235,6 +245,8 @@ class FederationSender(AbstractFederationSender):
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
+ self._storage_controllers = hs.get_storage_controllers()
+
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -409,7 +421,7 @@ class FederationSender(AbstractFederationSender):
)
return
- destinations: Optional[Set[str]] = None
+ destinations: Optional[Collection[str]] = None
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
@@ -444,7 +456,7 @@ class FederationSender(AbstractFederationSender):
)
return
- destinations = {
+ sharded_destinations = {
d
for d in destinations
if self._federation_shard_config.should_handle(
@@ -456,12 +468,12 @@ class FederationSender(AbstractFederationSender):
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to
# send the event to it.
- destinations.discard(send_on_behalf_of)
+ sharded_destinations.discard(send_on_behalf_of)
- logger.debug("Sending %s to %r", event, destinations)
+ logger.debug("Sending %s to %r", event, sharded_destinations)
- if destinations:
- await self._send_pdu(event, destinations)
+ if sharded_destinations:
+ await self._send_pdu(event, sharded_destinations)
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
@@ -592,7 +604,9 @@ class FederationSender(AbstractFederationSender):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
- domains_set = await self.state.get_current_hosts_in_room(room_id)
+ domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
+ room_id
+ )
domains = [
d
for d in domains_set
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index d80f0ac5..333ca9a9 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tupl
import attr
from prometheus_client import Counter
+from synapse.api.constants import EduTypes
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -223,7 +224,7 @@ class PerDestinationQueue:
"""Marks that the destination has new data to send, without starting a
new transaction.
- If a transaction loop is already in progress then a new transcation will
+ If a transaction loop is already in progress then a new transaction will
be attempted when the current one finishes.
"""
@@ -542,7 +543,7 @@ class PerDestinationQueue:
edu = Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.receipt",
+ edu_type=EduTypes.RECEIPT,
content=self._pending_rrs,
)
self._pending_rrs = {}
@@ -592,7 +593,7 @@ class PerDestinationQueue:
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.direct_to_device",
+ edu_type=EduTypes.DIRECT_TO_DEVICE,
content=content,
)
for content in contents
@@ -670,7 +671,7 @@ class _TransactionQueueManager:
Edu(
origin=self.queue._server_name,
destination=self.queue._destination,
- edu_type="m.presence",
+ edu_type=EduTypes.PRESENCE,
content={
"push": [
format_user_presence_state(
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 0c1cad86..75081810 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, List
from prometheus_client import Gauge
+from synapse.api.constants import EduTypes
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -126,7 +127,10 @@ class TransactionManager:
len(edus),
)
if issue_8631_logger.isEnabledFor(logging.DEBUG):
- DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"}
+ DEVICE_UPDATE_EDUS = {
+ EduTypes.DEVICE_LIST_UPDATE,
+ EduTypes.SIGNING_KEY_UPDATE,
+ }
device_list_updates = [
edu.content for edu in edus if edu.edu_type in DEVICE_UPDATE_EDUS
]
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9ce06dfa..9e84bd67 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -17,7 +17,6 @@ import logging
import urllib
from typing import (
Any,
- Awaitable,
Callable,
Collection,
Dict,
@@ -49,11 +48,6 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
-# Send join responses can be huge, so we set a separate limit here. The response
-# is parsed in a streaming manner, which helps alleviate the issue of memory
-# usage a bit.
-MAX_RESPONSE_SIZE_SEND_JOIN = 500 * 1024 * 1024
-
class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""
@@ -349,7 +343,6 @@ class TransportLayerClient:
path=path,
data=content,
parser=SendJoinParser(room_version, v1_api=True),
- max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
async def send_join_v2(
@@ -372,7 +365,6 @@ class TransportLayerClient:
args=query_params,
data=content,
parser=SendJoinParser(room_version, v1_api=False),
- max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
async def send_leave_v1(
@@ -688,488 +680,6 @@ class TransportLayerClient:
timeout=timeout,
)
- async def get_group_profile(
- self, destination: str, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get a group profile"""
- path = _create_v1_path("/groups/%s/profile", group_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def update_group_profile(
- self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """Update a remote group profile
-
- Args:
- destination
- group_id
- requester_user_id
- content: The new profile of the group
- """
- path = _create_v1_path("/groups/%s/profile", group_id)
-
- return self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def get_group_summary(
- self, destination: str, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get a group summary"""
- path = _create_v1_path("/groups/%s/summary", group_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def get_rooms_in_group(
- self, destination: str, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get all rooms in a group"""
- path = _create_v1_path("/groups/%s/rooms", group_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def add_room_to_group(
- self,
- destination: str,
- group_id: str,
- requester_user_id: str,
- room_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Add a room to a group"""
- path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
-
- return await self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def update_room_in_group(
- self,
- destination: str,
- group_id: str,
- requester_user_id: str,
- room_id: str,
- config_key: str,
- content: JsonDict,
- ) -> JsonDict:
- """Update room in group"""
- path = _create_v1_path(
- "/groups/%s/room/%s/config/%s", group_id, room_id, config_key
- )
-
- return await self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def remove_room_from_group(
- self, destination: str, group_id: str, requester_user_id: str, room_id: str
- ) -> JsonDict:
- """Remove a room from a group"""
- path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
-
- return await self.client.delete_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def get_users_in_group(
- self, destination: str, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get users in a group"""
- path = _create_v1_path("/groups/%s/users", group_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def get_invited_users_in_group(
- self, destination: str, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get users that have been invited to a group"""
- path = _create_v1_path("/groups/%s/invited_users", group_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def accept_group_invite(
- self, destination: str, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Accept a group invite"""
- path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
-
- return await self.client.post_json(
- destination=destination, path=path, data=content, ignore_backoff=True
- )
-
- def join_group(
- self, destination: str, group_id: str, user_id: str, content: JsonDict
- ) -> Awaitable[JsonDict]:
- """Attempts to join a group"""
- path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
-
- return self.client.post_json(
- destination=destination, path=path, data=content, ignore_backoff=True
- )
-
- async def invite_to_group(
- self,
- destination: str,
- group_id: str,
- user_id: str,
- requester_user_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Invite a user to a group"""
- path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
-
- return await self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def invite_to_group_notification(
- self, destination: str, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Sent by group server to inform a user's server that they have been
- invited.
- """
-
- path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
-
- return await self.client.post_json(
- destination=destination, path=path, data=content, ignore_backoff=True
- )
-
- async def remove_user_from_group(
- self,
- destination: str,
- group_id: str,
- requester_user_id: str,
- user_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Remove a user from a group"""
- path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
-
- return await self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def remove_user_from_group_notification(
- self, destination: str, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Sent by group server to inform a user's server that they have been
- kicked from the group.
- """
-
- path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
-
- return await self.client.post_json(
- destination=destination, path=path, data=content, ignore_backoff=True
- )
-
- async def renew_group_attestation(
- self, destination: str, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Sent by either a group server or a user's server to periodically update
- the attestations
- """
-
- path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
-
- return await self.client.post_json(
- destination=destination, path=path, data=content, ignore_backoff=True
- )
-
- async def update_group_summary_room(
- self,
- destination: str,
- group_id: str,
- user_id: str,
- room_id: str,
- category_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Update a room entry in a group summary"""
- if category_id:
- path = _create_v1_path(
- "/groups/%s/summary/categories/%s/rooms/%s",
- group_id,
- category_id,
- room_id,
- )
- else:
- path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
-
- return await self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def delete_group_summary_room(
- self,
- destination: str,
- group_id: str,
- user_id: str,
- room_id: str,
- category_id: str,
- ) -> JsonDict:
- """Delete a room entry in a group summary"""
- if category_id:
- path = _create_v1_path(
- "/groups/%s/summary/categories/%s/rooms/%s",
- group_id,
- category_id,
- room_id,
- )
- else:
- path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
-
- return await self.client.delete_json(
- destination=destination,
- path=path,
- args={"requester_user_id": user_id},
- ignore_backoff=True,
- )
-
- async def get_group_categories(
- self, destination: str, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get all categories in a group"""
- path = _create_v1_path("/groups/%s/categories", group_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def get_group_category(
- self, destination: str, group_id: str, requester_user_id: str, category_id: str
- ) -> JsonDict:
- """Get category info in a group"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def update_group_category(
- self,
- destination: str,
- group_id: str,
- requester_user_id: str,
- category_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Update a category in a group"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
-
- return await self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def delete_group_category(
- self, destination: str, group_id: str, requester_user_id: str, category_id: str
- ) -> JsonDict:
- """Delete a category in a group"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
-
- return await self.client.delete_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def get_group_roles(
- self, destination: str, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get all roles in a group"""
- path = _create_v1_path("/groups/%s/roles", group_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def get_group_role(
- self, destination: str, group_id: str, requester_user_id: str, role_id: str
- ) -> JsonDict:
- """Get a roles info"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
-
- return await self.client.get_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def update_group_role(
- self,
- destination: str,
- group_id: str,
- requester_user_id: str,
- role_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Update a role in a group"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
-
- return await self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def delete_group_role(
- self, destination: str, group_id: str, requester_user_id: str, role_id: str
- ) -> JsonDict:
- """Delete a role in a group"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
-
- return await self.client.delete_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def update_group_summary_user(
- self,
- destination: str,
- group_id: str,
- requester_user_id: str,
- user_id: str,
- role_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Update a users entry in a group"""
- if role_id:
- path = _create_v1_path(
- "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
- )
- else:
- path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
-
- return await self.client.post_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def set_group_join_policy(
- self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """Sets the join policy for a group"""
- path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
-
- return await self.client.put_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- data=content,
- ignore_backoff=True,
- )
-
- async def delete_group_summary_user(
- self,
- destination: str,
- group_id: str,
- requester_user_id: str,
- user_id: str,
- role_id: str,
- ) -> JsonDict:
- """Delete a users entry in a group"""
- if role_id:
- path = _create_v1_path(
- "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
- )
- else:
- path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
-
- return await self.client.delete_json(
- destination=destination,
- path=path,
- args={"requester_user_id": requester_user_id},
- ignore_backoff=True,
- )
-
- async def bulk_get_publicised_groups(
- self, destination: str, user_ids: Iterable[str]
- ) -> JsonDict:
- """Get the groups a list of users are publicising"""
-
- path = _create_v1_path("/get_groups_publicised")
-
- content = {"user_ids": user_ids}
-
- return await self.client.post_json(
- destination=destination, path=path, data=content, ignore_backoff=True
- )
-
async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict:
"""
Args:
@@ -1360,10 +870,15 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json"
+ # /send_join responses can be huge, so we override the size limit here. The response
+ # is parsed in a streaming manner, which helps alleviate the issue of memory
+ # usage a bit.
+ MAX_RESPONSE_SIZE = 500 * 1024 * 1024
+
def __init__(self, room_version: RoomVersion, v1_api: bool):
self._response = SendJoinResponse([], [], event_dict={})
self._room_version = room_version
- self._coros = []
+ self._coros: List[Generator[None, bytes, None]] = []
# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
@@ -1411,6 +926,9 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
return len(data)
def finish(self) -> SendJoinResponse:
+ for c in self._coros:
+ c.close()
+
if self._response.event_dict:
self._response.event = make_event_from_dict(
self._response.event_dict, self._room_version
@@ -1427,10 +945,13 @@ class _StateParser(ByteParser[StateRequestResponse]):
CONTENT_TYPE = "application/json"
+ # As with /send_join, /state responses can be huge.
+ MAX_RESPONSE_SIZE = 500 * 1024 * 1024
+
def __init__(self, room_version: RoomVersion):
self._response = StateRequestResponse([], [])
self._room_version = room_version
- self._coros = [
+ self._coros: List[Generator[None, bytes, None]] = [
ijson.items_coro(
_event_list_parser(room_version, self._response.state),
"pdus.item",
@@ -1449,4 +970,6 @@ class _StateParser(ByteParser[StateRequestResponse]):
return len(data)
def finish(self) -> StateRequestResponse:
+ for c in self._coros:
+ c.close()
return self._response
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 71b2f90e..50623cd3 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -27,10 +27,6 @@ from synapse.federation.transport.server.federation import (
FederationAccountStatusServlet,
FederationTimestampLookupServlet,
)
-from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES
-from synapse.federation.transport.server.groups_server import (
- GROUP_SERVER_SERVLET_CLASSES,
-)
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
parse_boolean_from_args,
@@ -199,38 +195,6 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
-class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
- """A group or user's server renews their attestation"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
-
- def __init__(
- self,
- hs: "HomeServer",
- authenticator: Authenticator,
- ratelimiter: FederationRateLimiter,
- server_name: str,
- ):
- super().__init__(hs, authenticator, ratelimiter, server_name)
- self.handler = hs.get_groups_attestation_renewer()
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- # We don't need to check auth here as we check the attestation signatures
-
- new_content = await self.handler.on_renew_attestation(
- group_id, user_id, content
- )
-
- return 200, new_content
-
-
class OpenIdUserInfo(BaseFederationServlet):
"""
Exchange a bearer token for information about a user.
@@ -292,16 +256,9 @@ class OpenIdUserInfo(BaseFederationServlet):
SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
"federation": FEDERATION_SERVLET_CLASSES,
"room_list": (PublicRoomList,),
- "group_server": GROUP_SERVER_SERVLET_CLASSES,
- "group_local": GROUP_LOCAL_SERVLET_CLASSES,
- "group_attestation": (FederationGroupsRenewAttestaionServlet,),
"openid": (OpenIdUserInfo,),
}
-DEFAULT_SERVLET_GROUPS = ("federation", "room_list", "openid")
-
-GROUP_SERVLET_GROUPS = ("group_server", "group_local", "group_attestation")
-
def register_servlets(
hs: "HomeServer",
@@ -324,10 +281,7 @@ def register_servlets(
Defaults to ``DEFAULT_SERVLET_GROUPS``.
"""
if not servlet_groups:
- servlet_groups = DEFAULT_SERVLET_GROUPS
- # Only allow the groups servlets if the deprecated groups feature is enabled.
- if hs.config.experimental.groups_enabled:
- servlet_groups = servlet_groups + GROUP_SERVLET_GROUPS
+ servlet_groups = SERVLET_GROUPS.keys()
for servlet_group in servlet_groups:
# Skip unknown servlet groups.
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index d629a3ec..84100a5a 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX
-from synapse.http.server import HttpServer, ServletCallback
+from synapse.http.server import HttpServer, ServletCallback, is_method_cancellable
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import run_in_background
@@ -169,14 +169,16 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str
"""
try:
header_str = header_bytes.decode("utf-8")
- params = header_str.split(" ")[1].split(",")
+ params = re.split(" +", header_str)[1].split(",")
param_dict: Dict[str, str] = {
- k: v for k, v in [param.split("=", maxsplit=1) for param in params]
+ k.lower(): v for k, v in [param.split("=", maxsplit=1) for param in params]
}
def strip_quotes(value: str) -> str:
if value.startswith('"'):
- return value[1:-1]
+ return re.sub(
+ "\\\\(.)", lambda matchobj: matchobj.group(1), value[1:-1]
+ )
else:
return value
@@ -373,6 +375,17 @@ class BaseFederationServlet:
if code is None:
continue
+ if is_method_cancellable(code):
+ # The wrapper added by `self._wrap` will inherit the cancellable flag,
+ # but the wrapper itself does not support cancellation yet.
+ # Once resolved, the cancellation tests in
+ # `tests/federation/transport/server/test__base.py` can be re-enabled.
+ raise Exception(
+ f"{self.__class__.__name__}.on_{method} has been marked as "
+ "cancellable, but federation servlets do not support cancellation "
+ "yet."
+ )
+
server.register_paths(
method,
(pattern,),
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 6fbc7b5f..7dfb8906 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -27,6 +27,7 @@ from typing import (
from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import Literal
+from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX
@@ -108,7 +109,10 @@ class FederationSendServlet(BaseFederationServerServlet):
)
if issue_8631_logger.isEnabledFor(logging.DEBUG):
- DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"]
+ DEVICE_UPDATE_EDUS = [
+ EduTypes.DEVICE_LIST_UPDATE,
+ EduTypes.SIGNING_KEY_UPDATE,
+ ]
device_list_updates = [
edu.get("content", {})
for edu in transaction_data.get("edus", [])
@@ -650,10 +654,6 @@ class FederationRoomHierarchyServlet(BaseFederationServlet):
)
-class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet):
- PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
-
-
class RoomComplexityServlet(BaseFederationServlet):
"""
Indicates to other servers how complex (and therefore likely
@@ -752,7 +752,6 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationVersionServlet,
RoomComplexityServlet,
FederationRoomHierarchyServlet,
- FederationRoomHierarchyUnstableServlet,
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
FederationAccountStatusServlet,
diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py
deleted file mode 100644
index 496472e1..00000000
--- a/synapse/federation/transport/server/groups_local.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING, Dict, List, Tuple, Type
-
-from synapse.api.errors import SynapseError
-from synapse.federation.transport.server._base import (
- Authenticator,
- BaseFederationServlet,
-)
-from synapse.handlers.groups_local import GroupsLocalHandler
-from synapse.types import JsonDict, get_domain_from_id
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class BaseGroupsLocalServlet(BaseFederationServlet):
- """Abstract base class for federation servlet classes which provides a groups local handler.
-
- See BaseFederationServlet for more information.
- """
-
- def __init__(
- self,
- hs: "HomeServer",
- authenticator: Authenticator,
- ratelimiter: FederationRateLimiter,
- server_name: str,
- ):
- super().__init__(hs, authenticator, ratelimiter, server_name)
- self.handler = hs.get_groups_local_handler()
-
-
-class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet):
- """A group server has invited a local user"""
-
- PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- if get_domain_from_id(group_id) != origin:
- raise SynapseError(403, "group_id doesn't match origin")
-
- assert isinstance(
- self.handler, GroupsLocalHandler
- ), "Workers cannot handle group invites."
-
- new_content = await self.handler.on_invite(group_id, user_id, content)
-
- return 200, new_content
-
-
-class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
- """A group server has removed a local user"""
-
- PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- user_id: str,
- ) -> Tuple[int, None]:
- if get_domain_from_id(group_id) != origin:
- raise SynapseError(403, "user_id doesn't match origin")
-
- assert isinstance(
- self.handler, GroupsLocalHandler
- ), "Workers cannot handle group removals."
-
- await self.handler.user_removed_from_group(group_id, user_id, content)
-
- return 200, None
-
-
-class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet):
- """Get roles in a group"""
-
- PATH = "/get_groups_publicised"
-
- async def on_POST(
- self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
- ) -> Tuple[int, JsonDict]:
- resp = await self.handler.bulk_get_publicised_groups(
- content["user_ids"], proxy=False
- )
-
- return 200, resp
-
-
-GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
- FederationGroupsLocalInviteServlet,
- FederationGroupsRemoveLocalUserServlet,
- FederationGroupsBulkPublicisedServlet,
-)
diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py
deleted file mode 100644
index 851b5015..00000000
--- a/synapse/federation/transport/server/groups_server.py
+++ /dev/null
@@ -1,755 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING, Dict, List, Tuple, Type
-
-from typing_extensions import Literal
-
-from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
-from synapse.api.errors import Codes, SynapseError
-from synapse.federation.transport.server._base import (
- Authenticator,
- BaseFederationServlet,
-)
-from synapse.http.servlet import parse_string_from_args
-from synapse.types import JsonDict, get_domain_from_id
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class BaseGroupsServerServlet(BaseFederationServlet):
- """Abstract base class for federation servlet classes which provides a groups server handler.
-
- See BaseFederationServlet for more information.
- """
-
- def __init__(
- self,
- hs: "HomeServer",
- authenticator: Authenticator,
- ratelimiter: FederationRateLimiter,
- server_name: str,
- ):
- super().__init__(hs, authenticator, ratelimiter, server_name)
- self.handler = hs.get_groups_server_handler()
-
-
-class FederationGroupsProfileServlet(BaseGroupsServerServlet):
- """Get/set the basic profile of a group on behalf of a user"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/profile"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.get_group_profile(group_id, requester_user_id)
-
- return 200, new_content
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.update_group_profile(
- group_id, requester_user_id, content
- )
-
- return 200, new_content
-
-
-class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
- PATH = "/groups/(?P<group_id>[^/]*)/summary"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.get_group_summary(group_id, requester_user_id)
-
- return 200, new_content
-
-
-class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
- """Get the rooms in a group on behalf of a user"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/rooms"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id)
-
- return 200, new_content
-
-
-class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
- """Add/remove room from group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- room_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.add_room_to_group(
- group_id, requester_user_id, room_id, content
- )
-
- return 200, new_content
-
- async def on_DELETE(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- room_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.remove_room_from_group(
- group_id, requester_user_id, room_id
- )
-
- return 200, new_content
-
-
-class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
- """Update room config in group"""
-
- PATH = (
- "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
- "/config/(?P<config_key>[^/]*)"
- )
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- room_id: str,
- config_key: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- result = await self.handler.update_room_in_group(
- group_id, requester_user_id, room_id, config_key, content
- )
-
- return 200, result
-
-
-class FederationGroupsUsersServlet(BaseGroupsServerServlet):
- """Get the users in a group on behalf of a user"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/users"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.get_users_in_group(group_id, requester_user_id)
-
- return 200, new_content
-
-
-class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
- """Get the users that have been invited to a group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.get_invited_users_in_group(
- group_id, requester_user_id
- )
-
- return 200, new_content
-
-
-class FederationGroupsInviteServlet(BaseGroupsServerServlet):
- """Ask a group server to invite someone to the group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.invite_to_group(
- group_id, user_id, requester_user_id, content
- )
-
- return 200, new_content
-
-
-class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet):
- """Accept an invitation from the group server"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- if get_domain_from_id(user_id) != origin:
- raise SynapseError(403, "user_id doesn't match origin")
-
- new_content = await self.handler.accept_invite(group_id, user_id, content)
-
- return 200, new_content
-
-
-class FederationGroupsJoinServlet(BaseGroupsServerServlet):
- """Attempt to join a group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- if get_domain_from_id(user_id) != origin:
- raise SynapseError(403, "user_id doesn't match origin")
-
- new_content = await self.handler.join_group(group_id, user_id, content)
-
- return 200, new_content
-
-
-class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
- """Leave or kick a user from the group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content
- )
-
- return 200, new_content
-
-
-class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
- """Add/remove a room from the group summary, with optional category.
-
- Matches both:
- - /groups/:group/summary/rooms/:room_id
- - /groups/:group/summary/categories/:category/rooms/:room_id
- """
-
- PATH = (
- "/groups/(?P<group_id>[^/]*)/summary"
- "(/categories/(?P<category_id>[^/]+))?"
- "/rooms/(?P<room_id>[^/]*)"
- )
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- category_id: str,
- room_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- if category_id == "":
- raise SynapseError(
- 400, "category_id cannot be empty string", Codes.INVALID_PARAM
- )
-
- if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
- raise SynapseError(
- 400,
- "category_id may not be longer than %s characters"
- % (MAX_GROUP_CATEGORYID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- resp = await self.handler.update_group_summary_room(
- group_id,
- requester_user_id,
- room_id=room_id,
- category_id=category_id,
- content=content,
- )
-
- return 200, resp
-
- async def on_DELETE(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- category_id: str,
- room_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- if category_id == "":
- raise SynapseError(400, "category_id cannot be empty string")
-
- resp = await self.handler.delete_group_summary_room(
- group_id, requester_user_id, room_id=room_id, category_id=category_id
- )
-
- return 200, resp
-
-
-class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
- """Get all categories for a group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- resp = await self.handler.get_group_categories(group_id, requester_user_id)
-
- return 200, resp
-
-
-class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
- """Add/remove/get a category in a group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- category_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- resp = await self.handler.get_group_category(
- group_id, requester_user_id, category_id
- )
-
- return 200, resp
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- category_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- if category_id == "":
- raise SynapseError(400, "category_id cannot be empty string")
-
- if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
- raise SynapseError(
- 400,
- "category_id may not be longer than %s characters"
- % (MAX_GROUP_CATEGORYID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- resp = await self.handler.upsert_group_category(
- group_id, requester_user_id, category_id, content
- )
-
- return 200, resp
-
- async def on_DELETE(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- category_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- if category_id == "":
- raise SynapseError(400, "category_id cannot be empty string")
-
- resp = await self.handler.delete_group_category(
- group_id, requester_user_id, category_id
- )
-
- return 200, resp
-
-
-class FederationGroupsRolesServlet(BaseGroupsServerServlet):
- """Get roles in a group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- resp = await self.handler.get_group_roles(group_id, requester_user_id)
-
- return 200, resp
-
-
-class FederationGroupsRoleServlet(BaseGroupsServerServlet):
- """Add/remove/get a role in a group"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
-
- async def on_GET(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- role_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- resp = await self.handler.get_group_role(group_id, requester_user_id, role_id)
-
- return 200, resp
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- role_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- if role_id == "":
- raise SynapseError(
- 400, "role_id cannot be empty string", Codes.INVALID_PARAM
- )
-
- if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
- raise SynapseError(
- 400,
- "role_id may not be longer than %s characters"
- % (MAX_GROUP_ROLEID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- resp = await self.handler.update_group_role(
- group_id, requester_user_id, role_id, content
- )
-
- return 200, resp
-
- async def on_DELETE(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- role_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- if role_id == "":
- raise SynapseError(400, "role_id cannot be empty string")
-
- resp = await self.handler.delete_group_role(
- group_id, requester_user_id, role_id
- )
-
- return 200, resp
-
-
-class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
- """Add/remove a user from the group summary, with optional role.
-
- Matches both:
- - /groups/:group/summary/users/:user_id
- - /groups/:group/summary/roles/:role/users/:user_id
- """
-
- PATH = (
- "/groups/(?P<group_id>[^/]*)/summary"
- "(/roles/(?P<role_id>[^/]+))?"
- "/users/(?P<user_id>[^/]*)"
- )
-
- async def on_POST(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- role_id: str,
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- if role_id == "":
- raise SynapseError(400, "role_id cannot be empty string")
-
- if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
- raise SynapseError(
- 400,
- "role_id may not be longer than %s characters"
- % (MAX_GROUP_ROLEID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- resp = await self.handler.update_group_summary_user(
- group_id,
- requester_user_id,
- user_id=user_id,
- role_id=role_id,
- content=content,
- )
-
- return 200, resp
-
- async def on_DELETE(
- self,
- origin: str,
- content: Literal[None],
- query: Dict[bytes, List[bytes]],
- group_id: str,
- role_id: str,
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- if role_id == "":
- raise SynapseError(400, "role_id cannot be empty string")
-
- resp = await self.handler.delete_group_summary_user(
- group_id, requester_user_id, user_id=user_id, role_id=role_id
- )
-
- return 200, resp
-
-
-class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
- """Sets whether a group is joinable without an invite or knock"""
-
- PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
-
- async def on_PUT(
- self,
- origin: str,
- content: JsonDict,
- query: Dict[bytes, List[bytes]],
- group_id: str,
- ) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(
- query, "requester_user_id", required=True
- )
- if get_domain_from_id(requester_user_id) != origin:
- raise SynapseError(403, "requester_user_id doesn't match origin")
-
- new_content = await self.handler.set_group_join_policy(
- group_id, requester_user_id, content
- )
-
- return 200, new_content
-
-
-GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
- FederationGroupsProfileServlet,
- FederationGroupsSummaryServlet,
- FederationGroupsRoomsServlet,
- FederationGroupsUsersServlet,
- FederationGroupsInvitedUsersServlet,
- FederationGroupsInviteServlet,
- FederationGroupsAcceptInviteServlet,
- FederationGroupsJoinServlet,
- FederationGroupsRemoveUserServlet,
- FederationGroupsSummaryRoomsServlet,
- FederationGroupsCategoriesServlet,
- FederationGroupsCategoryServlet,
- FederationGroupsRolesServlet,
- FederationGroupsRoleServlet,
- FederationGroupsSummaryUsersServlet,
- FederationGroupsAddRoomsServlet,
- FederationGroupsAddRoomsConfigServlet,
- FederationGroupsSettingJoinPolicyServlet,
-)
diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/synapse/groups/__init__.py
+++ /dev/null
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
deleted file mode 100644
index ed26d6a6..00000000
--- a/synapse/groups/attestations.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Attestations ensure that users and groups can't lie about their memberships.
-
-When a user joins a group the HS and GS swap attestations, which allow them
-both to independently prove to third parties their membership.These
-attestations have a validity period so need to be periodically renewed.
-
-If a user leaves (or gets kicked out of) a group, either side can still use
-their attestation to "prove" their membership, until the attestation expires.
-Therefore attestations shouldn't be relied on to prove membership in important
-cases, but can for less important situations, e.g. showing a users membership
-of groups on their profile, showing flairs, etc.
-
-An attestation is a signed blob of json that looks like:
-
- {
- "user_id": "@foo:a.example.com",
- "group_id": "+bar:b.example.com",
- "valid_until_ms": 1507994728530,
- "signatures":{"matrix.org":{"ed25519:auto":"..."}}
- }
-"""
-
-import logging
-import random
-from typing import TYPE_CHECKING, Optional, Tuple
-
-from signedjson.sign import sign_json
-
-from twisted.internet.defer import Deferred
-
-from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import JsonDict, get_domain_from_id
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-# Default validity duration for new attestations we create
-DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
-
-# We add some jitter to the validity duration of attestations so that if we
-# add lots of users at once we don't need to renew them all at once.
-# The jitter is a multiplier picked randomly between the first and second number
-DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
-
-# Start trying to update our attestations when they come this close to expiring
-UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
-
-
-class GroupAttestationSigning:
- """Creates and verifies group attestations."""
-
- def __init__(self, hs: "HomeServer"):
- self.keyring = hs.get_keyring()
- self.clock = hs.get_clock()
- self.server_name = hs.hostname
- self.signing_key = hs.signing_key
-
- async def verify_attestation(
- self,
- attestation: JsonDict,
- group_id: str,
- user_id: str,
- server_name: Optional[str] = None,
- ) -> None:
- """Verifies that the given attestation matches the given parameters.
-
- An optional server_name can be supplied to explicitly set which server's
- signature is expected. Otherwise assumes that either the group_id or user_id
- is local and uses the other's server as the one to check.
- """
-
- if not server_name:
- if get_domain_from_id(group_id) == self.server_name:
- server_name = get_domain_from_id(user_id)
- elif get_domain_from_id(user_id) == self.server_name:
- server_name = get_domain_from_id(group_id)
- else:
- raise Exception("Expected either group_id or user_id to be local")
-
- if user_id != attestation["user_id"]:
- raise SynapseError(400, "Attestation has incorrect user_id")
-
- if group_id != attestation["group_id"]:
- raise SynapseError(400, "Attestation has incorrect group_id")
- valid_until_ms = attestation["valid_until_ms"]
-
- # TODO: We also want to check that *new* attestations that people give
- # us to store are valid for at least a little while.
- now = self.clock.time_msec()
- if valid_until_ms < now:
- raise SynapseError(400, "Attestation expired")
-
- assert server_name is not None
- await self.keyring.verify_json_for_server(
- server_name,
- attestation,
- now,
- )
-
- def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
- """Create an attestation for the group_id and user_id with default
- validity length.
- """
- validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform(
- *DEFAULT_ATTESTATION_JITTER
- )
- valid_until_ms = int(self.clock.time_msec() + validity_period)
-
- return sign_json(
- {
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": valid_until_ms,
- },
- self.server_name,
- self.signing_key,
- )
-
-
-class GroupAttestionRenewer:
- """Responsible for sending and receiving attestation updates."""
-
- def __init__(self, hs: "HomeServer"):
- self.clock = hs.get_clock()
- self.store = hs.get_datastores().main
- self.assestations = hs.get_groups_attestation_signing()
- self.transport_client = hs.get_federation_transport_client()
- self.is_mine_id = hs.is_mine_id
- self.attestations = hs.get_groups_attestation_signing()
-
- if not hs.config.worker.worker_app:
- self._renew_attestations_loop = self.clock.looping_call(
- self._start_renew_attestations, 30 * 60 * 1000
- )
-
- async def on_renew_attestation(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """When a remote updates an attestation"""
- attestation = content["attestation"]
-
- if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
- raise SynapseError(400, "Neither user not group are on this server")
-
- await self.attestations.verify_attestation(
- attestation, user_id=user_id, group_id=group_id
- )
-
- await self.store.update_remote_attestion(group_id, user_id, attestation)
-
- return {}
-
- def _start_renew_attestations(self) -> "Deferred[None]":
- return run_as_background_process("renew_attestations", self._renew_attestations)
-
- async def _renew_attestations(self) -> None:
- """Called periodically to check if we need to update any of our attestations"""
-
- now = self.clock.time_msec()
-
- rows = await self.store.get_attestations_need_renewals(
- now + UPDATE_ATTESTATION_TIME_MS
- )
-
- async def _renew_attestation(group_user: Tuple[str, str]) -> None:
- group_id, user_id = group_user
- try:
- if not self.is_mine_id(group_id):
- destination = get_domain_from_id(group_id)
- elif not self.is_mine_id(user_id):
- destination = get_domain_from_id(user_id)
- else:
- logger.warning(
- "Incorrectly trying to do attestations for user: %r in %r",
- user_id,
- group_id,
- )
- await self.store.remove_attestation_renewal(group_id, user_id)
- return
-
- attestation = self.attestations.create_attestation(group_id, user_id)
-
- await self.transport_client.renew_group_attestation(
- destination, group_id, user_id, content={"attestation": attestation}
- )
-
- await self.store.update_attestation_renewal(
- group_id, user_id, attestation
- )
- except (RequestSendFailed, HttpResponseException) as e:
- logger.warning(
- "Failed to renew attestation of %r in %r: %s", user_id, group_id, e
- )
- except Exception:
- logger.exception(
- "Error renewing attestation of %r in %r", user_id, group_id
- )
-
- for row in rows:
- await _renew_attestation((row["group_id"], row["user_id"]))
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
deleted file mode 100644
index 4c3a5a6e..00000000
--- a/synapse/groups/groups_server.py
+++ /dev/null
@@ -1,1019 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
-# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com>
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Optional
-
-from synapse.api.errors import Codes, SynapseError
-from synapse.handlers.groups_local import GroupsLocalHandler
-from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
-from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id
-from synapse.util.async_helpers import concurrently_execute
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-# TODO: Allow users to "knock" or simply join depending on rules
-# TODO: Federation admin APIs
-# TODO: is_privileged flag to users and is_public to users and rooms
-# TODO: Audit log for admins (profile updates, membership changes, users who tried
-# to join but were rejected, etc)
-# TODO: Flairs
-
-
-# Note that the maximum lengths are somewhat arbitrary.
-MAX_SHORT_DESC_LEN = 1000
-MAX_LONG_DESC_LEN = 10000
-
-
-class GroupsServerWorkerHandler:
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.store = hs.get_datastores().main
- self.room_list_handler = hs.get_room_list_handler()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.keyring = hs.get_keyring()
- self.is_mine_id = hs.is_mine_id
- self.signing_key = hs.signing_key
- self.server_name = hs.hostname
- self.attestations = hs.get_groups_attestation_signing()
- self.transport_client = hs.get_federation_transport_client()
- self.profile_handler = hs.get_profile_handler()
-
- async def check_group_is_ours(
- self,
- group_id: str,
- requester_user_id: str,
- and_exists: bool = False,
- and_is_admin: Optional[str] = None,
- ) -> Optional[dict]:
- """Check that the group is ours, and optionally if it exists.
-
- If group does exist then return group.
-
- Args:
- group_id: The group ID to check.
- requester_user_id: The user ID of the requester.
- and_exists: whether to also check if group exists
- and_is_admin: whether to also check if given str is a user_id
- that is an admin
- """
- if not self.is_mine_id(group_id):
- raise SynapseError(400, "Group not on this server")
-
- group = await self.store.get_group(group_id)
- if and_exists and not group:
- raise SynapseError(404, "Unknown group")
-
- is_user_in_group = await self.store.is_user_in_group(
- requester_user_id, group_id
- )
- if group and not is_user_in_group and not group["is_public"]:
- raise SynapseError(404, "Unknown group")
-
- if and_is_admin:
- is_admin = await self.store.is_user_admin_in_group(group_id, and_is_admin)
- if not is_admin:
- raise SynapseError(403, "User is not admin in group")
-
- return group
-
- async def get_group_summary(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get the summary for a group as seen by requester_user_id.
-
- The group summary consists of the profile of the room, and a curated
- list of users and rooms. These list *may* be organised by role/category.
- The roles/categories are ordered, and so are the users/rooms within them.
-
- A user/room may appear in multiple roles/categories.
- """
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_user_in_group = await self.store.is_user_in_group(
- requester_user_id, group_id
- )
-
- profile = await self.get_group_profile(group_id, requester_user_id)
-
- users, roles = await self.store.get_users_for_summary_by_role(
- group_id, include_private=is_user_in_group
- )
-
- # TODO: Add profiles to users
-
- rooms, categories = await self.store.get_rooms_for_summary_by_category(
- group_id, include_private=is_user_in_group
- )
-
- for room_entry in rooms:
- room_id = room_entry["room_id"]
- joined_users = await self.store.get_users_in_room(room_id)
- entry = await self.room_list_handler.generate_room_entry(
- room_id, len(joined_users), with_alias=False, allow_private=True
- )
- if entry is None:
- continue
- entry = dict(entry) # so we don't change what's cached
- entry.pop("room_id", None)
-
- room_entry["profile"] = entry
-
- rooms.sort(key=lambda e: e.get("order", 0))
-
- for user in users:
- user_id = user["user_id"]
-
- if not self.is_mine_id(requester_user_id):
- attestation = await self.store.get_remote_attestation(group_id, user_id)
- if not attestation:
- continue
-
- user["attestation"] = attestation
- else:
- user["attestation"] = self.attestations.create_attestation(
- group_id, user_id
- )
-
- user_profile = await self.profile_handler.get_profile_from_cache(user_id)
- user.update(user_profile)
-
- users.sort(key=lambda e: e.get("order", 0))
-
- membership_info = await self.store.get_users_membership_info_in_group(
- group_id, requester_user_id
- )
-
- return {
- "profile": profile,
- "users_section": {
- "users": users,
- "roles": roles,
- "total_user_count_estimate": 0, # TODO
- },
- "rooms_section": {
- "rooms": rooms,
- "categories": categories,
- "total_room_count_estimate": 0, # TODO
- },
- "user": membership_info,
- }
-
- async def get_group_categories(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get all categories in a group (as seen by user)"""
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- categories = await self.store.get_group_categories(group_id=group_id)
- return {"categories": categories}
-
- async def get_group_category(
- self, group_id: str, requester_user_id: str, category_id: str
- ) -> JsonDict:
- """Get a specific category in a group (as seen by user)"""
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- return await self.store.get_group_category(
- group_id=group_id, category_id=category_id
- )
-
- async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict:
- """Get all roles in a group (as seen by user)"""
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- roles = await self.store.get_group_roles(group_id=group_id)
- return {"roles": roles}
-
- async def get_group_role(
- self, group_id: str, requester_user_id: str, role_id: str
- ) -> JsonDict:
- """Get a specific role in a group (as seen by user)"""
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- return await self.store.get_group_role(group_id=group_id, role_id=role_id)
-
- async def get_group_profile(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get the group profile as seen by requester_user_id"""
-
- await self.check_group_is_ours(group_id, requester_user_id)
-
- group = await self.store.get_group(group_id)
-
- if group:
- cols = [
- "name",
- "short_description",
- "long_description",
- "avatar_url",
- "is_public",
- ]
- group_description = {key: group[key] for key in cols}
- group_description["is_openly_joinable"] = group["join_policy"] == "open"
-
- return group_description
- else:
- raise SynapseError(404, "Unknown group")
-
- async def get_users_in_group(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get the users in group as seen by requester_user_id.
-
- The ordering is arbitrary at the moment
- """
-
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_user_in_group = await self.store.is_user_in_group(
- requester_user_id, group_id
- )
-
- user_results = await self.store.get_users_in_group(
- group_id, include_private=is_user_in_group
- )
-
- chunk = []
- for user_result in user_results:
- g_user_id = user_result["user_id"]
- is_public = user_result["is_public"]
- is_privileged = user_result["is_admin"]
-
- entry = {"user_id": g_user_id}
-
- profile = await self.profile_handler.get_profile_from_cache(g_user_id)
- entry.update(profile)
-
- entry["is_public"] = bool(is_public)
- entry["is_privileged"] = bool(is_privileged)
-
- if not self.is_mine_id(g_user_id):
- attestation = await self.store.get_remote_attestation(
- group_id, g_user_id
- )
- if not attestation:
- continue
-
- entry["attestation"] = attestation
- else:
- entry["attestation"] = self.attestations.create_attestation(
- group_id, g_user_id
- )
-
- chunk.append(entry)
-
- # TODO: If admin add lists of users whose attestations have timed out
-
- return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
-
- async def get_invited_users_in_group(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get the users that have been invited to a group as seen by requester_user_id.
-
- The ordering is arbitrary at the moment
- """
-
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_user_in_group = await self.store.is_user_in_group(
- requester_user_id, group_id
- )
-
- if not is_user_in_group:
- raise SynapseError(403, "User not in group")
-
- invited_users = await self.store.get_invited_users_in_group(group_id)
-
- user_profiles = []
-
- for user_id in invited_users:
- user_profile = {"user_id": user_id}
- try:
- profile = await self.profile_handler.get_profile_from_cache(user_id)
- user_profile.update(profile)
- except Exception as e:
- logger.warning("Error getting profile for %s: %s", user_id, e)
- user_profiles.append(user_profile)
-
- return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
-
- async def get_rooms_in_group(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get the rooms in group as seen by requester_user_id
-
- This returns rooms in order of decreasing number of joined users
- """
-
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_user_in_group = await self.store.is_user_in_group(
- requester_user_id, group_id
- )
-
- # Note! room_results["is_public"] is about whether the room is considered
- # public from the group's point of view. (i.e. whether non-group members
- # should be able to see the room is in the group).
- # This is not the same as whether the room itself is public (in the sense
- # of being visible in the room directory).
- # As such, room_results["is_public"] itself is not sufficient to determine
- # whether any given user is permitted to see the room's metadata.
- room_results = await self.store.get_rooms_in_group(
- group_id, include_private=is_user_in_group
- )
-
- chunk = []
- for room_result in room_results:
- room_id = room_result["room_id"]
-
- joined_users = await self.store.get_users_in_room(room_id)
-
- # check the user is actually allowed to see the room before showing it to them
- allow_private = requester_user_id in joined_users
-
- entry = await self.room_list_handler.generate_room_entry(
- room_id,
- len(joined_users),
- with_alias=False,
- allow_private=allow_private,
- )
-
- if not entry:
- continue
-
- entry["is_public"] = bool(room_result["is_public"])
-
- chunk.append(entry)
-
- chunk.sort(key=lambda e: -e["num_joined_members"])
-
- return {"chunk": chunk, "total_room_count_estimate": len(chunk)}
-
-
-class GroupsServerHandler(GroupsServerWorkerHandler):
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
- # Ensure attestations get renewed
- hs.get_groups_attestation_renewer()
-
- async def update_group_summary_room(
- self,
- group_id: str,
- requester_user_id: str,
- room_id: str,
- category_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Add/update a room to the group summary"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- RoomID.from_string(room_id) # Ensure valid room id
-
- order = content.get("order", None)
-
- is_public = _parse_visibility_from_contents(content)
-
- await self.store.add_room_to_summary(
- group_id=group_id,
- room_id=room_id,
- category_id=category_id,
- order=order,
- is_public=is_public,
- )
-
- return {}
-
- async def delete_group_summary_room(
- self, group_id: str, requester_user_id: str, room_id: str, category_id: str
- ) -> JsonDict:
- """Remove a room from the summary"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- await self.store.remove_room_from_summary(
- group_id=group_id, room_id=room_id, category_id=category_id
- )
-
- return {}
-
- async def set_group_join_policy(
- self, group_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """Sets the group join policy.
-
- Currently supported policies are:
- - "invite": an invite must be received and accepted in order to join.
- - "open": anyone can join.
- """
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- join_policy = _parse_join_policy_from_contents(content)
- if join_policy is None:
- raise SynapseError(400, "No value specified for 'm.join_policy'")
-
- await self.store.set_group_join_policy(group_id, join_policy=join_policy)
-
- return {}
-
- async def update_group_category(
- self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict
- ) -> JsonDict:
- """Add/Update a group category"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- is_public = _parse_visibility_from_contents(content)
- profile = content.get("profile")
-
- await self.store.upsert_group_category(
- group_id=group_id,
- category_id=category_id,
- is_public=is_public,
- profile=profile,
- )
-
- return {}
-
- async def delete_group_category(
- self, group_id: str, requester_user_id: str, category_id: str
- ) -> JsonDict:
- """Delete a group category"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- await self.store.remove_group_category(
- group_id=group_id, category_id=category_id
- )
-
- return {}
-
- async def update_group_role(
- self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict
- ) -> JsonDict:
- """Add/update a role in a group"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- is_public = _parse_visibility_from_contents(content)
-
- profile = content.get("profile")
-
- await self.store.upsert_group_role(
- group_id=group_id, role_id=role_id, is_public=is_public, profile=profile
- )
-
- return {}
-
- async def delete_group_role(
- self, group_id: str, requester_user_id: str, role_id: str
- ) -> JsonDict:
- """Remove role from group"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- await self.store.remove_group_role(group_id=group_id, role_id=role_id)
-
- return {}
-
- async def update_group_summary_user(
- self,
- group_id: str,
- requester_user_id: str,
- user_id: str,
- role_id: str,
- content: JsonDict,
- ) -> JsonDict:
- """Add/update a users entry in the group summary"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- order = content.get("order", None)
-
- is_public = _parse_visibility_from_contents(content)
-
- await self.store.add_user_to_summary(
- group_id=group_id,
- user_id=user_id,
- role_id=role_id,
- order=order,
- is_public=is_public,
- )
-
- return {}
-
- async def delete_group_summary_user(
- self, group_id: str, requester_user_id: str, user_id: str, role_id: str
- ) -> JsonDict:
- """Remove a user from the group summary"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- await self.store.remove_user_from_summary(
- group_id=group_id, user_id=user_id, role_id=role_id
- )
-
- return {}
-
- async def update_group_profile(
- self, group_id: str, requester_user_id: str, content: JsonDict
- ) -> None:
- """Update the group profile"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- profile = {}
- for keyname, max_length in (
- ("name", MAX_DISPLAYNAME_LEN),
- ("avatar_url", MAX_AVATAR_URL_LEN),
- ("short_description", MAX_SHORT_DESC_LEN),
- ("long_description", MAX_LONG_DESC_LEN),
- ):
- if keyname in content:
- value = content[keyname]
- if not isinstance(value, str):
- raise SynapseError(
- 400,
- "%r value is not a string" % (keyname,),
- errcode=Codes.INVALID_PARAM,
- )
- if len(value) > max_length:
- raise SynapseError(
- 400,
- "Invalid %s parameter" % (keyname,),
- errcode=Codes.INVALID_PARAM,
- )
- profile[keyname] = value
-
- await self.store.update_group_profile(group_id, profile)
-
- async def add_room_to_group(
- self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict
- ) -> JsonDict:
- """Add room to group"""
- RoomID.from_string(room_id) # Ensure valid room id
-
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- is_public = _parse_visibility_from_contents(content)
-
- await self.store.add_room_to_group(group_id, room_id, is_public=is_public)
-
- return {}
-
- async def update_room_in_group(
- self,
- group_id: str,
- requester_user_id: str,
- room_id: str,
- config_key: str,
- content: JsonDict,
- ) -> JsonDict:
- """Update room in group"""
- RoomID.from_string(room_id) # Ensure valid room id
-
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- if config_key == "m.visibility":
- is_public = _parse_visibility_dict(content)
-
- await self.store.update_room_in_group_visibility(
- group_id, room_id, is_public=is_public
- )
- else:
- raise SynapseError(400, "Unknown config option")
-
- return {}
-
- async def remove_room_from_group(
- self, group_id: str, requester_user_id: str, room_id: str
- ) -> JsonDict:
- """Remove room from group"""
- await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
-
- await self.store.remove_room_from_group(group_id, room_id)
-
- return {}
-
- async def invite_to_group(
- self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """Invite user to group"""
-
- group = await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
- )
- if not group:
- raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE)
-
- # TODO: Check if user knocked
-
- invited_users = await self.store.get_invited_users_in_group(group_id)
- if user_id in invited_users:
- raise SynapseError(
- 400, "User already invited to group", errcode=Codes.BAD_STATE
- )
-
- user_results = await self.store.get_users_in_group(
- group_id, include_private=True
- )
- if user_id in (user_result["user_id"] for user_result in user_results):
- raise SynapseError(400, "User already in group")
-
- content = {
- "profile": {"name": group["name"], "avatar_url": group["avatar_url"]},
- "inviter": requester_user_id,
- }
-
- if self.hs.is_mine_id(user_id):
- groups_local = self.hs.get_groups_local_handler()
- assert isinstance(
- groups_local, GroupsLocalHandler
- ), "Workers cannot invites users to groups."
- res = await groups_local.on_invite(group_id, user_id, content)
- local_attestation = None
- else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content.update({"attestation": local_attestation})
-
- res = await self.transport_client.invite_to_group_notification(
- get_domain_from_id(user_id), group_id, user_id, content
- )
-
- user_profile = res.get("user_profile", {})
- await self.store.add_remote_profile_cache(
- user_id,
- displayname=user_profile.get("displayname"),
- avatar_url=user_profile.get("avatar_url"),
- )
-
- if res["state"] == "join":
- if not self.hs.is_mine_id(user_id):
- remote_attestation = res["attestation"]
-
- await self.attestations.verify_attestation(
- remote_attestation, user_id=user_id, group_id=group_id
- )
- else:
- remote_attestation = None
-
- await self.store.add_user_to_group(
- group_id,
- user_id,
- is_admin=False,
- is_public=False, # TODO
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- )
- return {"state": "join"}
- elif res["state"] == "invite":
- await self.store.add_group_invite(group_id, user_id)
- return {"state": "invite"}
- elif res["state"] == "reject":
- return {"state": "reject"}
- else:
- raise SynapseError(502, "Unknown state returned by HS")
-
- async def _add_user(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> Optional[JsonDict]:
- """Add a user to a group based on a content dict.
-
- See accept_invite, join_group.
- """
- if not self.hs.is_mine_id(user_id):
- local_attestation: Optional[
- JsonDict
- ] = self.attestations.create_attestation(group_id, user_id)
-
- remote_attestation = content["attestation"]
-
- await self.attestations.verify_attestation(
- remote_attestation, user_id=user_id, group_id=group_id
- )
- else:
- local_attestation = None
- remote_attestation = None
-
- is_public = _parse_visibility_from_contents(content)
-
- await self.store.add_user_to_group(
- group_id,
- user_id,
- is_admin=False,
- is_public=is_public,
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- )
-
- return local_attestation
-
- async def accept_invite(
- self, group_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """User tries to accept an invite to the group.
-
- This is different from them asking to join, and so should error if no
- invite exists (and they're not a member of the group)
- """
-
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_invited = await self.store.is_user_invited_to_local_group(
- group_id, requester_user_id
- )
- if not is_invited:
- raise SynapseError(403, "User not invited to group")
-
- local_attestation = await self._add_user(group_id, requester_user_id, content)
-
- return {"state": "join", "attestation": local_attestation}
-
- async def join_group(
- self, group_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """User tries to join the group.
-
- This will error if the group requires an invite/knock to join
- """
-
- group_info = await self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True
- )
- if not group_info:
- raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND)
- if group_info["join_policy"] != "open":
- raise SynapseError(403, "Group is not publicly joinable")
-
- local_attestation = await self._add_user(group_id, requester_user_id, content)
-
- return {"state": "join", "attestation": local_attestation}
-
- async def remove_user_from_group(
- self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """Remove a user from the group; either a user is leaving or an admin
- kicked them.
- """
-
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_kick = False
- if requester_user_id != user_id:
- is_admin = await self.store.is_user_admin_in_group(
- group_id, requester_user_id
- )
- if not is_admin:
- raise SynapseError(403, "User is not admin in group")
-
- is_kick = True
-
- await self.store.remove_user_from_group(group_id, user_id)
-
- if is_kick:
- if self.hs.is_mine_id(user_id):
- groups_local = self.hs.get_groups_local_handler()
- assert isinstance(
- groups_local, GroupsLocalHandler
- ), "Workers cannot remove users from groups."
- await groups_local.user_removed_from_group(group_id, user_id, {})
- else:
- await self.transport_client.remove_user_from_group_notification(
- get_domain_from_id(user_id), group_id, user_id, {}
- )
-
- if not self.hs.is_mine_id(user_id):
- await self.store.maybe_delete_remote_profile_cache(user_id)
-
- # Delete group if the last user has left
- users = await self.store.get_users_in_group(group_id, include_private=True)
- if not users:
- await self.store.delete_group(group_id)
-
- return {}
-
- async def create_group(
- self, group_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- logger.info("Attempting to create group with ID: %r", group_id)
-
- # parsing the id into a GroupID validates it.
- group_id_obj = GroupID.from_string(group_id)
-
- group = await self.check_group_is_ours(group_id, requester_user_id)
- if group:
- raise SynapseError(400, "Group already exists")
-
- is_admin = await self.auth.is_server_admin(
- UserID.from_string(requester_user_id)
- )
- if not is_admin:
- if not self.hs.config.groups.enable_group_creation:
- raise SynapseError(
- 403, "Only a server admin can create groups on this server"
- )
- localpart = group_id_obj.localpart
- if not localpart.startswith(self.hs.config.groups.group_creation_prefix):
- raise SynapseError(
- 400,
- "Can only create groups with prefix %r on this server"
- % (self.hs.config.groups.group_creation_prefix,),
- )
-
- profile = content.get("profile", {})
- name = profile.get("name")
- avatar_url = profile.get("avatar_url")
- short_description = profile.get("short_description")
- long_description = profile.get("long_description")
- user_profile = content.get("user_profile", {})
-
- await self.store.create_group(
- group_id,
- requester_user_id,
- name=name,
- avatar_url=avatar_url,
- short_description=short_description,
- long_description=long_description,
- )
-
- if not self.hs.is_mine_id(requester_user_id):
- remote_attestation = content["attestation"]
-
- await self.attestations.verify_attestation(
- remote_attestation, user_id=requester_user_id, group_id=group_id
- )
-
- local_attestation: Optional[
- JsonDict
- ] = self.attestations.create_attestation(group_id, requester_user_id)
- else:
- local_attestation = None
- remote_attestation = None
-
- await self.store.add_user_to_group(
- group_id,
- requester_user_id,
- is_admin=True,
- is_public=True, # TODO
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- )
-
- if not self.hs.is_mine_id(requester_user_id):
- await self.store.add_remote_profile_cache(
- requester_user_id,
- displayname=user_profile.get("displayname"),
- avatar_url=user_profile.get("avatar_url"),
- )
-
- return {"group_id": group_id}
-
- async def delete_group(self, group_id: str, requester_user_id: str) -> None:
- """Deletes a group, kicking out all current members.
-
- Only group admins or server admins can call this request
-
- Args:
- group_id: The group ID to delete.
- requester_user_id: The user requesting to delete the group.
- """
-
- await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- # Only server admins or group admins can delete groups.
-
- is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id)
-
- if not is_admin:
- is_admin = await self.auth.is_server_admin(
- UserID.from_string(requester_user_id)
- )
-
- if not is_admin:
- raise SynapseError(403, "User is not an admin")
-
- # Before deleting the group lets kick everyone out of it
- users = await self.store.get_users_in_group(group_id, include_private=True)
-
- async def _kick_user_from_group(user_id):
- if self.hs.is_mine_id(user_id):
- groups_local = self.hs.get_groups_local_handler()
- assert isinstance(
- groups_local, GroupsLocalHandler
- ), "Workers cannot kick users from groups."
- await groups_local.user_removed_from_group(group_id, user_id, {})
- else:
- await self.transport_client.remove_user_from_group_notification(
- get_domain_from_id(user_id), group_id, user_id, {}
- )
- await self.store.maybe_delete_remote_profile_cache(user_id)
-
- # We kick users out in the order of:
- # 1. Non-admins
- # 2. Other admins
- # 3. The requester
- #
- # This is so that if the deletion fails for some reason other admins or
- # the requester still has auth to retry.
- non_admins = []
- admins = []
- for u in users:
- if u["user_id"] == requester_user_id:
- continue
- if u["is_admin"]:
- admins.append(u["user_id"])
- else:
- non_admins.append(u["user_id"])
-
- await concurrently_execute(_kick_user_from_group, non_admins, 10)
- await concurrently_execute(_kick_user_from_group, admins, 10)
- await _kick_user_from_group(requester_user_id)
-
- await self.store.delete_group(group_id)
-
-
-def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]:
- """Given a content for a request, return the specified join policy or None"""
-
- join_policy_dict = content.get("m.join_policy")
- if join_policy_dict:
- return _parse_join_policy_dict(join_policy_dict)
- else:
- return None
-
-
-def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str:
- """Given a dict for the "m.join_policy" config return the join policy specified"""
- join_policy_type = join_policy_dict.get("type")
- if not join_policy_type:
- return "invite"
-
- if join_policy_type not in ("invite", "open"):
- raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule")
- return join_policy_type
-
-
-def _parse_visibility_from_contents(content: JsonDict) -> bool:
- """Given a content for a request parse out whether the entity should be
- public or not
- """
-
- visibility = content.get("m.visibility")
- if visibility:
- return _parse_visibility_dict(visibility)
- else:
- is_public = True
-
- return is_public
-
-
-def _parse_visibility_dict(visibility: JsonDict) -> bool:
- """Given a dict for the "m.visibility" config return if the entity should
- be public or not
- """
- vis_type = visibility.get("type")
- if not vis_type:
- return True
-
- if vis_type not in ("public", "private"):
- raise SynapseError(400, "Synapse only supports 'public'/'private' visibility")
- return vis_type == "public"
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 4af9fbc5..0478448b 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -23,7 +23,7 @@ from synapse.replication.http.account_data import (
ReplicationUserAccountDataRestServlet,
)
from synapse.streams import EventSource
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -105,7 +105,7 @@ class AccountDataHandler:
)
self._notifier.on_new_event(
- "account_data_key", max_stream_id, users=[user_id]
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
await self._notify_modules(user_id, room_id, account_data_type, content)
@@ -141,7 +141,7 @@ class AccountDataHandler:
)
self._notifier.on_new_event(
- "account_data_key", max_stream_id, users=[user_id]
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
await self._notify_modules(user_id, None, account_data_type, content)
@@ -176,7 +176,7 @@ class AccountDataHandler:
)
self._notifier.on_new_event(
- "account_data_key", max_stream_id, users=[user_id]
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
return max_stream_id
else:
@@ -201,7 +201,7 @@ class AccountDataHandler:
)
self._notifier.on_new_event(
- "account_data_key", max_stream_id, users=[user_id]
+ StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
return max_stream_id
else:
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 96376963..d4fe7df5 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
class AdminHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
async def get_whois(self, user: UserID) -> JsonDict:
connections = []
@@ -197,7 +197,9 @@ class AdminHandler:
from_key = events[-1].internal_metadata.after
- events = await filter_events_for_client(self.storage, user_id, events)
+ events = await filter_events_for_client(
+ self._storage_controllers, user_id, events
+ )
writer.write_events(room_id, events)
@@ -233,7 +235,9 @@ class AdminHandler:
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
- state = await self.state_store.get_state_for_event(event_id)
+ state = await self._state_storage_controller.get_state_for_event(
+ event_id
+ )
writer.write_state(room_id, event_id, state)
return writer.finished()
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 85bd5e47..814553e0 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -19,7 +19,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -38,6 +38,7 @@ from synapse.types import (
JsonDict,
RoomAlias,
RoomStreamToken,
+ StreamKeyType,
UserID,
)
from synapse.util.async_helpers import Linearizer
@@ -213,8 +214,8 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
- `stream_key` can be "typing_key", "receipt_key", "presence_key",
- "to_device_key" or "device_list_key". Any other value for `stream_key`
+ `stream_key` can be StreamKeyType.TYPING, StreamKeyType.RECEIPT, StreamKeyType.PRESENCE,
+ StreamKeyType.TO_DEVICE or StreamKeyType.DEVICE_LIST. Any other value for `stream_key`
will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into
@@ -235,11 +236,11 @@ class ApplicationServicesHandler:
# Only the following streams are currently supported.
# FIXME: We should use constants for these values.
if stream_key not in (
- "typing_key",
- "receipt_key",
- "presence_key",
- "to_device_key",
- "device_list_key",
+ StreamKeyType.TYPING,
+ StreamKeyType.RECEIPT,
+ StreamKeyType.PRESENCE,
+ StreamKeyType.TO_DEVICE,
+ StreamKeyType.DEVICE_LIST,
):
return
@@ -258,14 +259,14 @@ class ApplicationServicesHandler:
# Ignore to-device messages if the feature flag is not enabled
if (
- stream_key == "to_device_key"
+ stream_key == StreamKeyType.TO_DEVICE
and not self._msc2409_to_device_messages_enabled
):
return
# Ignore device lists if the feature flag is not enabled
if (
- stream_key == "device_list_key"
+ stream_key == StreamKeyType.DEVICE_LIST
and not self._msc3202_transaction_extensions_enabled
):
return
@@ -283,15 +284,15 @@ class ApplicationServicesHandler:
if (
stream_key
in (
- "typing_key",
- "receipt_key",
- "presence_key",
- "to_device_key",
+ StreamKeyType.TYPING,
+ StreamKeyType.RECEIPT,
+ StreamKeyType.PRESENCE,
+ StreamKeyType.TO_DEVICE,
)
and service.supports_ephemeral
)
or (
- stream_key == "device_list_key"
+ stream_key == StreamKeyType.DEVICE_LIST
and service.msc3202_transaction_extensions
)
]
@@ -317,7 +318,7 @@ class ApplicationServicesHandler:
logger.debug("Checking interested services for %s", stream_key)
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
- if stream_key == "typing_key":
+ if stream_key == StreamKeyType.TYPING:
# Note that we don't persist the token (via set_appservice_stream_type_pos)
# for typing_key due to performance reasons and due to their highly
# ephemeral nature.
@@ -333,7 +334,7 @@ class ApplicationServicesHandler:
async with self._ephemeral_events_linearizer.queue(
(service.id, stream_key)
):
- if stream_key == "receipt_key":
+ if stream_key == StreamKeyType.RECEIPT:
events = await self._handle_receipts(service, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -342,7 +343,7 @@ class ApplicationServicesHandler:
service, "read_receipt", new_token
)
- elif stream_key == "presence_key":
+ elif stream_key == StreamKeyType.PRESENCE:
events = await self._handle_presence(service, users, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -351,7 +352,7 @@ class ApplicationServicesHandler:
service, "presence", new_token
)
- elif stream_key == "to_device_key":
+ elif stream_key == StreamKeyType.TO_DEVICE:
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
@@ -366,7 +367,7 @@ class ApplicationServicesHandler:
service, "to_device", new_token
)
- elif stream_key == "device_list_key":
+ elif stream_key == StreamKeyType.DEVICE_LIST:
device_list_summary = await self._get_device_list_summary(
service, new_token
)
@@ -502,7 +503,7 @@ class ApplicationServicesHandler:
time_now = self.clock.time_msec()
events.extend(
{
- "type": "m.presence",
+ "type": EduTypes.PRESENCE,
"sender": event.user_id,
"content": format_user_presence_state(
event, time_now, include_user_id=False
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1b9050ea..fbafbbee 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -210,7 +210,8 @@ class AuthHandler:
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
- self._password_enabled = hs.config.auth.password_enabled
+ self._password_enabled_for_login = hs.config.auth.password_enabled_for_login
+ self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
self._third_party_rules = hs.get_third_party_event_rules()
@@ -387,13 +388,13 @@ class AuthHandler:
return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
- """Get a list of the authentication types this user can use"""
+ """Get a list of the user-interactive authentication types this user can use."""
ui_auth_types = set()
# if the HS supports password auth, and the user has a non-null password, we
# support password auth
- if self._password_localdb_enabled and self._password_enabled:
+ if self._password_localdb_enabled and self._password_enabled_for_reauth:
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
if lookupres:
_, password_hash = lookupres
@@ -402,7 +403,7 @@ class AuthHandler:
# also allow auth from password providers
for t in self.password_auth_provider.get_supported_login_types().keys():
- if t == LoginType.PASSWORD and not self._password_enabled:
+ if t == LoginType.PASSWORD and not self._password_enabled_for_reauth:
continue
ui_auth_types.add(t)
@@ -710,7 +711,7 @@ class AuthHandler:
return res
# fall back to the v1 login flow
- canonical_id, _ = await self.validate_login(authdict)
+ canonical_id, _ = await self.validate_login(authdict, is_reauth=True)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -1064,7 +1065,7 @@ class AuthHandler:
Returns:
Whether users on this server are allowed to change or set a password
"""
- return self._password_enabled and self._password_localdb_enabled
+ return self._password_enabled_for_login and self._password_localdb_enabled
def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
@@ -1089,9 +1090,9 @@ class AuthHandler:
# that comes first, where it's present.
if LoginType.PASSWORD in types:
types.remove(LoginType.PASSWORD)
- if self._password_enabled:
+ if self._password_enabled_for_login:
types.insert(0, LoginType.PASSWORD)
- elif self._password_localdb_enabled and self._password_enabled:
+ elif self._password_localdb_enabled and self._password_enabled_for_login:
types.insert(0, LoginType.PASSWORD)
return types
@@ -1100,6 +1101,7 @@ class AuthHandler:
self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
+ is_reauth: bool = False,
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Authenticates the user for the /login API
@@ -1110,6 +1112,9 @@ class AuthHandler:
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
ratelimit: whether to apply the failed_login_attempt ratelimiter
+ is_reauth: whether this is part of a User-Interactive Authorisation
+ flow to reauthenticate for a privileged action (rather than a
+ new login)
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
@@ -1132,8 +1137,14 @@ class AuthHandler:
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
+
if login_type == LoginType.PASSWORD:
- if not self._password_enabled:
+ if is_reauth:
+ passwords_allowed_here = self._password_enabled_for_reauth
+ else:
+ passwords_allowed_here = self._password_enabled_for_login
+
+ if not passwords_allowed_here:
raise SynapseError(400, "Password login has been disabled.")
if not isinstance(password, str):
raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index a91b1ee4..a0cbeedc 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -28,7 +28,7 @@ from typing import (
)
from synapse.api import errors
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import (
Codes,
FederationDeniedError,
@@ -43,6 +43,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.types import (
JsonDict,
+ StreamKeyType,
StreamToken,
UserID,
get_domain_from_id,
@@ -60,6 +61,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MAX_DEVICE_DISPLAY_NAME_LEN = 100
+DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
class DeviceWorkerHandler:
@@ -69,7 +71,7 @@ class DeviceWorkerHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
- self.state_store = hs.get_storage().state
+ self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
@@ -164,7 +166,7 @@ class DeviceWorkerHandler:
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
- current_state_ids = await self.store.get_current_state_ids(room_id)
+ current_state_ids = await self._state_storage.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
@@ -202,7 +204,9 @@ class DeviceWorkerHandler:
continue
# mapping from event_id -> state_dict
- prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
+ prev_state_ids = await self._state_storage.get_state_ids_for_events(
+ event_ids
+ )
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@@ -276,7 +280,8 @@ class DeviceHandler(DeviceWorkerHandler):
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
- "m.device_list_update", self.device_list_updater.incoming_device_list_update
+ EduTypes.DEVICE_LIST_UPDATE,
+ self.device_list_updater.incoming_device_list_update,
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@@ -291,6 +296,19 @@ class DeviceHandler(DeviceWorkerHandler):
# On start up check if there are any updates pending.
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
+ self._delete_stale_devices_after = hs.config.server.delete_stale_devices_after
+
+ # Ideally we would run this on a worker and condition this on the
+ # "run_background_tasks_on" setting, but this would mean making the notification
+ # of device list changes over federation work on workers, which is nontrivial.
+ if self._delete_stale_devices_after is not None:
+ self.clock.looping_call(
+ run_as_background_process,
+ DELETE_STALE_DEVICES_INTERVAL_MS,
+ "delete_stale_devices",
+ self._delete_stale_devices,
+ )
+
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
@@ -366,6 +384,19 @@ class DeviceHandler(DeviceWorkerHandler):
raise errors.StoreError(500, "Couldn't generate a device ID.")
+ async def _delete_stale_devices(self) -> None:
+ """Background task that deletes devices which haven't been accessed for more than
+ a configured time period.
+ """
+ # We should only be running this job if the config option is defined.
+ assert self._delete_stale_devices_after is not None
+ now_ms = self.clock.time_msec()
+ since_ms = now_ms - self._delete_stale_devices_after
+ devices = await self.store.get_local_devices_not_accessed_since(since_ms)
+
+ for user_id, user_devices in devices.items():
+ await self.delete_devices(user_id, user_devices)
+
@trace
async def delete_device(self, user_id: str, device_id: str) -> None:
"""Delete the given device
@@ -502,7 +533,7 @@ class DeviceHandler(DeviceWorkerHandler):
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
self.notifier.on_new_event(
- "device_list_key", position, users={user_id}, rooms=room_ids
+ StreamKeyType.DEVICE_LIST, position, users={user_id}, rooms=room_ids
)
# We may need to do some processing asynchronously for local user IDs.
@@ -523,7 +554,9 @@ class DeviceHandler(DeviceWorkerHandler):
from_user_id, user_ids
)
- self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
+ self.notifier.on_new_event(
+ StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
+ )
async def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
@@ -686,7 +719,8 @@ class DeviceHandler(DeviceWorkerHandler):
)
# TODO: when called, this isn't in a logging context.
# This leads to log spam, sentry event spam, and massive
- # memory usage. See #12552.
+ # memory usage.
+ # See https://github.com/matrix-org/synapse/issues/12552.
# log_kv(
# {"message": "sent device update to host", "host": host}
# )
@@ -760,6 +794,10 @@ class DeviceListUpdater:
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
prev_ids = edu_content.pop("prev_id", [])
+ if not isinstance(prev_ids, list):
+ raise SynapseError(
+ 400, "Device list update had an invalid 'prev_ids' field"
+ )
prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 4cb725d0..444c08bc 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
-from synapse.api.constants import ToDeviceEventTypes
+from synapse.api.constants import EduTypes, ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
@@ -26,7 +26,7 @@ from synapse.logging.opentracing import (
set_tag,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
-from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
@@ -59,11 +59,11 @@ class DeviceMessageHandler:
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.to_device:
hs.get_federation_registry().register_edu_handler(
- "m.direct_to_device", self.on_direct_to_device_edu
+ EduTypes.DIRECT_TO_DEVICE, self.on_direct_to_device_edu
)
else:
hs.get_federation_registry().register_instances_for_edu(
- "m.direct_to_device",
+ EduTypes.DIRECT_TO_DEVICE,
hs.config.worker.writers.to_device,
)
@@ -151,7 +151,7 @@ class DeviceMessageHandler:
# Notify listeners that there are new to-device messages to process,
# handing them the latest stream id.
self.notifier.on_new_event(
- "to_device_key", last_stream_id, users=local_messages.keys()
+ StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
)
async def _check_for_unknown_devices(
@@ -285,7 +285,7 @@ class DeviceMessageHandler:
# Notify listeners that there are new to-device messages to process,
# handing them the latest stream id.
self.notifier.on_new_event(
- "to_device_key", last_stream_id, users=local_messages.keys()
+ StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
)
if self.federation_sender:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 33d827a4..1459a046 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -45,6 +45,7 @@ class DirectoryHandler:
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.config = hs.config
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.require_membership = hs.config.server.require_membership_for_aliases
@@ -71,6 +72,9 @@ class DirectoryHandler:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
+ if ":" in room_alias.localpart:
+ raise SynapseError(400, "Invalid character in room alias localpart: ':'.")
+
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this.
@@ -316,7 +320,7 @@ class DirectoryHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
- alias_event = await self.state.get_current_state(
+ alias_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.CanonicalAlias, ""
)
@@ -460,7 +464,11 @@ class DirectoryHandler:
making_public = visibility == "public"
if making_public:
room_aliases = await self.store.get_aliases_for_room(room_id)
- canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
+ canonical_alias = (
+ await self._storage_controllers.state.get_canonical_alias_for_room(
+ room_id
+ )
+ )
if canonical_alias:
room_aliases.append(canonical_alias)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d6714228..52bb5c9c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -25,6 +25,7 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
+from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
@@ -66,13 +67,13 @@ class E2eKeysHandler:
# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
- "m.signing_key_update",
+ EduTypes.SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
- "org.matrix.signing_key_update",
+ EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
)
@@ -1105,22 +1106,19 @@ class E2eKeysHandler:
# can request over federation
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
- (
- key,
- key_id,
- verify_key,
- ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
-
- if key is None:
+ cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
+ user, key_type
+ )
+ if cross_signing_keys is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
- return key, key_id, verify_key
+ return cross_signing_keys
async def _retrieve_cross_signing_keys_for_remote_user(
self,
user: UserID,
desired_key_type: str,
- ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
+ ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1146,12 +1144,10 @@ class E2eKeysHandler:
type(e),
e,
)
- return None, None, None
+ return None
# Process each of the retrieved cross-signing keys
- desired_key = None
- desired_key_id = None
- desired_verify_key = None
+ desired_key_data = None
retrieved_device_ids = []
for key_type in ["master", "self_signing"]:
key_content = remote_result.get(key_type + "_key")
@@ -1196,9 +1192,7 @@ class E2eKeysHandler:
# If this is the desired key type, save it and its ID/VerifyKey
if key_type == desired_key_type:
- desired_key = key_content
- desired_verify_key = verify_key
- desired_key_id = key_id
+ desired_key_data = key_content, key_id, verify_key
# At the same time, store this key in the db for subsequent queries
await self.store.set_e2e_cross_signing_key(
@@ -1212,7 +1206,7 @@ class E2eKeysHandler:
user.to_string(), retrieved_device_ids
)
- return desired_key, desired_key_id, desired_verify_key
+ return desired_key_data
def _check_cross_signing_key(
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index d441ebb0..6bed4643 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -241,7 +241,15 @@ class EventAuthHandler:
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self._store.get_event(join_rules_event_id)
- return join_rules_event.content.get("join_rule") == JoinRules.RESTRICTED
+ content_join_rule = join_rules_event.content.get("join_rule")
+ if content_join_rule == JoinRules.RESTRICTED:
+ return True
+
+ # also check for MSC3787 behaviour
+ if room_version.msc3787_knock_restricted_join_rule:
+ return content_join_rule == JoinRules.KNOCK_RESTRICTED
+
+ return False
async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 82a5aac3..ac13340d 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -113,7 +113,7 @@ class EventStreamHandler:
states = await presence_handler.get_states(users)
to_add.extend(
{
- "type": EduTypes.Presence,
+ "type": EduTypes.PRESENCE,
"content": format_user_presence_state(state, time_now),
}
for state in states
@@ -139,7 +139,7 @@ class EventStreamHandler:
class EventHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
async def get_event(
self,
@@ -177,7 +177,7 @@ class EventHandler:
is_peeking = user.to_string() not in users
filtered = await filter_events_for_client(
- self.storage, user.to_string(), [event], is_peeking=is_peeking
+ self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38dc5b1f..6a143440 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,7 +20,16 @@ import itertools
import logging
from enum import Enum
from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
import attr
from signedjson.key import decode_verify_key_bytes
@@ -34,6 +43,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
FederationDeniedError,
+ FederationError,
HttpResponseException,
NotFoundError,
RequestSendFailed,
@@ -54,6 +64,7 @@ from synapse.replication.http.federation import (
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -124,8 +135,8 @@ class FederationHandler:
self.hs = hs
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@@ -158,6 +169,14 @@ class FederationHandler:
self.third_party_event_rules = hs.get_third_party_event_rules()
+ # if this is the main process, fire off a background process to resume
+ # any partial-state-resync operations which were in flight when we
+ # were shut down.
+ if not hs.config.worker.worker_app:
+ run_as_background_process(
+ "resume_sync_partial_state_room", self._resume_sync_partial_state_room
+ )
+
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
) -> bool:
@@ -323,7 +342,7 @@ class FederationHandler:
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
- self.storage,
+ self._storage_controllers,
self.server_name,
events_to_check,
redact=False,
@@ -352,7 +371,7 @@ class FederationHandler:
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = await self.state_handler.get_current_state(room_id)
+ curr_state = await self._storage_controllers.state.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)
@@ -459,6 +478,8 @@ class FederationHandler:
"""
# TODO: We should be able to call this on workers, but the upgrading of
# room stuff after join currently doesn't work on workers.
+ # TODO: Before we relax this condition, we need to allow re-syncing of
+ # partial room state to happen on workers.
assert self.config.worker.worker_app is None
logger.debug("Joining %s to %s", joinee, room_id)
@@ -539,12 +560,11 @@ class FederationHandler:
if ret.partial_state:
# Kick off the process of asynchronously fetching the state for this
# room.
- #
- # TODO(faster_joins): pick this up again on restart
run_as_background_process(
desc="sync_partial_state_room",
func=self._sync_partial_state_room,
- destination=origin,
+ initial_destination=origin,
+ other_destinations=ret.servers_in_room,
room_id=room_id,
)
@@ -659,7 +679,7 @@ class FederationHandler:
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -730,7 +750,9 @@ class FederationHandler:
# Note that this requires the /send_join request to come back to the
# same server.
if room_version.msc3083_join_rules:
- state_ids = await self.store.get_current_state_ids(room_id)
+ state_ids = await self._state_storage_controller.get_current_state_ids(
+ room_id
+ )
if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
@@ -848,7 +870,7 @@ class FederationHandler:
)
)
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -877,7 +899,7 @@ class FederationHandler:
await self.federation_client.send_leave(host_list, event)
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -1026,7 +1048,9 @@ class FederationHandler:
if event.internal_metadata.outlier:
raise NotFoundError("State not known at event %s" % (event_id,))
- state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
+ state_groups = await self._state_storage_controller.get_state_groups_ids(
+ room_id, [event_id]
+ )
# get_state_groups_ids should return exactly one result
assert len(state_groups) == 1
@@ -1075,7 +1099,9 @@ class FederationHandler:
],
)
- events = await filter_events_for_server(self.storage, origin, events)
+ events = await filter_events_for_server(
+ self._storage_controllers, origin, events
+ )
return events
@@ -1106,7 +1132,9 @@ class FederationHandler:
if not in_room:
raise AuthError(403, "Host not in room.")
- events = await filter_events_for_server(self.storage, origin, [event])
+ events = await filter_events_for_server(
+ self._storage_controllers, origin, [event]
+ )
event = events[0]
return event
else:
@@ -1135,7 +1163,7 @@ class FederationHandler:
)
missing_events = await filter_events_for_server(
- self.storage, origin, missing_events
+ self._storage_controllers, origin, missing_events
)
return missing_events
@@ -1259,7 +1287,9 @@ class FederationHandler:
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
+ )
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = await self.store.get_event(
@@ -1308,7 +1338,9 @@ class FederationHandler:
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
+ )
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
@@ -1441,17 +1473,35 @@ class FederationHandler:
# well.
return None
+ async def _resume_sync_partial_state_room(self) -> None:
+ """Resumes resyncing of all partial-state rooms after a restart."""
+ assert not self.config.worker.worker_app
+
+ partial_state_rooms = await self.store.get_partial_state_rooms_and_servers()
+ for room_id, servers_in_room in partial_state_rooms.items():
+ run_as_background_process(
+ desc="sync_partial_state_room",
+ func=self._sync_partial_state_room,
+ initial_destination=None,
+ other_destinations=servers_in_room,
+ room_id=room_id,
+ )
+
async def _sync_partial_state_room(
self,
- destination: str,
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
room_id: str,
) -> None:
"""Background process to resync the state of a partial-state room
Args:
- destination: homeserver to pull the state from
+ initial_destination: the initial homeserver to pull the state from
+ other_destinations: other homeservers to try to pull the state from, if
+ `initial_destination` is unavailable
room_id: room to be resynced
"""
+ assert not self.config.worker.worker_app
# TODO(faster_joins): do we need to lock to avoid races? What happens if other
# worker processes kick off a resync in parallel? Perhaps we should just elect
@@ -1461,8 +1511,29 @@ class FederationHandler:
# really leave, that might mean we have difficulty getting the room state over
# federation.
#
- # TODO(faster_joins): try other destinations if the one we have fails
+ # TODO(faster_joins): we need some way of prioritising which homeservers in
+ # `other_destinations` to try first, otherwise we'll spend ages trying dead
+ # homeservers for large rooms.
+
+ if initial_destination is None and len(other_destinations) == 0:
+ raise ValueError(
+ f"Cannot resync state of {room_id}: no destinations provided"
+ )
+
+ # Make an infinite iterator of destinations to try. Once we find a working
+ # destination, we'll stick with it until it flakes.
+ if initial_destination is not None:
+ # Move `initial_destination` to the front of the list.
+ destinations = list(other_destinations)
+ if initial_destination in destinations:
+ destinations.remove(initial_destination)
+ destinations = [initial_destination] + destinations
+ destination_iter = itertools.cycle(destinations)
+ else:
+ destination_iter = itertools.cycle(other_destinations)
+ # `destination` is the current remote homeserver we're pulling from.
+ destination = next(destination_iter)
logger.info("Syncing state for room %s via %s", room_id, destination)
# we work through the queue in order of increasing stream ordering.
@@ -1473,14 +1544,19 @@ class FederationHandler:
# clear the lazy-loading flag.
logger.info("Updating current state for %s", room_id)
assert (
- self.storage.persistence is not None
+ self._storage_controllers.persistence is not None
), "TODO(faster_joins): support for workers"
- await self.storage.persistence.update_current_state(room_id)
+ await self._storage_controllers.persistence.update_current_state(
+ room_id
+ )
logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id)
if success:
logger.info("State resync complete for %s", room_id)
+ self._storage_controllers.state.notify_room_un_partial_stated(
+ room_id
+ )
# TODO(faster_joins) update room stats and user directory?
return
@@ -1498,6 +1574,41 @@ class FederationHandler:
allow_rejected=True,
)
for event in events:
- await self._federation_event_handler.update_state_for_partial_state_event(
- destination, event
- )
+ for attempt in itertools.count():
+ try:
+ await self._federation_event_handler.update_state_for_partial_state_event(
+ destination, event
+ )
+ break
+ except FederationError as e:
+ if attempt == len(destinations) - 1:
+ # We have tried every remote server for this event. Give up.
+ # TODO(faster_joins) giving up isn't the right thing to do
+ # if there's a temporary network outage. retrying
+ # indefinitely is also not the right thing to do if we can
+ # reach all homeservers and they all claim they don't have
+ # the state we want.
+ logger.error(
+ "Failed to get state for %s at %s from %s because %s, "
+ "giving up!",
+ room_id,
+ event,
+ destination,
+ e,
+ )
+ raise
+
+ # Try the next remote server.
+ logger.info(
+ "Failed to get state for %s at %s from %s because %s",
+ room_id,
+ event,
+ destination,
+ e,
+ )
+ destination = next(destination_iter)
+ logger.info(
+ "Syncing state for room %s via %s instead",
+ room_id,
+ destination,
+ )
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 6cf927e4..87a06083 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -30,6 +30,7 @@ from typing import (
from prometheus_client import Counter
+from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -63,6 +64,7 @@ from synapse.replication.http.federation import (
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
@@ -96,14 +98,14 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
- self._storage = hs.get_storage()
- self._state_store = self._storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._state_handler = hs.get_state_handler()
self._event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self._message_handler = hs.get_message_handler()
- self._action_generator = hs.get_action_generator()
+ self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
self._state_resolution_handler = hs.get_state_resolution_handler()
# avoid a circular dependency by deferring execution here
self._get_room_member_handler = hs.get_room_member_handler
@@ -272,7 +274,7 @@ class FederationEventHandler:
affected=pdu.event_id,
)
- await self._process_received_pdu(origin, pdu, state=None)
+ await self._process_received_pdu(origin, pdu, state_ids=None)
async def on_send_membership_event(
self, origin: str, event: EventBase
@@ -461,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
- old_state=state,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in state
+ },
partial_state=partial_state,
)
@@ -475,7 +479,23 @@ class FederationEventHandler:
# and discover that we do not have it.
event.internal_metadata.proactively_send = False
- return await self.persist_events_and_notify(room_id, [(event, context)])
+ stream_id_after_persist = await self.persist_events_and_notify(
+ room_id, [(event, context)]
+ )
+
+ # If we're joining the room again, check if there is new marker
+ # state indicating that there is new history imported somewhere in
+ # the DAG. Multiple markers can exist in the current state with
+ # unique state_keys.
+ #
+ # Do this after the state from the remote join was persisted (via
+ # `persist_events_and_notify`). Otherwise we can run into a
+ # situation where the create event doesn't exist yet in the
+ # `current_state_events`
+ for e in state:
+ await self._handle_marker_event(origin, e)
+
+ return stream_id_after_persist
async def update_state_for_partial_state_event(
self, destination: str, event: EventBase
@@ -485,6 +505,9 @@ class FederationEventHandler:
Args:
destination: server to request full state from
event: partial-state event to be de-partial-stated
+
+ Raises:
+ FederationError if we fail to request state from the remote server.
"""
logger.info("Updating state for %s", event.event_id)
with nested_logging_context(suffix=event.event_id):
@@ -494,12 +517,12 @@ class FederationEventHandler:
#
# This is the same operation as we do when we receive a regular event
# over federation.
- state = await self._resolve_state_at_missing_prevs(destination, event)
+ state_ids = await self._resolve_state_at_missing_prevs(destination, event)
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
- old_state=state,
+ state_ids_before_event=state_ids,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
@@ -515,7 +538,9 @@ class FederationEventHandler:
)
return
await self._store.update_state_for_partial_state_event(event, context)
- self._state_store.notify_event_un_partial_stated(event.event_id)
+ self._state_storage_controller.notify_event_un_partial_stated(
+ event.event_id
+ )
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
@@ -749,11 +774,12 @@ class FederationEventHandler:
return
try:
- state = await self._resolve_state_at_missing_prevs(origin, event)
+ state_ids = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state
+
await self._process_received_pdu(
- origin, event, state=state, backfilled=backfilled
+ origin, event, state_ids=state_ids, backfilled=backfilled
)
except FederationError as e:
if e.code == 403:
@@ -763,7 +789,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
- ) -> Optional[Iterable[EventBase]]:
+ ) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
@@ -790,8 +816,12 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
- if we already had all the prev events, `None`. Otherwise, returns a list of
- the events in the state at `event`.
+ if we already had all the prev events, `None`. Otherwise, returns
+ the event ids of the state at `event`.
+
+ Raises:
+ FederationError if we fail to get the state from the remote server after any
+ missing `prev_event`s.
"""
room_id = event.room_id
event_id = event.event_id
@@ -811,10 +841,12 @@ class FederationEventHandler:
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
- event_map = {event_id: event}
+
try:
# Get the state of the events we know about
- ours = await self._state_store.get_state_groups_ids(room_id, seen)
+ ours = await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen
+ )
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
@@ -831,40 +863,23 @@ class FederationEventHandler:
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- remote_state = await self._get_state_after_missing_prev_event(
- dest, room_id, p
+ remote_state_map = (
+ await self._get_state_ids_after_missing_prev_event(
+ dest, room_id, p
+ )
)
- remote_state_map = {
- (x.type, x.state_key): x.event_id for x in remote_state
- }
state_maps.append(remote_state_map)
- for x in remote_state:
- event_map[x.event_id] = x
-
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
- event_map,
+ event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
)
- # We need to give _process_received_pdu the actual state events
- # rather than event ids, so generate that now.
-
- # First though we need to fetch all the events that are in
- # state_map, so we can build up the state below.
- evs = await self._store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.as_is,
- )
- event_map.update(evs)
-
- state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
@@ -876,14 +891,14 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
- return state
+ return state_map
- async def _get_state_after_missing_prev_event(
+ async def _get_state_ids_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
- ) -> List[EventBase]:
+ ) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@@ -892,7 +907,11 @@ class FederationEventHandler:
event_id: The id of the event we want the state at.
Returns:
- A list of events in the state, including the event itself
+ The event ids of the state *after* the given event.
+
+ Raises:
+ InvalidResponseError: if the remote homeserver's response contains fields
+ of the wrong type.
"""
(
state_event_ids,
@@ -907,19 +926,17 @@ class FederationEventHandler:
len(auth_event_ids),
)
- # start by just trying to fetch the events from the store
+ # Start by checking events we already have in the DB
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
- fetched_events = await self._store.get_events(
- desired_events, allow_rejected=True
- )
+ have_events = await self._store.have_seen_events(room_id, desired_events)
- missing_desired_events = desired_events - fetched_events.keys()
+ missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
- len(fetched_events),
+ len(have_events),
)
# We probably won't need most of the auth events, so let's just check which
@@ -930,7 +947,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
- missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
@@ -956,47 +973,51 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events
)
- # we need to make sure we re-load from the database to get the rejected
- # state correct.
- fetched_events.update(
- await self._store.get_events(missing_desired_events, allow_rejected=True)
- )
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
+ # We now need to fill out the state map, which involves fetching the
+ # type and state key for each event ID in the state.
+ state_map = {}
- bad_events = [
- (event_id, event.room_id)
- for event_id, event in fetched_events.items()
- if event.room_id != room_id
- ]
+ event_metadata = await self._store.get_metadata_for_events(state_event_ids)
+ for state_event_id, metadata in event_metadata.items():
+ if metadata.room_id != room_id:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ #
+ # This can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ state_event_id,
+ metadata.room_id,
+ room_id,
+ )
+ continue
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
+ if metadata.state_key is None:
+ logger.warning(
+ "Remote server gave us non-state event in state: %s", state_event_id
+ )
+ continue
- del fetched_events[bad_event_id]
+ state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem.
- remote_event = fetched_events.get(event_id)
+ remote_event = await self._store.get_event(
+ event_id,
+ allow_none=True,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.as_is,
+ )
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
- failed_to_fetch = desired_events - fetched_events.keys()
+ failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@@ -1004,14 +1025,12 @@ class FederationEventHandler:
failed_to_fetch,
)
- remote_state = [
- fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
- ]
-
if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
+ state_map[
+ (remote_event.type, remote_event.state_key)
+ ] = remote_event.event_id
- return remote_state
+ return state_map
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
@@ -1038,7 +1057,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
- state: Optional[Iterable[EventBase]],
+ state_ids: Optional[StateMap[str]],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@@ -1060,7 +1079,7 @@ class FederationEventHandler:
event: event to be persisted
- state: Normally None, but if we are handling a gap in the graph
+ state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
@@ -1072,7 +1091,8 @@ class FederationEventHandler:
try:
context = await self._state_handler.compute_event_context(
- event, old_state=state
+ event,
+ state_ids_before_event=state_ids,
)
context = await self._check_event_auth(
origin,
@@ -1089,7 +1109,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- await self._check_for_soft_fail(event, state, origin=origin)
+ await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@@ -1228,6 +1248,14 @@ class FederationEventHandler:
# Nothing to retrieve then (invalid marker)
return
+ already_seen_insertion_event = await self._store.have_seen_event(
+ marker_event.room_id, insertion_event_id
+ )
+ if already_seen_insertion_event:
+ # No need to process a marker again if we have already seen the
+ # insertion event that it was pointing to
+ return
+
logger.debug(
"_handle_marker_event: backfilling insertion event %s", insertion_event_id
)
@@ -1423,7 +1451,7 @@ class FederationEventHandler:
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
try:
validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth)
@@ -1500,7 +1528,11 @@ class FederationEventHandler:
return context
# now check auth against what we think the auth events *should* be.
- prev_state_ids = await context.get_prev_state_ids()
+ event_types = event_auth.auth_types_for_event(event.room_version, event)
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types(event_types)
+ )
+
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -1552,14 +1584,16 @@ class FederationEventHandler:
if guest_access == GuestAccess.CAN_JOIN:
return
- current_state_map = await self._state_handler.get_current_state(event.room_id)
- current_state = list(current_state_map.values())
- await self._get_room_member_handler().kick_guest_users(current_state)
+ current_state = await self._storage_controllers.state.get_current_state(
+ event.room_id
+ )
+ current_state_list = list(current_state.values())
+ await self._get_room_member_handler().kick_guest_users(current_state_list)
async def _check_for_soft_fail(
self,
event: EventBase,
- state: Optional[Iterable[EventBase]],
+ state_ids: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@@ -1567,7 +1601,7 @@ class FederationEventHandler:
Args:
event
- state: The state at the event if we don't have all the event's prev events
+ state_ids: The state at the event if we don't have all the event's prev events
origin: The host the event originates from.
"""
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
@@ -1582,8 +1616,11 @@ class FederationEventHandler:
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ # The event types we want to pull from the "current" state.
+ auth_types = auth_types_for_event(room_version_obj, event)
+
# Calculate the "current state".
- if state is not None:
+ if state_ids is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
@@ -1596,20 +1633,25 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets_d = await self._state_store.get_state_groups(
+ state_sets_d = await self._state_storage_controller.get_state_groups_ids(
event.room_id, extrem_ids
)
- state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
- state_sets.append(state)
- current_states = await self._state_handler.resolve_events(
- room_version, state_sets, event
+ state_sets: List[StateMap[str]] = list(state_sets_d.values())
+ state_sets.append(state_ids)
+ current_state_ids = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version,
+ state_sets,
+ event_map=None,
+ state_res_store=StateResolutionStore(self._store),
+ )
)
- current_state_ids: StateMap[str] = {
- k: e.event_id for k, e in current_states.items()
- }
else:
- current_state_ids = await self._state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
+ current_state_ids = (
+ await self._state_storage_controller.get_current_state_ids(
+ event.room_id, StateFilter.from_types(auth_types)
+ )
)
logger.debug(
@@ -1619,7 +1661,6 @@ class FederationEventHandler:
)
# Now check if event pass auth against said current state
- auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
@@ -1865,7 +1906,7 @@ class FederationEventHandler:
# create a new state group as a delta from the existing one.
prev_group = context.state_group
- state_group = await self._state_store.store_state_group(
+ state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -1874,10 +1915,10 @@ class FederationEventHandler:
)
return EventContext.with_state(
+ storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
+ state_delta_due_to_event=state_updates,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
@@ -1913,7 +1954,7 @@ class FederationEventHandler:
min_depth,
)
else:
- await self._action_generator.handle_push_actions_for_event(
+ await self._bulk_push_rule_evaluator.action_for_event_by_user(
event, context
)
@@ -1964,11 +2005,14 @@ class FederationEventHandler:
)
return result["max_stream_id"]
else:
- assert self._storage.persistence
+ assert self._storage_controllers.persistence
# Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs.
- events, max_stream_token = await self._storage.persistence.persist_events(
+ (
+ events,
+ max_stream_token,
+ ) = await self._storage_controllers.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
deleted file mode 100644
index e7a39978..00000000
--- a/synapse/handlers/groups_local.py
+++ /dev/null
@@ -1,503 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
-
-from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, JsonDict, get_domain_from_id
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
- """Returns an async function that looks at the group id and calls the function
- on federation or the local group server if the group is local
- """
-
- async def f(
- self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
- ) -> JsonDict:
- if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
-
- if self.is_mine_id(group_id):
- return await getattr(self.groups_server_handler, func_name)(
- group_id, *args, **kwargs
- )
- else:
- destination = get_domain_from_id(group_id)
-
- try:
- return await getattr(self.transport_client, func_name)(
- destination, group_id, *args, **kwargs
- )
- except HttpResponseException as e:
- # Capture errors returned by the remote homeserver and
- # re-throw specific errors as SynapseErrors. This is so
- # when the remote end responds with things like 403 Not
- # In Group, we can communicate that to the client instead
- # of a 500.
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- return f
-
-
-class GroupsLocalWorkerHandler:
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.store = hs.get_datastores().main
- self.room_list_handler = hs.get_room_list_handler()
- self.groups_server_handler = hs.get_groups_server_handler()
- self.transport_client = hs.get_federation_transport_client()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.keyring = hs.get_keyring()
- self.is_mine_id = hs.is_mine_id
- self.signing_key = hs.signing_key
- self.server_name = hs.hostname
- self.notifier = hs.get_notifier()
- self.attestations = hs.get_groups_attestation_signing()
-
- self.profile_handler = hs.get_profile_handler()
-
- # The following functions merely route the query to the local groups server
- # or federation depending on if the group is local or remote
-
- get_group_profile = _create_rerouter("get_group_profile")
- get_rooms_in_group = _create_rerouter("get_rooms_in_group")
- get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
- get_group_category = _create_rerouter("get_group_category")
- get_group_categories = _create_rerouter("get_group_categories")
- get_group_role = _create_rerouter("get_group_role")
- get_group_roles = _create_rerouter("get_group_roles")
-
- async def get_group_summary(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get the group summary for a group.
-
- If the group is remote we check that the users have valid attestations.
- """
- if self.is_mine_id(group_id):
- res = await self.groups_server_handler.get_group_summary(
- group_id, requester_user_id
- )
- else:
- try:
- res = await self.transport_client.get_group_summary(
- get_domain_from_id(group_id), group_id, requester_user_id
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- group_server_name = get_domain_from_id(group_id)
-
- # Loop through the users and validate the attestations.
- chunk = res["users_section"]["users"]
- valid_users = []
- for entry in chunk:
- g_user_id = entry["user_id"]
- attestation = entry.pop("attestation", {})
- try:
- if get_domain_from_id(g_user_id) != group_server_name:
- await self.attestations.verify_attestation(
- attestation,
- group_id=group_id,
- user_id=g_user_id,
- server_name=get_domain_from_id(g_user_id),
- )
- valid_users.append(entry)
- except Exception as e:
- logger.info("Failed to verify user is in group: %s", e)
-
- res["users_section"]["users"] = valid_users
-
- res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
- res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
-
- # Add `is_publicised` flag to indicate whether the user has publicised their
- # membership of the group on their profile
- result = await self.store.get_publicised_groups_for_user(requester_user_id)
- is_publicised = group_id in result
-
- res.setdefault("user", {})["is_publicised"] = is_publicised
-
- return res
-
- async def get_users_in_group(
- self, group_id: str, requester_user_id: str
- ) -> JsonDict:
- """Get users in a group"""
- if self.is_mine_id(group_id):
- return await self.groups_server_handler.get_users_in_group(
- group_id, requester_user_id
- )
-
- group_server_name = get_domain_from_id(group_id)
-
- try:
- res = await self.transport_client.get_users_in_group(
- get_domain_from_id(group_id), group_id, requester_user_id
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- chunk = res["chunk"]
- valid_entries = []
- for entry in chunk:
- g_user_id = entry["user_id"]
- attestation = entry.pop("attestation", {})
- try:
- if get_domain_from_id(g_user_id) != group_server_name:
- await self.attestations.verify_attestation(
- attestation,
- group_id=group_id,
- user_id=g_user_id,
- server_name=get_domain_from_id(g_user_id),
- )
- valid_entries.append(entry)
- except Exception as e:
- logger.info("Failed to verify user is in group: %s", e)
-
- res["chunk"] = valid_entries
-
- return res
-
- async def get_joined_groups(self, user_id: str) -> JsonDict:
- group_ids = await self.store.get_joined_groups(user_id)
- return {"groups": group_ids}
-
- async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
- if self.hs.is_mine_id(user_id):
- result = await self.store.get_publicised_groups_for_user(user_id)
-
- # Check AS associated groups for this user - this depends on the
- # RegExps in the AS registration file (under `users`)
- for app_service in self.store.get_app_services():
- result.extend(app_service.get_groups_for_user(user_id))
-
- return {"groups": result}
- else:
- try:
- bulk_result = await self.transport_client.bulk_get_publicised_groups(
- get_domain_from_id(user_id), [user_id]
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- result = bulk_result.get("users", {}).get(user_id)
- # TODO: Verify attestations
- return {"groups": result}
-
- async def bulk_get_publicised_groups(
- self, user_ids: Iterable[str], proxy: bool = True
- ) -> JsonDict:
- destinations: Dict[str, Set[str]] = {}
- local_users = set()
-
- for user_id in user_ids:
- if self.hs.is_mine_id(user_id):
- local_users.add(user_id)
- else:
- destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id)
-
- if not proxy and destinations:
- raise SynapseError(400, "Some user_ids are not local")
-
- results = {}
- failed_results: List[str] = []
- for destination, dest_user_ids in destinations.items():
- try:
- r = await self.transport_client.bulk_get_publicised_groups(
- destination, list(dest_user_ids)
- )
- results.update(r["users"])
- except Exception:
- failed_results.extend(dest_user_ids)
-
- for uid in local_users:
- results[uid] = await self.store.get_publicised_groups_for_user(uid)
-
- # Check AS associated groups for this user - this depends on the
- # RegExps in the AS registration file (under `users`)
- for app_service in self.store.get_app_services():
- results[uid].extend(app_service.get_groups_for_user(uid))
-
- return {"users": results}
-
-
-class GroupsLocalHandler(GroupsLocalWorkerHandler):
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
- # Ensure attestations get renewed
- hs.get_groups_attestation_renewer()
-
- # The following functions merely route the query to the local groups server
- # or federation depending on if the group is local or remote
-
- update_group_profile = _create_rerouter("update_group_profile")
-
- add_room_to_group = _create_rerouter("add_room_to_group")
- update_room_in_group = _create_rerouter("update_room_in_group")
- remove_room_from_group = _create_rerouter("remove_room_from_group")
-
- update_group_summary_room = _create_rerouter("update_group_summary_room")
- delete_group_summary_room = _create_rerouter("delete_group_summary_room")
-
- update_group_category = _create_rerouter("update_group_category")
- delete_group_category = _create_rerouter("delete_group_category")
-
- update_group_summary_user = _create_rerouter("update_group_summary_user")
- delete_group_summary_user = _create_rerouter("delete_group_summary_user")
-
- update_group_role = _create_rerouter("update_group_role")
- delete_group_role = _create_rerouter("delete_group_role")
-
- set_group_join_policy = _create_rerouter("set_group_join_policy")
-
- async def create_group(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Create a group"""
-
- logger.info("Asking to create group with ID: %r", group_id)
-
- if self.is_mine_id(group_id):
- res = await self.groups_server_handler.create_group(
- group_id, user_id, content
- )
- local_attestation = None
- remote_attestation = None
- else:
- raise SynapseError(400, "Unable to create remote groups")
-
- is_publicised = content.get("publicise", False)
- token = await self.store.register_user_group_membership(
- group_id,
- user_id,
- membership="join",
- is_admin=True,
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- is_publicised=is_publicised,
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
-
- return res
-
- async def join_group(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Request to join a group"""
- if self.is_mine_id(group_id):
- await self.groups_server_handler.join_group(group_id, user_id, content)
- local_attestation = None
- remote_attestation = None
- else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content["attestation"] = local_attestation
-
- try:
- res = await self.transport_client.join_group(
- get_domain_from_id(group_id), group_id, user_id, content
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- remote_attestation = res["attestation"]
-
- await self.attestations.verify_attestation(
- remote_attestation,
- group_id=group_id,
- user_id=user_id,
- server_name=get_domain_from_id(group_id),
- )
-
- # TODO: Check that the group is public and we're being added publicly
- is_publicised = content.get("publicise", False)
-
- token = await self.store.register_user_group_membership(
- group_id,
- user_id,
- membership="join",
- is_admin=False,
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- is_publicised=is_publicised,
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
-
- return {}
-
- async def accept_invite(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """Accept an invite to a group"""
- if self.is_mine_id(group_id):
- await self.groups_server_handler.accept_invite(group_id, user_id, content)
- local_attestation = None
- remote_attestation = None
- else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content["attestation"] = local_attestation
-
- try:
- res = await self.transport_client.accept_group_invite(
- get_domain_from_id(group_id), group_id, user_id, content
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- remote_attestation = res["attestation"]
-
- await self.attestations.verify_attestation(
- remote_attestation,
- group_id=group_id,
- user_id=user_id,
- server_name=get_domain_from_id(group_id),
- )
-
- # TODO: Check that the group is public and we're being added publicly
- is_publicised = content.get("publicise", False)
-
- token = await self.store.register_user_group_membership(
- group_id,
- user_id,
- membership="join",
- is_admin=False,
- local_attestation=local_attestation,
- remote_attestation=remote_attestation,
- is_publicised=is_publicised,
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
-
- return {}
-
- async def invite(
- self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
- ) -> JsonDict:
- """Invite a user to a group"""
- content = {"requester_user_id": requester_user_id, "config": config}
- if self.is_mine_id(group_id):
- res = await self.groups_server_handler.invite_to_group(
- group_id, user_id, requester_user_id, content
- )
- else:
- try:
- res = await self.transport_client.invite_to_group(
- get_domain_from_id(group_id),
- group_id,
- user_id,
- requester_user_id,
- content,
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- return res
-
- async def on_invite(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> JsonDict:
- """One of our users were invited to a group"""
- # TODO: Support auto join and rejection
-
- if not self.is_mine_id(user_id):
- raise SynapseError(400, "User not on this server")
-
- local_profile = {}
- if "profile" in content:
- if "name" in content["profile"]:
- local_profile["name"] = content["profile"]["name"]
- if "avatar_url" in content["profile"]:
- local_profile["avatar_url"] = content["profile"]["avatar_url"]
-
- token = await self.store.register_user_group_membership(
- group_id,
- user_id,
- membership="invite",
- content={"profile": local_profile, "inviter": content["inviter"]},
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
- try:
- user_profile = await self.profile_handler.get_profile(user_id)
- except Exception as e:
- logger.warning("No profile for user %s: %s", user_id, e)
- user_profile = {}
-
- return {"state": "invite", "user_profile": user_profile}
-
- async def remove_user_from_group(
- self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
- ) -> JsonDict:
- """Remove a user from a group"""
- if user_id == requester_user_id:
- token = await self.store.register_user_group_membership(
- group_id, user_id, membership="leave"
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
-
- # TODO: Should probably remember that we tried to leave so that we can
- # retry if the group server is currently down.
-
- if self.is_mine_id(group_id):
- res = await self.groups_server_handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content
- )
- else:
- content["requester_user_id"] = requester_user_id
- try:
- res = await self.transport_client.remove_user_from_group(
- get_domain_from_id(group_id),
- group_id,
- requester_user_id,
- user_id,
- content,
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- return res
-
- async def user_removed_from_group(
- self, group_id: str, user_id: str, content: JsonDict
- ) -> None:
- """One of our users was removed/kicked from a group"""
- # TODO: Check if user in group
- token = await self.store.register_user_group_membership(
- group_id, user_id, membership="leave"
- )
- self.notifier.on_new_event("groups_key", token, users=[user_id])
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 7b94770f..85b472f2 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -30,6 +30,7 @@ from synapse.types import (
Requester,
RoomStreamToken,
StateMap,
+ StreamKeyType,
StreamToken,
UserID,
)
@@ -66,8 +67,8 @@ class InitialSyncHandler:
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
async def snapshot_all_rooms(
self,
@@ -143,7 +144,7 @@ class InitialSyncHandler:
to_key=int(now_token.receipt_key),
)
if self.hs.config.experimental.msc2285_enabled:
- receipt = ReceiptEventSource.filter_out_private(receipt, user_id)
+ receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -189,7 +190,7 @@ class InitialSyncHandler:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
- self.state_handler.get_current_state, event.room_id
+ self._state_storage_controller.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
@@ -197,7 +198,8 @@ class InitialSyncHandler:
event.stream_ordering,
)
deferred_room_state = run_in_background(
- self.state_store.get_state_for_events, [event.event_id]
+ self._state_storage_controller.get_state_for_events,
+ [event.event_id],
).addCallback(
lambda states: cast(StateMap[EventBase], states[event.event_id])
)
@@ -217,11 +219,13 @@ class InitialSyncHandler:
).addErrback(unwrapFirstError)
messages = await filter_events_for_client(
- self.storage, user_id, messages
+ self._storage_controllers, user_id, messages
)
- start_token = now_token.copy_and_replace("room_key", token)
- end_token = now_token.copy_and_replace("room_key", room_end_token)
+ start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
+ end_token = now_token.copy_and_replace(
+ StreamKeyType.ROOM, room_end_token
+ )
time_now = self.clock.time_msec()
d["messages"] = {
@@ -271,7 +275,7 @@ class InitialSyncHandler:
"rooms": rooms_ret,
"presence": [
{
- "type": "m.presence",
+ "type": EduTypes.PRESENCE,
"content": format_user_presence_state(event, now),
}
for event in presence
@@ -352,7 +356,9 @@ class InitialSyncHandler:
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
- room_state = await self.state_store.get_state_for_event(member_event_id)
+ room_state = await self._state_storage_controller.get_state_for_event(
+ member_event_id
+ )
limit = pagin_config.limit if pagin_config else None
if limit is None:
@@ -366,11 +372,11 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
- self.storage, user_id, messages, is_peeking=is_peeking
+ self._storage_controllers, user_id, messages, is_peeking=is_peeking
)
- start_token = StreamToken.START.copy_and_replace("room_key", token)
- end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
+ start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
+ end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token)
time_now = self.clock.time_msec()
@@ -401,7 +407,9 @@ class InitialSyncHandler:
membership: str,
is_peeking: bool,
) -> JsonDict:
- current_state = await self.state.get_current_state(room_id=room_id)
+ current_state = await self._storage_controllers.state.get_current_state(
+ room_id=room_id
+ )
# TODO: These concurrently
time_now = self.clock.time_msec()
@@ -436,7 +444,7 @@ class InitialSyncHandler:
return [
{
- "type": EduTypes.Presence,
+ "type": EduTypes.PRESENCE,
"content": format_user_presence_state(s, time_now),
}
for s in states
@@ -449,7 +457,9 @@ class InitialSyncHandler:
if not receipts:
return []
if self.hs.config.experimental.msc2285_enabled:
- receipts = ReceiptEventSource.filter_out_private(receipts, user_id)
+ receipts = ReceiptEventSource.filter_out_private_receipts(
+ receipts, user_id
+ )
return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable(
@@ -469,10 +479,10 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
- self.storage, user_id, messages, is_peeking=is_peeking
+ self._storage_controllers, user_id, messages, is_peeking=is_peeking
)
- start_token = now_token.copy_and_replace("room_key", token)
+ start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
end_token = now_token
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c28b792e..f455158a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -28,6 +28,7 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
GuestAccess,
+ HistoryVisibility,
Membership,
RelationTypes,
UserTypes,
@@ -44,7 +45,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.event_auth import validate_event_for_room_version
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
@@ -54,12 +55,19 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
-from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
+from synapse.types import (
+ MutableStateMap,
+ Requester,
+ RoomAlias,
+ StreamToken,
+ UserID,
+ create_requester,
+)
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
-from synapse.visibility import filter_events_for_client
+from synapse.visibility import get_effective_room_visibility_from_state
if TYPE_CHECKING:
from synapse.events.third_party_rules import ThirdPartyEventRules
@@ -76,8 +84,8 @@ class MessageHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
@@ -117,14 +125,16 @@ class MessageHandler:
)
if membership == Membership.JOIN:
- data = await self.state.get_current_state(room_id, event_type, state_key)
+ data = await self._storage_controllers.state.get_current_state_event(
+ room_id, event_type, state_key
+ )
elif membership == Membership.LEAVE:
key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist.
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
- room_state = await self.state_store.get_state_for_events(
+ room_state = await self._state_storage_controller.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@@ -175,49 +185,31 @@ class MessageHandler:
state_filter = state_filter or StateFilter.all()
if at_token:
- last_event = await self.store.get_last_event_in_room_before_stream_ordering(
- room_id,
- end_token=at_token.room_key,
+ last_event_id = (
+ await self.store.get_last_event_in_room_before_stream_ordering(
+ room_id,
+ end_token=at_token.room_key,
+ )
)
- if not last_event:
+ if not last_event_id:
raise NotFoundError("Can't find event for token %s" % (at_token,))
- # check whether the user is in the room at that time to determine
- # whether they should be treated as peeking.
- state_map = await self.state_store.get_state_for_event(
- last_event.event_id,
- StateFilter.from_types([(EventTypes.Member, user_id)]),
- )
-
- joined = False
- membership_event = state_map.get((EventTypes.Member, user_id))
- if membership_event:
- joined = membership_event.membership == Membership.JOIN
-
- is_peeking = not joined
-
- visible_events = await filter_events_for_client(
- self.storage,
- user_id,
- [last_event],
- filter_send_to_client=False,
- is_peeking=is_peeking,
- )
-
- if visible_events:
- room_state_events = await self.state_store.get_state_for_events(
- [last_event.event_id], state_filter=state_filter
- )
- room_state: Mapping[Any, EventBase] = room_state_events[
- last_event.event_id
- ]
- else:
+ if not await self._user_can_see_state_at_event(
+ user_id, room_id, last_event_id
+ ):
raise AuthError(
403,
"User %s not allowed to view events in room %s at token %s"
% (user_id, room_id, at_token),
)
+
+ room_state_events = (
+ await self._state_storage_controller.get_state_for_events(
+ [last_event_id], state_filter=state_filter
+ )
+ )
+ room_state: Mapping[Any, EventBase] = room_state_events[last_event_id]
else:
(
membership,
@@ -227,7 +219,7 @@ class MessageHandler:
)
if membership == Membership.JOIN:
- state_ids = await self.store.get_filtered_current_state_ids(
+ state_ids = await self._state_storage_controller.get_current_state_ids(
room_id, state_filter=state_filter
)
room_state = await self.store.get_events(state_ids.values())
@@ -236,8 +228,10 @@ class MessageHandler:
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
- room_state_events = await self.state_store.get_state_for_events(
- [membership_event_id], state_filter=state_filter
+ room_state_events = (
+ await self._state_storage_controller.get_state_for_events(
+ [membership_event_id], state_filter=state_filter
+ )
)
room_state = room_state_events[membership_event_id]
@@ -245,6 +239,65 @@ class MessageHandler:
events = self._event_serializer.serialize_events(room_state.values(), now)
return events
+ async def _user_can_see_state_at_event(
+ self, user_id: str, room_id: str, event_id: str
+ ) -> bool:
+ # check whether the user was in the room, and the history visibility,
+ # at that time.
+ state_map = await self._state_storage_controller.get_state_for_event(
+ event_id,
+ StateFilter.from_types(
+ [
+ (EventTypes.Member, user_id),
+ (EventTypes.RoomHistoryVisibility, ""),
+ ]
+ ),
+ )
+
+ membership = None
+ membership_event = state_map.get((EventTypes.Member, user_id))
+ if membership_event:
+ membership = membership_event.membership
+
+ # if the user was a member of the room at the time of the event,
+ # they can see it.
+ if membership == Membership.JOIN:
+ return True
+
+ # otherwise, it depends on the history visibility.
+ visibility = get_effective_room_visibility_from_state(state_map)
+
+ if visibility == HistoryVisibility.JOINED:
+ # we weren't a member at the time of the event, so we can't see this event.
+ return False
+
+ # otherwise *invited* is good enough
+ if membership == Membership.INVITE:
+ return True
+
+ if visibility == HistoryVisibility.INVITED:
+ # we weren't invited, so we can't see this event.
+ return False
+
+ if visibility == HistoryVisibility.WORLD_READABLE:
+ return True
+
+ # So it's SHARED, and the user was not a member at the time. The user cannot
+ # see history, unless they have *subsequently* joined the room.
+ #
+ # XXX: if the user has subsequently joined and then left again,
+ # ideally we would share history up to the point they left. But
+ # we don't know when they left. We just treat it as though they
+ # never joined, and restrict access.
+
+ (
+ current_membership,
+ _,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ user_id, event_id
+ )
+ return current_membership == Membership.JOIN
+
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
"""Get all the joined members in the room and their profile information.
@@ -394,7 +447,7 @@ class EventCreationHandler:
self.auth = hs.get_auth()
self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
@@ -426,7 +479,7 @@ class EventCreationHandler:
# This is to stop us from diverging history *too* much.
self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
- self.action_generator = hs.get_action_generator()
+ self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules: "ThirdPartyEventRules" = (
@@ -634,7 +687,9 @@ class EventCreationHandler:
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.Member, None)])
+ )
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = (
await self.store.get_event(prev_event_id, allow_none=True)
@@ -757,7 +812,13 @@ class EventCreationHandler:
The previous version of the event is returned, if it is found in the
event context. Otherwise, None is returned.
"""
- prev_state_ids = await context.get_prev_state_ids()
+ if event.internal_metadata.is_outlier():
+ # This can happen due to out of band memberships
+ return None
+
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(event.type, None)])
+ )
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return None
@@ -877,11 +938,39 @@ class EventCreationHandler:
event.sender,
)
- spam_error = await self.spam_checker.check_event_for_spam(event)
- if spam_error:
- if not isinstance(spam_error, str):
- spam_error = "Spam is not permitted here"
- raise SynapseError(403, spam_error, Codes.FORBIDDEN)
+ spam_check_result = await self.spam_checker.check_event_for_spam(event)
+ if spam_check_result != self.spam_checker.NOT_SPAM:
+ if isinstance(spam_check_result, tuple):
+ try:
+ [code, dict] = spam_check_result
+ raise SynapseError(
+ 403,
+ "This message had been rejected as probable spam",
+ code,
+ dict,
+ )
+ except ValueError:
+ logger.error(
+ "Spam-check module returned invalid error value. Expecting [code, dict], got %s",
+ spam_check_result,
+ )
+ spam_check_result = Codes.FORBIDDEN
+
+ if isinstance(spam_check_result, Codes):
+ raise SynapseError(
+ 403,
+ "This message has been rejected as probable spam",
+ spam_check_result,
+ )
+
+ # Backwards compatibility: if the return value is not an error code, it
+ # means the module returned an error message to be included in the
+ # SynapseError (which is now deprecated).
+ raise SynapseError(
+ 403,
+ spam_check_result,
+ Codes.FORBIDDEN,
+ )
ev = await self.handle_new_client_event(
requester=requester,
@@ -1001,7 +1090,7 @@ class EventCreationHandler:
# after it is created
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
- context = EventContext.for_outlier()
+ context = EventContext.for_outlier(self._storage_controllers)
elif (
event.type == EventTypes.MSC2716_INSERTION
and state_event_ids
@@ -1013,8 +1102,35 @@ class EventCreationHandler:
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
- old_state = await self.store.get_events_as_list(state_event_ids)
- context = await self.state.compute_event_context(event, old_state=old_state)
+ metadata = await self.store.get_metadata_for_events(state_event_ids)
+
+ state_map_for_event: MutableStateMap[str] = {}
+ for state_id in state_event_ids:
+ data = metadata.get(state_id)
+ if data is None:
+ # We're trying to persist a new historical batch of events
+ # with the given state, e.g. via
+ # `RoomBatchSendEventRestServlet`. The state can be inferred
+ # by Synapse or set directly by the client.
+ #
+ # Either way, we should have persisted all the state before
+ # getting here.
+ raise Exception(
+ f"State event {state_id} not found in DB,"
+ " Synapse should have persisted it before using it."
+ )
+
+ if data.state_key is None:
+ raise Exception(
+ f"Trying to set non-state event {state_id} as state"
+ )
+
+ state_map_for_event[(data.event_type, data.state_key)] = state_id
+
+ context = await self.state.compute_event_context(
+ event,
+ state_ids_before_event=state_map_for_event,
+ )
else:
context = await self.state.compute_event_context(event)
@@ -1056,20 +1172,11 @@ class EventCreationHandler:
SynapseError if the event is invalid.
"""
- relation = event.content.get("m.relates_to")
+ relation = relation_from_event(event)
if not relation:
return
- relation_type = relation.get("rel_type")
- if not relation_type:
- return
-
- # Ensure the parent is real.
- relates_to = relation.get("event_id")
- if not relates_to:
- return
-
- parent_event = await self.store.get_event(relates_to, allow_none=True)
+ parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
if parent_event:
# And in the same room.
if parent_event.room_id != event.room_id:
@@ -1078,28 +1185,31 @@ class EventCreationHandler:
else:
# There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine.
- if not await self.store.event_is_target_of_relation(relates_to):
+ if not await self.store.event_is_target_of_relation(relation.parent_id):
# Otherwise, the client can't know about the parent event!
raise SynapseError(400, "Can't send relation to unknown event")
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
# event multiple times).
- if relation_type == RelationTypes.ANNOTATION:
- aggregation_key = relation["key"]
+ if relation.rel_type == RelationTypes.ANNOTATION:
+ aggregation_key = relation.aggregation_key
+
+ if aggregation_key is None:
+ raise SynapseError(400, "Missing aggregation key")
if len(aggregation_key) > 500:
raise SynapseError(400, "Aggregation key is too long")
already_exists = await self.store.has_user_annotated_event(
- relates_to, event.type, aggregation_key, event.sender
+ relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
# Don't attempt to start a thread if the parent event is a relation.
- elif relation_type == RelationTypes.THREAD:
- if await self.store.event_includes_relation(relates_to):
+ elif relation.rel_type == RelationTypes.THREAD:
+ if await self.store.event_includes_relation(relation.parent_id):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
)
@@ -1245,7 +1355,9 @@ class EventCreationHandler:
# and `state_groups` because they have `prev_events` that aren't persisted yet
# (historical messages persisted in reverse-chronological order).
if not event.internal_metadata.is_historical():
- await self.action_generator.handle_push_actions_for_event(event, context)
+ await self._bulk_push_rule_evaluator.action_for_event_by_user(
+ event, context
+ )
try:
# If we're a worker we need to hit out to the master.
@@ -1391,7 +1503,7 @@ class EventCreationHandler:
"""
extra_users = extra_users or []
- assert self.storage.persistence is not None
+ assert self._storage_controllers.persistence is not None
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
@@ -1547,7 +1659,11 @@ class EventCreationHandler:
"Redacting MSC2716 events is not supported in this room version",
)
- prev_state_ids = await context.get_prev_state_ids()
+ event_types = event_auth.auth_types_for_event(event.room_version, event)
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types(event_types)
+ )
+
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -1621,7 +1737,7 @@ class EventCreationHandler:
event,
event_pos,
max_stream_token,
- ) = await self.storage.persistence.persist_event(
+ ) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
)
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index f6ffb7d1..9de61d55 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -224,7 +224,7 @@ class OidcHandler:
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
- logger.exception("Could not verify session for OIDC callback")
+ logger.warning("Could not verify session for OIDC callback: %s", e)
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
@@ -827,7 +827,7 @@ class OidcProvider:
logger.debug("Exchanging OAuth2 code for a token")
token = await self._exchange_code(code)
except OidcError as e:
- logger.exception("Could not exchange OAuth2 code")
+ logger.warning("Could not exchange OAuth2 code: %s", e)
self._sso_handler.render_error(request, e.error, e.error_description)
return
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 7ee33403..6262a358 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -27,7 +27,7 @@ from synapse.handlers.room import ShutdownRoomResponse
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester
+from synapse.types import JsonDict, Requester, StreamKeyType
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -129,8 +129,8 @@ class PaginationHandler:
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -239,7 +239,7 @@ class PaginationHandler:
# defined in the server's configuration, we can safely assume that's the
# case and use it for this room.
max_lifetime = (
- retention_policy["max_lifetime"] or self._retention_default_max_lifetime
+ retention_policy.max_lifetime or self._retention_default_max_lifetime
)
# Cap the effective max_lifetime to be within the range allowed in the
@@ -352,7 +352,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self.pagination_lock.write(room_id):
- await self.storage.purge_events.purge_history(
+ await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] complete")
@@ -414,7 +414,7 @@ class PaginationHandler:
if joined:
raise SynapseError(400, "Users are still joined to this room")
- await self.storage.purge_events.purge_room(room_id)
+ await self._storage_controllers.purge_events.purge_room(room_id)
async def get_messages(
self,
@@ -448,7 +448,7 @@ class PaginationHandler:
)
# We expect `/messages` to use historic pagination tokens by default but
# `/messages` should still works with live tokens when manually provided.
- assert from_token.room_key.topological
+ assert from_token.room_key.topological is not None
if pagin_config.limit is None:
# This shouldn't happen as we've set a default limit before this
@@ -491,7 +491,7 @@ class PaginationHandler:
if leave_token.topological < curr_topo:
from_token = from_token.copy_and_replace(
- "room_key", leave_token
+ StreamKeyType.ROOM, leave_token
)
await self.hs.get_federation_handler().maybe_backfill(
@@ -513,16 +513,30 @@ class PaginationHandler:
event_filter=event_filter,
)
- next_token = from_token.copy_and_replace("room_key", next_key)
+ next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
- if events:
- if event_filter:
- events = await event_filter.filter(events)
+ # if no events are returned from pagination, that implies
+ # we have reached the end of the available events.
+ # In that case we do not return end, to tell the client
+ # there is no need for further queries.
+ if not events:
+ return {
+ "chunk": [],
+ "start": await from_token.to_string(self.store),
+ }
- events = await filter_events_for_client(
- self.storage, user_id, events, is_peeking=(member_event_id is None)
- )
+ if event_filter:
+ events = await event_filter.filter(events)
+
+ events = await filter_events_for_client(
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+ # if after the filter applied there are no more events
+ # return immediately - but there might be more in next_token batch
if not events:
return {
"chunk": [],
@@ -539,7 +553,7 @@ class PaginationHandler:
(EventTypes.Member, event.sender) for event in events
)
- state_ids = await self.state_store.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
@@ -653,7 +667,7 @@ class PaginationHandler:
400, "Users are still joined to this room"
)
- await self.storage.purge_events.purge_room(room_id)
+ await self._storage_controllers.purge_events.purge_room(room_id)
logger.info("complete")
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 268481ec..895ea63e 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -49,7 +49,7 @@ from prometheus_client import Counter
from typing_extensions import ContextManager
import synapse.metrics
-from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.appservice import ApplicationService
@@ -66,7 +66,7 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
@@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC):
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.presence_router = hs.get_presence_router()
self.state = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
@@ -394,7 +395,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
- "m.presence",
+ EduTypes.PRESENCE,
hs.config.worker.writers.presence,
)
@@ -522,7 +523,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key",
+ StreamKeyType.PRESENCE,
stream_id,
rooms=room_ids_to_states.keys(),
users=users_to_states.keys(),
@@ -649,7 +650,9 @@ class PresenceHandler(BasePresenceHandler):
federation_registry = hs.get_federation_registry()
- federation_registry.register_edu_handler("m.presence", self.incoming_presence)
+ federation_registry.register_edu_handler(
+ EduTypes.PRESENCE, self.incoming_presence
+ )
LaterGauge(
"synapse_handlers_presence_user_to_current_state_size",
@@ -1145,7 +1148,7 @@ class PresenceHandler(BasePresenceHandler):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key",
+ StreamKeyType.PRESENCE,
stream_id,
rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states],
@@ -1346,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler):
self._event_pos,
room_max_stream_ordering,
)
- max_pos, deltas = await self.store.get_current_state_deltas(
+ (
+ max_pos,
+ deltas,
+ ) = await self._storage_controllers.state.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 239b0aa7..6eed3826 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -23,14 +23,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.types import (
- JsonDict,
- Requester,
- UserID,
- create_requester,
- get_domain_from_id,
-)
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -50,9 +43,6 @@ class ProfileHandler:
delegate to master when necessary.
"""
- PROFILE_UPDATE_MS = 60 * 1000
- PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
-
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -73,11 +63,6 @@ class ProfileHandler:
self._third_party_rules = hs.get_third_party_event_rules()
- if hs.config.worker.run_background_tasks:
- self.clock.looping_call(
- self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
- )
-
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
@@ -116,30 +101,6 @@ class ProfileHandler:
raise SynapseError(502, "Failed to fetch profile")
raise e.to_synapse_error()
- async def get_profile_from_cache(self, user_id: str) -> JsonDict:
- """Get the profile information from our local cache. If the user is
- ours then the profile information will always be correct. Otherwise,
- it may be out of date/missing.
- """
- target_user = UserID.from_string(user_id)
- if self.hs.is_mine(target_user):
- try:
- displayname = await self.store.get_profile_displayname(
- target_user.localpart
- )
- avatar_url = await self.store.get_profile_avatar_url(
- target_user.localpart
- )
- except StoreError as e:
- if e.code == 404:
- raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
- raise
-
- return {"displayname": displayname, "avatar_url": avatar_url}
- else:
- profile = await self.store.get_from_remote_profile_cache(user_id)
- return profile or {}
-
async def get_displayname(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user):
try:
@@ -509,45 +470,3 @@ class ProfileHandler:
# so we act as if we couldn't find the profile.
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
raise
-
- @wrap_as_background_process("Update remote profile")
- async def _update_remote_profile_cache(self) -> None:
- """Called periodically to check profiles of remote users we haven't
- checked in a while.
- """
- entries = await self.store.get_remote_profile_cache_entries_that_expire(
- last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
- )
-
- for user_id, displayname, avatar_url in entries:
- is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
- user_id
- )
- if not is_subscribed:
- await self.store.maybe_delete_remote_profile_cache(user_id)
- continue
-
- try:
- profile = await self.federation.make_query(
- destination=get_domain_from_id(user_id),
- query_type="profile",
- args={"user_id": user_id},
- ignore_backoff=True,
- )
- except Exception:
- logger.exception("Failed to get avatar_url")
-
- await self.store.update_remote_profile_cache(
- user_id, displayname, avatar_url
- )
- continue
-
- new_name = profile.get("displayname")
- if not isinstance(new_name, str):
- new_name = None
- new_avatar = profile.get("avatar_url")
- if not isinstance(new_avatar, str):
- new_avatar = None
-
- # We always hit update to update the last_check timestamp
- await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d61535..43d2882b 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,10 +14,16 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService
from synapse.streams import EventSource
-from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ ReadReceipt,
+ StreamKeyType,
+ UserID,
+ get_domain_from_id,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -46,11 +52,11 @@ class ReceiptsHandler:
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
hs.get_federation_registry().register_edu_handler(
- "m.receipt", self._received_remote_receipt
+ EduTypes.RECEIPT, self._received_remote_receipt
)
else:
hs.get_federation_registry().register_instances_for_edu(
- "m.receipt",
+ EduTypes.RECEIPT,
hs.config.worker.writers.receipts,
)
@@ -129,7 +135,9 @@ class ReceiptsHandler:
affected_room_ids = list({r.room_id for r in receipts})
- self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
+ self.notifier.on_new_event(
+ StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
+ )
# Note that the min here shouldn't be relied upon to be accurate.
await self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
@@ -165,43 +173,69 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
self.config = hs.config
@staticmethod
- def filter_out_private(events: List[JsonDict], user_id: str) -> List[JsonDict]:
- """
- This method takes in what is returned by
- get_linearized_receipts_for_rooms() and goes through read receipts
- filtering out m.read.private receipts if they were not sent by the
- current user.
+ def filter_out_private_receipts(
+ rooms: List[JsonDict], user_id: str
+ ) -> List[JsonDict]:
"""
+ Filters a list of serialized receipts (as returned by /sync and /initialSync)
+ and removes private read receipts of other users.
- visible_events = []
-
- # filter out private receipts the user shouldn't see
- for event in events:
- content = event.get("content", {})
- new_event = event.copy()
- new_event["content"] = {}
-
- for event_id, event_content in content.items():
- receipt_event = {}
- for receipt_type, receipt_content in event_content.items():
- if receipt_type == ReceiptTypes.READ_PRIVATE:
- user_rr = receipt_content.get(user_id, None)
- if user_rr:
- receipt_event[ReceiptTypes.READ_PRIVATE] = {
- user_id: user_rr.copy()
- }
- else:
- receipt_event[receipt_type] = receipt_content.copy()
+ This operates on the return value of get_linearized_receipts_for_rooms(),
+ which is wrapped in a cache. Care must be taken to ensure that the input
+ values are not modified.
- # Only include the receipt event if it is non-empty.
- if receipt_event:
- new_event["content"][event_id] = receipt_event
+ Args:
+ rooms: A list of mappings, each mapping has a `content` field, which
+ is a map of event ID -> receipt type -> user ID -> receipt information.
- # Append new_event to visible_events unless empty
- if len(new_event["content"].keys()) > 0:
- visible_events.append(new_event)
+ Returns:
+ The same as rooms, but filtered.
+ """
- return visible_events
+ result = []
+
+ # Iterate through each room's receipt content.
+ for room in rooms:
+ # The receipt content with other user's private read receipts removed.
+ content = {}
+
+ # Iterate over each event ID / receipts for that event.
+ for event_id, orig_event_content in room.get("content", {}).items():
+ event_content = orig_event_content
+ # If there are private read receipts, additional logic is necessary.
+ if ReceiptTypes.READ_PRIVATE in event_content:
+ # Make a copy without private read receipts to avoid leaking
+ # other user's private read receipts..
+ event_content = {
+ receipt_type: receipt_value
+ for receipt_type, receipt_value in event_content.items()
+ if receipt_type != ReceiptTypes.READ_PRIVATE
+ }
+
+ # Copy the current user's private read receipt from the
+ # original content, if it exists.
+ user_private_read_receipt = orig_event_content[
+ ReceiptTypes.READ_PRIVATE
+ ].get(user_id, None)
+ if user_private_read_receipt:
+ event_content[ReceiptTypes.READ_PRIVATE] = {
+ user_id: user_private_read_receipt
+ }
+
+ # Include the event if there is at least one non-private read
+ # receipt or the current user has a private read receipt.
+ if event_content:
+ content[event_id] = event_content
+
+ # Include the event if there is at least one non-private read receipt
+ # or the current user has a private read receipt.
+ if content:
+ # Build a new event to avoid mutating the cache.
+ new_room = {k: v for k, v in room.items() if k != "content"}
+ new_room["content"] = content
+ result.append(new_room)
+
+ return result
async def get_new_events(
self,
@@ -223,7 +257,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
)
if self.config.experimental.msc2285_enabled:
- events = ReceiptEventSource.filter_out_private(events, user.to_string())
+ events = ReceiptEventSource.filter_out_private_receipts(
+ events, user.to_string()
+ )
return events, to_key
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 05bb1e02..33820428 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -87,6 +87,7 @@ class LoginDict(TypedDict):
class RegistrationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self.hs = hs
self.auth = hs.get_auth()
@@ -528,7 +529,7 @@ class RegistrationHandler:
if requires_invite:
# If the server is in the room, check if the room is public.
- state = await self.store.get_filtered_current_state_ids(
+ state = await self._storage_controllers.state.get_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index c2754ec9..0b63cd21 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,24 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections.abc
import logging
-from typing import (
- TYPE_CHECKING,
- Collection,
- Dict,
- FrozenSet,
- Iterable,
- List,
- Optional,
- Tuple,
-)
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
import attr
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
@@ -70,7 +60,7 @@ class BundledAggregations:
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
- self._storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
@@ -144,7 +134,10 @@ class RelationsHandler:
)
events = await filter_events_for_client(
- self._storage, user_id, events, is_peeking=(member_event_id is None)
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
)
now = self._clock.time_msec()
@@ -254,13 +247,19 @@ class RelationsHandler:
return filtered_results
- async def get_threads_for_events(
- self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
+ async def _get_threads_for_events(
+ self,
+ events_by_id: Dict[str, EventBase],
+ relations_by_id: Dict[str, str],
+ user_id: str,
+ ignored_users: FrozenSet[str],
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.
Args:
- event_ids: Events to get aggregations for threads.
+ events_by_id: A map of event_id to events to get aggregations for threads.
+ relations_by_id: A map of event_id to the relation type, if one exists
+ for that event.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
@@ -271,16 +270,34 @@ class RelationsHandler:
"""
user = UserID.from_string(user_id)
+ # It is not valid to start a thread on an event which itself relates to another event.
+ event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
+
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
- # Only fetch participated for a limited selection based on what had
- # summaries.
+ # Limit fetching whether the requester has participated in a thread to
+ # events which are thread roots.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
- participated = await self._main_store.get_threads_participated(
- thread_event_ids, user_id
+
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {
+ event_id: events_by_id[event_id].sender == user_id
+ for event_id in thread_event_ids
+ }
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [
+ event_id
+ for event_id in thread_event_ids
+ if not participated[event_id]
+ ],
+ user_id,
+ )
)
# Then subtract off the results for any ignored users.
@@ -341,7 +358,8 @@ class RelationsHandler:
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
- current_user_participated=participated[event_id],
+ current_user_participated=events_by_id[event_id].sender == user_id
+ or participated[event_id],
)
return results
@@ -373,20 +391,21 @@ class RelationsHandler:
if event.is_state():
continue
- relates_to = event.content.get("m.relates_to")
- relation_type = None
- if isinstance(relates_to, collections.abc.Mapping):
- relation_type = relates_to.get("rel_type")
+ relates_to = relation_from_event(event)
+ if relates_to:
# An event which is a replacement (ie edit) or annotation (ie,
# reaction) may not have any other event related to it.
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ if relates_to.rel_type in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REPLACE,
+ ):
continue
+ # Track the event's relation information for later.
+ relations_by_id[event.event_id] = relates_to.rel_type
+
# The event should get bundled aggregations.
events_by_id[event.event_id] = event
- # Track the event's relation information for later.
- if isinstance(relation_type, str):
- relations_by_id[event.event_id] = relation_type
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
@@ -398,9 +417,9 @@ class RelationsHandler:
# events to be fetched. Thus, we check those first!
# Fetch thread summaries (but only for the directly requested events).
- threads = await self.get_threads_for_events(
- # It is not valid to start a thread on an event which itself relates to another event.
- [eid for eid in events_by_id.keys() if eid not in relations_by_id],
+ threads = await self._get_threads_for_events(
+ events_by_id,
+ relations_by_id,
user_id,
ignored_users,
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 604eb6ec..520663f1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -33,6 +33,7 @@ from typing import (
import attr
from typing_extensions import TypedDict
+import synapse.events.snapshot
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -72,12 +73,12 @@ from synapse.types import (
RoomID,
RoomStreamToken,
StateMap,
+ StreamKeyType,
StreamToken,
UserID,
create_requester,
)
from synapse.util import stringutils
-from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
@@ -106,6 +107,7 @@ class EventContext:
class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.hs = hs
@@ -149,10 +151,11 @@ class RoomCreationHandler:
)
preset_config["encrypted"] = encrypted
- self._replication = hs.get_replication_data_handler()
+ self._default_power_level_content_override = (
+ self.config.room.default_power_level_content_override
+ )
- # linearizer to stop two upgrades happening at once
- self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
+ self._replication = hs.get_replication_data_handler()
# If a user tries to update the same room multiple times in quick
# succession, only process the first attempt and return its result to
@@ -196,6 +199,39 @@ class RoomCreationHandler:
400, "An upgrade for this room is currently in progress"
)
+ # Check whether the room exists and 404 if it doesn't.
+ # We could go straight for the auth check, but that will raise a 403 instead.
+ old_room = await self.store.get_room(old_room_id)
+ if old_room is None:
+ raise NotFoundError("Unknown room id %s" % (old_room_id,))
+
+ new_room_id = self._generate_room_id()
+
+ # Check whether the user has the power level to carry out the upgrade.
+ # `check_auth_rules_from_context` will check that they are in the room and have
+ # the required power level to send the tombstone event.
+ (
+ tombstone_event,
+ tombstone_context,
+ ) = await self.event_creation_handler.create_event(
+ requester,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
+ },
+ },
+ )
+ old_room_version = await self.store.get_room_version(old_room_id)
+ validate_event_for_room_version(old_room_version, tombstone_event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ old_room_version, tombstone_event, tombstone_context
+ )
+
# Upgrade the room
#
# If this user has sent multiple upgrade requests for the same room
@@ -206,19 +242,35 @@ class RoomCreationHandler:
self._upgrade_room,
requester,
old_room_id,
- new_version, # args for _upgrade_room
+ old_room, # args for _upgrade_room
+ new_room_id,
+ new_version,
+ tombstone_event,
+ tombstone_context,
)
return ret
async def _upgrade_room(
- self, requester: Requester, old_room_id: str, new_version: RoomVersion
+ self,
+ requester: Requester,
+ old_room_id: str,
+ old_room: Dict[str, Any],
+ new_room_id: str,
+ new_version: RoomVersion,
+ tombstone_event: EventBase,
+ tombstone_context: synapse.events.snapshot.EventContext,
) -> str:
"""
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
- new_versions: the version to upgrade the room to
+ old_room: a dict containing room information for the room to be replaced,
+ as returned by `RoomWorkerStore.get_room`.
+ new_room_id: the id of the replacement room
+ new_version: the version to upgrade the room to
+ tombstone_event: the tombstone event to send to the old room
+ tombstone_context: the context for the tombstone event
Raises:
ShadowBanError if the requester is shadow-banned.
@@ -226,40 +278,15 @@ class RoomCreationHandler:
user_id = requester.user.to_string()
assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,)
- # start by allocating a new room id
- r = await self.store.get_room(old_room_id)
- if r is None:
- raise NotFoundError("Unknown room id %s" % (old_room_id,))
- new_room_id = await self._generate_room_id(
- creator_id=user_id,
- is_public=r["is_public"],
- room_version=new_version,
- )
-
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
- # we create and auth the tombstone event before properly creating the new
- # room, to check our user has perms in the old room.
- (
- tombstone_event,
- tombstone_context,
- ) = await self.event_creation_handler.create_event(
- requester,
- {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- },
- },
- )
- old_room_version = await self.store.get_room_version(old_room_id)
- validate_event_for_room_version(old_room_version, tombstone_event)
- await self._event_auth_handler.check_auth_rules_from_context(
- old_room_version, tombstone_event, tombstone_context
+ # create the new room. may raise a `StoreError` in the exceedingly unlikely
+ # event of a room ID collision.
+ await self.store.store_room(
+ room_id=new_room_id,
+ room_creator_user_id=user_id,
+ is_public=old_room["is_public"],
+ room_version=new_version,
)
await self.clone_existing_room(
@@ -277,7 +304,10 @@ class RoomCreationHandler:
context=tombstone_context,
)
- old_room_state = await tombstone_context.get_current_state_ids()
+ state_filter = StateFilter.from_types(
+ [(EventTypes.CanonicalAlias, ""), (EventTypes.PowerLevels, "")]
+ )
+ old_room_state = await tombstone_context.get_current_state_ids(state_filter)
# We know the tombstone event isn't an outlier so it has current state.
assert old_room_state is not None
@@ -401,7 +431,7 @@ class RoomCreationHandler:
requester: the user requesting the upgrade
old_room_id : the id of the room to be replaced
new_room_id: the id to give the new room (should already have been
- created with _gemerate_room_id())
+ created with _generate_room_id())
new_room_version: the new room version to use
tombstone_event_id: the ID of the tombstone event in the old room.
"""
@@ -439,21 +469,22 @@ class RoomCreationHandler:
(EventTypes.RoomAvatar, ""),
(EventTypes.RoomEncryption, ""),
(EventTypes.ServerACL, ""),
- (EventTypes.RelatedGroups, ""),
(EventTypes.PowerLevels, ""),
]
- # If the old room was a space, copy over the room type and the rooms in
- # the space.
- if (
- old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
- == RoomTypes.SPACE
- ):
- creation_content[EventContentFields.ROOM_TYPE] = RoomTypes.SPACE
- types_to_copy.append((EventTypes.SpaceChild, None))
+ # Copy the room type as per MSC3818.
+ room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
+ if room_type is not None:
+ creation_content[EventContentFields.ROOM_TYPE] = room_type
- old_room_state_ids = await self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types(types_to_copy)
+ # If the old room was a space, copy over the rooms in the space.
+ if room_type == RoomTypes.SPACE:
+ types_to_copy.append((EventTypes.SpaceChild, None))
+
+ old_room_state_ids = (
+ await self._storage_controllers.state.get_current_state_ids(
+ old_room_id, StateFilter.from_types(types_to_copy)
+ )
)
# map from event_id to BaseEvent
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
@@ -530,8 +561,10 @@ class RoomCreationHandler:
)
# Transfer membership events
- old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
+ old_room_member_state_ids = (
+ await self._storage_controllers.state.get_current_state_ids(
+ old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
+ )
)
# map from event_id to BaseEvent
@@ -725,6 +758,21 @@ class RoomCreationHandler:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
+ if ":" in config["room_alias_name"]:
+ # Prevent someone from trying to pass in a full alias here.
+ # Note that it's permissible for a room alias to have multiple
+ # hash symbols at the start (notably bridged over from IRC, too),
+ # but the first colon in the alias is defined to separate the local
+ # part from the server name.
+ # (remember server names can contain port numbers, also separated
+ # by a colon. But under no circumstances should the local part be
+ # allowed to contain a colon!)
+ raise SynapseError(
+ 400,
+ "':' is not permitted in the room alias name. "
+ "Please note this expects a local part — 'wombat', not '#wombat:example.com'.",
+ )
+
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
mapping = await self.store.get_association_from_room_alias(room_alias)
@@ -778,7 +826,7 @@ class RoomCreationHandler:
visibility = config.get("visibility", "private")
is_public = visibility == "public"
- room_id = await self._generate_room_id(
+ room_id = await self._generate_and_create_room_id(
creator_id=user_id,
is_public=is_public,
room_version=room_version,
@@ -1042,9 +1090,19 @@ class RoomCreationHandler:
for invitee in invite_list:
power_level_content["users"][invitee] = 100
- # Power levels overrides are defined per chat preset
+ # If the user supplied a preset name e.g. "private_chat",
+ # we apply that preset
power_level_content.update(config["power_level_content_override"])
+ # If the server config contains default_power_level_content_override,
+ # and that contains information for this room preset, apply it.
+ if self._default_power_level_content_override:
+ override = self._default_power_level_content_override.get(preset_config)
+ if override is not None:
+ power_level_content.update(override)
+
+ # Finally, if the user supplied specific permissions for this room,
+ # apply those.
if power_level_content_override:
power_level_content.update(power_level_content_override)
@@ -1090,7 +1148,26 @@ class RoomCreationHandler:
return last_sent_stream_id
- async def _generate_room_id(
+ def _generate_room_id(self) -> str:
+ """Generates a random room ID.
+
+ Room IDs look like "!opaque_id:domain" and are case-sensitive as per the spec
+ at https://spec.matrix.org/v1.2/appendices/#room-ids-and-event-ids.
+
+ Does not check for collisions with existing rooms or prevent future calls from
+ returning the same room ID. To ensure the uniqueness of a new room ID, use
+ `_generate_and_create_room_id` instead.
+
+ Synapse's room IDs are 18 [a-zA-Z] characters long, which comes out to around
+ 102 bits.
+
+ Returns:
+ A random room ID of the form "!opaque_id:domain".
+ """
+ random_string = stringutils.random_string(18)
+ return RoomID(random_string, self.hs.hostname).to_string()
+
+ async def _generate_and_create_room_id(
self,
creator_id: str,
is_public: bool,
@@ -1101,8 +1178,7 @@ class RoomCreationHandler:
attempts = 0
while attempts < 5:
try:
- random_string = stringutils.random_string(18)
- gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
+ gen_room_id = self._generate_room_id()
await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
@@ -1120,8 +1196,8 @@ class RoomContextHandler:
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._relations_handler = hs.get_relations_handler()
async def get_event_context(
@@ -1164,7 +1240,10 @@ class RoomContextHandler:
if use_admin_priviledge:
return events
return await filter_events_for_client(
- self.storage, user.to_string(), events, is_peeking=is_peeking
+ self._storage_controllers,
+ user.to_string(),
+ events,
+ is_peeking=is_peeking,
)
event = await self.store.get_event(
@@ -1221,7 +1300,7 @@ class RoomContextHandler:
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = await self.state_store.get_state_for_events(
+ state = await self._state_storage_controller.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -1239,10 +1318,10 @@ class RoomContextHandler:
events_after=events_after,
state=await filter_evts(state_events),
aggregations=aggregations,
- start=await token.copy_and_replace("room_key", results.start).to_string(
- self.store
- ),
- end=await token.copy_and_replace("room_key", results.end).to_string(
+ start=await token.copy_and_replace(
+ StreamKeyType.ROOM, results.start
+ ).to_string(self.store),
+ end=await token.copy_and_replace(StreamKeyType.ROOM, results.end).to_string(
self.store
),
)
@@ -1254,6 +1333,7 @@ class TimestampLookupHandler:
self.store = hs.get_datastores().main
self.state_handler = hs.get_state_handler()
self.federation_client = hs.get_federation_client()
+ self._storage_controllers = hs.get_storage_controllers()
async def get_event_for_timestamp(
self,
@@ -1327,7 +1407,9 @@ class TimestampLookupHandler:
)
# Find other homeservers from the given state in the room
- curr_state = await self.state_handler.get_current_state(room_id)
+ curr_state = await self._storage_controllers.state.get_current_state(
+ room_id
+ )
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains if domain != self.server_name
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 29de7e5b..1414e575 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -17,7 +17,7 @@ class RoomBatchHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
- self.state_store = hs.get_storage().state
+ self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -53,6 +53,7 @@ class RoomBatchHandler:
# We want to use the successor event depth so they appear after `prev_event` because
# it has a larger `depth` but before the successor event because the `stream_ordering`
# is negative before the successor event.
+ assert most_recent_prev_event_id is not None
successor_event_ids = await self.store.get_successor_events(
most_recent_prev_event_id
)
@@ -139,7 +140,8 @@ class RoomBatchHandler:
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
- prev_state_map = await self.state_store.get_state_ids_for_event(
+ assert most_recent_event_id is not None
+ prev_state_map = await self._state_storage_controller.get_state_ids_for_event(
most_recent_event_id
)
# List of state event ID's
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index f3577b5d..183d4ae3 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.hs = hs
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[
@@ -274,7 +275,7 @@ class RoomListHandler:
if aliases:
result["aliases"] = aliases
- current_state_ids = await self.store.get_current_state_ids(
+ current_state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 802e57c4..d1199a06 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -38,6 +38,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
Requester,
@@ -67,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
@@ -362,7 +364,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical=historical,
)
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.Member, None)])
+ )
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -991,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
- current_state_ids = await self.store.get_current_state_ids(room_id)
+ current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+ room_id
+ )
# If restricted join rules are not being used, a local join can always
# be used.
@@ -1078,17 +1084,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Transfer alias mappings in the room directory
await self.store.update_aliases_for_room(old_room_id, room_id)
- # Check if any groups we own contain the predecessor room
- local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
- for group_id in local_group_ids:
- # Add new the new room to those groups
- await self.store.add_room_to_group(
- group_id, room_id, old_room is not None and old_room["is_public"]
- )
-
- # Remove the old room from those groups
- await self.store.remove_room_from_group(group_id, old_room_id)
-
async def copy_user_state_on_room_upgrade(
self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
) -> None:
@@ -1160,7 +1155,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
else:
requester = types.create_requester(target_user)
- prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types([(EventTypes.GuestAccess, None)])
+ )
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = await self._can_guest_join(prev_state_ids)
@@ -1404,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
- room_state = await self.state_handler.get_current_state(room_id)
+ room_state = await self._storage_controllers.state.get_current_state(
+ room_id,
+ StateFilter.from_types(
+ [
+ (EventTypes.Member, user.to_string()),
+ (EventTypes.CanonicalAlias, ""),
+ (EventTypes.Name, ""),
+ (EventTypes.Create, ""),
+ (EventTypes.JoinRules, ""),
+ (EventTypes.RoomAvatar, ""),
+ ]
+ ),
+ )
inviter_display_name = ""
inviter_avatar_url = ""
@@ -1800,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
- member = await self.state_handler.get_current_state(
+ member = await self._storage_controllers.state.get_current_state_event(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index ff24ec80..13098f56 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -90,6 +90,7 @@ class RoomSummaryHandler:
def __init__(self, hs: "HomeServer"):
self._event_auth_handler = hs.get_event_auth_handler()
self._store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
@@ -537,7 +538,7 @@ class RoomSummaryHandler:
Returns:
True if the room is accessible to the requesting user or server.
"""
- state_ids = await self._store.get_current_state_ids(room_id)
+ state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
# If there's no state for the room, it isn't known.
if not state_ids:
@@ -562,8 +563,13 @@ class RoomSummaryHandler:
if join_rules_event_id:
join_rules_event = await self._store.get_event(join_rules_event_id)
join_rule = join_rules_event.content.get("join_rule")
- if join_rule == JoinRules.PUBLIC or (
- room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+ if (
+ join_rule == JoinRules.PUBLIC
+ or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
+ or (
+ room_version.msc3787_knock_restricted_join_rule
+ and join_rule == JoinRules.KNOCK_RESTRICTED
+ )
):
return True
@@ -657,7 +663,8 @@ class RoomSummaryHandler:
# The API doesn't return the room version so assume that a
# join rule of knock is valid.
if (
- room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
+ room.get("join_rule")
+ in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED)
or room.get("world_readable") is True
):
return True
@@ -696,7 +703,9 @@ class RoomSummaryHandler:
# there should always be an entry
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
- current_state_ids = await self._store.get_current_state_ids(room_id)
+ current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+ room_id
+ )
create_event = await self._store.get_event(
current_state_ids[(EventTypes.Create, "")]
)
@@ -708,9 +717,6 @@ class RoomSummaryHandler:
"canonical_alias": stats["canonical_alias"],
"num_joined_members": stats["joined_members"],
"avatar_url": stats["avatar"],
- # plural join_rules is a documentation error but kept for historical
- # purposes. Should match /publicRooms.
- "join_rules": stats["join_rules"],
"join_rule": stats["join_rules"],
"world_readable": (
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
@@ -757,7 +763,9 @@ class RoomSummaryHandler:
"""
# look for child rooms/spaces.
- current_state_ids = await self._store.get_current_state_ids(room_id)
+ current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+ room_id
+ )
events = await self._store.get_events_as_list(
[
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 5619f8f5..bcab98c6 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -24,7 +24,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -55,8 +55,8 @@ class SearchHandler:
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.auth = hs.get_auth()
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
@@ -348,7 +348,7 @@ class SearchHandler:
state_results = {}
if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}:
- state = await self.state_handler.get_current_state(room_id)
+ state = await self._storage_controllers.state.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = await self._relations_handler.get_bundled_aggregations(
@@ -460,7 +460,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
+ self._storage_controllers, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@@ -559,7 +559,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
+ self._storage_controllers, user.to_string(), filtered_events
)
room_events.extend(events)
@@ -644,22 +644,22 @@ class SearchHandler:
)
events_before = await filter_events_for_client(
- self.storage, user.to_string(), res.events_before
+ self._storage_controllers, user.to_string(), res.events_before
)
events_after = await filter_events_for_client(
- self.storage, user.to_string(), res.events_after
+ self._storage_controllers, user.to_string(), res.events_after
)
context: JsonDict = {
"events_before": events_before,
"events_after": events_after,
"start": await now_token.copy_and_replace(
- "room_key", res.start
+ StreamKeyType.ROOM, res.start
+ ).to_string(self.store),
+ "end": await now_token.copy_and_replace(
+ StreamKeyType.ROOM, res.end
).to_string(self.store),
- "end": await now_token.copy_and_replace("room_key", res.end).to_string(
- self.store
- ),
}
if include_profile:
@@ -677,7 +677,7 @@ class SearchHandler:
[(EventTypes.Member, sender) for sender in senders]
)
- state = await self.state_store.get_state_for_event(
+ state = await self._state_storage_controller.get_state_for_event(
last_event_id, state_filter
)
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 436cd971..f45e06eb 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -40,6 +40,7 @@ class StatsHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self.server_name = hs.hostname
self.clock = hs.get_clock()
@@ -105,7 +106,10 @@ class StatsHandler:
logger.debug(
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
)
- max_pos, deltas = await self.store.get_current_state_deltas(
+ (
+ max_pos,
+ deltas,
+ ) = await self._storage_controllers.state.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2c555a66..b4ead79f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.types import (
Requester,
RoomStreamToken,
StateMap,
+ StreamKeyType,
StreamToken,
UserID,
)
@@ -165,16 +166,6 @@ class KnockedSyncResult:
return True
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class GroupsSyncResult:
- join: JsonDict
- invite: JsonDict
- leave: JsonDict
-
- def __bool__(self) -> bool:
- return bool(self.join or self.invite or self.leave)
-
-
@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
@@ -205,7 +196,6 @@ class SyncResult:
for this device
device_unused_fallback_key_types: List of key types that have an unused fallback
key
- groups: Group updates, if any
"""
next_batch: StreamToken
@@ -219,7 +209,6 @@ class SyncResult:
device_lists: DeviceListUpdates
device_one_time_keys_count: JsonDict
device_unused_fallback_key_types: List[str]
- groups: Optional[GroupsSyncResult]
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -235,7 +224,6 @@ class SyncResult:
or self.account_data
or self.to_device
or self.device_lists
- or self.groups
)
@@ -250,8 +238,8 @@ class SyncHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
- self.storage = hs.get_storage()
- self.state_store = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
@@ -410,10 +398,10 @@ class SyncHandler:
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result
- async def push_rules_for_user(self, user: UserID) -> JsonDict:
+ async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
user_id = user.to_string()
- rules = await self.store.get_push_rules_for_user(user_id)
- rules = format_push_rules_for_user(user, rules)
+ rules_raw = await self.store.get_push_rules_for_user(user_id)
+ rules = format_push_rules_for_user(user, rules_raw)
return rules
async def ephemeral_by_room(
@@ -449,7 +437,7 @@ class SyncHandler:
room_ids=room_ids,
is_guest=sync_config.is_guest,
)
- now_token = now_token.copy_and_replace("typing_key", typing_key)
+ now_token = now_token.copy_and_replace(StreamKeyType.TYPING, typing_key)
ephemeral_by_room: JsonDict = {}
@@ -471,7 +459,7 @@ class SyncHandler:
room_ids=room_ids,
is_guest=sync_config.is_guest,
)
- now_token = now_token.copy_and_replace("receipt_key", receipt_key)
+ now_token = now_token.copy_and_replace(StreamKeyType.RECEIPT, receipt_key)
for event in receipts:
room_id = event["room_id"]
@@ -518,13 +506,15 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents):
- current_state_ids_map = await self.store.get_current_state_ids(
- room_id
+ current_state_ids_map = (
+ await self._state_storage_controller.get_current_state_ids(
+ room_id
+ )
)
current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@@ -537,7 +527,9 @@ class SyncHandler:
prev_batch_token = now_token
if recents:
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+ prev_batch_token = now_token.copy_and_replace(
+ StreamKeyType.ROOM, room_key
+ )
return TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=False
@@ -584,13 +576,16 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
- current_state_ids_map = await self.store.get_current_state_ids(
- room_id
+ # FIXME(faster_joins): We use the partial state here as
+ # we don't want to block `/sync` on finishing a lazy join.
+ # Is this the correct way of doing it?
+ current_state_ids_map = (
+ await self.store.get_partial_current_state_ids(room_id)
)
current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@@ -611,7 +606,7 @@ class SyncHandler:
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+ prev_batch_token = now_token.copy_and_replace(StreamKeyType.ROOM, room_key)
# Don't bother to bundle aggregations if the timeline is unlimited,
# as clients will have all the necessary information.
@@ -631,21 +626,32 @@ class SyncHandler:
)
async def get_state_after_event(
- self, event: EventBase, state_filter: Optional[StateFilter] = None
+ self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the room state after the given event
Args:
- event: event of interest
+ event_id: event of interest
state_filter: The state filter used to fetch state from the database.
"""
- state_ids = await self.state_store.get_state_ids_for_event(
- event.event_id, state_filter=state_filter or StateFilter.all()
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
+ event_id, state_filter=state_filter or StateFilter.all()
)
- if event.is_state():
+
+ # using get_metadata_for_events here (instead of get_event) sidesteps an issue
+ # with redactions: if `event_id` is a redaction event, and we don't have the
+ # original (possibly because it got purged), get_event will refuse to return
+ # the redaction event, which isn't terribly helpful here.
+ #
+ # (To be fair, in that case we could assume it's *not* a state event, and
+ # therefore we don't need to worry about it. But still, it seems cleaner just
+ # to pull the metadata.)
+ m = (await self.store.get_metadata_for_events([event_id]))[event_id]
+ if m.state_key is not None and m.rejection_reason is None:
state_ids = dict(state_ids)
- state_ids[(event.type, event.state_key)] = event.event_id
+ state_ids[(m.event_type, m.state_key)] = event_id
+
return state_ids
async def get_state_at(
@@ -664,14 +670,14 @@ class SyncHandler:
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time.
- last_event = await self.store.get_last_event_in_room_before_stream_ordering(
+ last_event_id = await self.store.get_last_event_in_room_before_stream_ordering(
room_id,
end_token=stream_position.room_key,
)
- if last_event:
+ if last_event_id:
state = await self.get_state_after_event(
- last_event, state_filter=state_filter or StateFilter.all()
+ last_event_id, state_filter=state_filter or StateFilter.all()
)
else:
@@ -720,7 +726,7 @@ class SyncHandler:
return None
last_event = last_events[-1]
- state_ids = await self.state_store.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -898,12 +904,16 @@ class SyncHandler:
if full_state:
if batch:
- current_state_ids = await self.state_store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter
+ current_state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[-1].event_id, state_filter=state_filter
+ )
)
- state_ids = await self.state_store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[0].event_id, state_filter=state_filter
+ )
)
else:
@@ -923,7 +933,7 @@ class SyncHandler:
elif batch.limited:
if batch:
state_at_timeline_start = (
- await self.state_store.get_state_ids_for_event(
+ await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
)
@@ -957,8 +967,10 @@ class SyncHandler:
)
if batch:
- current_state_ids = await self.state_store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter
+ current_state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[-1].event_id, state_filter=state_filter
+ )
)
else:
# Its not clear how we get here, but empirically we do
@@ -988,7 +1000,7 @@ class SyncHandler:
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
- state_ids = await self.state_store.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
@@ -1154,10 +1166,6 @@ class SyncHandler:
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
- if self.hs_config.experimental.groups_enabled:
- logger.debug("Fetching group data")
- await self._generate_sync_entry_for_groups(sync_result_builder)
-
num_events = 0
# debug for https://github.com/matrix-org/synapse/issues/9424
@@ -1181,57 +1189,11 @@ class SyncHandler:
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
- groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
- @measure_func("_generate_sync_entry_for_groups")
- async def _generate_sync_entry_for_groups(
- self, sync_result_builder: "SyncResultBuilder"
- ) -> None:
- user_id = sync_result_builder.sync_config.user.to_string()
- since_token = sync_result_builder.since_token
- now_token = sync_result_builder.now_token
-
- if since_token and since_token.groups_key:
- results = await self.store.get_groups_changes_for_user(
- user_id, since_token.groups_key, now_token.groups_key
- )
- else:
- results = await self.store.get_all_groups_for_user(
- user_id, now_token.groups_key
- )
-
- invited = {}
- joined = {}
- left = {}
- for result in results:
- membership = result["membership"]
- group_id = result["group_id"]
- gtype = result["type"]
- content = result["content"]
-
- if membership == "join":
- if gtype == "membership":
- # TODO: Add profile
- content.pop("membership", None)
- joined[group_id] = content["content"]
- else:
- joined.setdefault(group_id, {})[gtype] = content
- elif membership == "invite":
- if gtype == "membership":
- content.pop("membership", None)
- invited[group_id] = content["content"]
- else:
- if gtype == "membership":
- left[group_id] = content["content"]
-
- sync_result_builder.groups = GroupsSyncResult(
- join=joined, invite=invited, leave=left
- )
-
@measure_func("_generate_sync_entry_for_device_list")
async def _generate_sync_entry_for_device_list(
self,
@@ -1398,7 +1360,7 @@ class SyncHandler:
now_token.to_device_key,
)
sync_result_builder.now_token = now_token.copy_and_replace(
- "to_device_key", stream_id
+ StreamKeyType.TO_DEVICE, stream_id
)
sync_result_builder.to_device = messages
else:
@@ -1503,7 +1465,7 @@ class SyncHandler:
)
assert presence_key
sync_result_builder.now_token = now_token.copy_and_replace(
- "presence_key", presence_key
+ StreamKeyType.PRESENCE, presence_key
)
extra_users_ids = set(newly_joined_or_invited_users)
@@ -1826,7 +1788,7 @@ class SyncHandler:
# stream token as it'll only be used in the context of this
# room. (c.f. the docstring of `to_room_stream_token`).
leave_token = since_token.copy_and_replace(
- "room_key", leave_position.to_room_stream_token()
+ StreamKeyType.ROOM, leave_position.to_room_stream_token()
)
# If this is an out of band message, like a remote invite
@@ -1875,7 +1837,9 @@ class SyncHandler:
if room_entry:
events, start_key = room_entry
- prev_batch_token = now_token.copy_and_replace("room_key", start_key)
+ prev_batch_token = now_token.copy_and_replace(
+ StreamKeyType.ROOM, start_key
+ )
entry = RoomSyncResultBuilder(
room_id=room_id,
@@ -1972,7 +1936,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
- "room_key", RoomStreamToken(None, event.stream_ordering)
+ StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(
@@ -2328,7 +2292,6 @@ class SyncResultBuilder:
invited
knocked
archived
- groups
to_device
"""
@@ -2344,7 +2307,6 @@ class SyncResultBuilder:
invited: List[InvitedSyncResult] = attr.Factory(list)
knocked: List[KnockedSyncResult] = attr.Factory(list)
archived: List[ArchivedSyncResult] = attr.Factory(list)
- groups: Optional[GroupsSyncResult] = None
to_device: List[JsonDict] = attr.Factory(list)
def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6854428b..d104ea07 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
+from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import (
@@ -25,7 +26,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -58,6 +59,7 @@ class FollowerTypingHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -68,7 +70,7 @@ class FollowerTypingHandler:
if hs.get_instance_name() not in hs.config.worker.writers.typing:
hs.get_federation_registry().register_instances_for_edu(
- "m.typing",
+ EduTypes.TYPING,
hs.config.worker.writers.typing,
)
@@ -130,7 +132,6 @@ class FollowerTypingHandler:
return
try:
- users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@@ -138,12 +139,15 @@ class FollowerTypingHandler:
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
- for domain in {get_domain_from_id(u) for u in users}:
+ hosts = await self._storage_controllers.state.get_current_hosts_in_room(
+ member.room_id
+ )
+ for domain in hosts:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
destination=domain,
- edu_type="m.typing",
+ edu_type=EduTypes.TYPING,
content={
"room_id": member.room_id,
"user_id": member.user_id,
@@ -218,7 +222,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self.hs = hs
- hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
+ hs.get_federation_registry().register_edu_handler(
+ EduTypes.TYPING, self._recv_edu
+ )
hs.get_distributor().observe("user_left_room", self.user_left_room)
@@ -382,7 +388,7 @@ class TypingWriterHandler(FollowerTypingHandler):
)
self.notifier.on_new_event(
- "typing_key", self._latest_room_serial, rooms=[member.room_id]
+ StreamKeyType.TYPING, self._latest_room_serial, rooms=[member.room_id]
)
async def get_all_typing_updates(
@@ -458,7 +464,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
def _make_event_for(self, room_id: str) -> JsonDict:
typing = self.get_typing_handler()._room_typing[room_id]
return {
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": room_id,
"content": {"user_ids": list(typing)},
}
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 74f7fdfe..8c3c52e1 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler):
super().__init__(hs)
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
@@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug(
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
)
- max_pos, deltas = await self.store.get_current_state_deltas(
+ (
+ max_pos,
+ deltas,
+ ) = await self._storage_controllers.state.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8310fb46..084d0a5b 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
+ IDelayedCall,
IHostResolution,
IReactorPluggableNameResolver,
+ IReactorTime,
IResolutionReceiver,
ITCPTransport,
)
@@ -121,13 +123,15 @@ def check_against_blacklist(
_EPSILON = 0.00000001
-def _make_scheduler(reactor):
+def _make_scheduler(
+ reactor: IReactorTime,
+) -> Callable[[Callable[[], object]], IDelayedCall]:
"""Makes a schedular suitable for a Cooperator using the given reactor.
(This is effectively just a copy from `twisted.internet.task`)
"""
- def _scheduler(x):
+ def _scheduler(x: Callable[[], object]) -> IDelayedCall:
return reactor.callLater(_EPSILON, x)
return _scheduler
@@ -348,7 +352,7 @@ class SimpleHttpClient:
# XXX: The justification for using the cache factor here is that larger instances
# will need both more cache and more connections.
# Still, this should probably be a separate dial
- pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
+ pool.maxPersistentPerHost = max(int(100 * hs.config.caches.global_factor), 5)
pool.cachedConnectionTimeout = 2 * 60
self.agent: IAgent = ProxyAgent(
@@ -775,7 +779,7 @@ class SimpleHttpClient:
)
-def _timeout_to_request_timed_out_error(f: Failure):
+def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
# The TCP connection has its own timeout (set by the 'connectTimeout' param
# on the Agent), which raises twisted_error.TimeoutError exception.
@@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
- def _maybe_fail(self):
+ def _maybe_fail(self) -> None:
"""
Report a max size exceed error and disconnect the first time this is called.
"""
@@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
Do not use this since it allows an attacker to intercept your communications.
"""
- def __init__(self):
+ def __init__(self) -> None:
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: False)
def getContext(self, hostname=None, port=None):
return self._context
- def creatorForNetloc(self, hostname, port):
+ def creatorForNetloc(self, hostname: bytes, port: int):
return self
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 203e995b..23a60af1 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -14,15 +14,22 @@
import base64
import logging
-from typing import Optional
+from typing import Optional, Union
import attr
from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+ IAddress,
+ IConnector,
+ IProtocol,
+ IReactorCore,
+ IStreamClientEndpoint,
+)
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
+from twisted.python.failure import Failure
from twisted.web import http
logger = logging.getLogger(__name__)
@@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint:
self._port = port
self._proxy_creds = proxy_creds
- def __repr__(self):
+ def __repr__(self) -> str:
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
# Mypy encounters a false positive here: it complains that ClientFactory
# is incompatible with IProtocolFactory. But ClientFactory inherits from
# Factory, which implements IProtocolFactory. So I think this is a bug
# in mypy-zope.
- def connect(self, protocolFactory: ClientFactory): # type: ignore[override]
+ def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]": # type: ignore[override]
f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._proxy_creds
)
@@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.proxy_creds = proxy_creds
self.on_connection: "defer.Deferred[None]" = defer.Deferred()
- def startedConnecting(self, connector):
+ def startedConnecting(self, connector: IConnector) -> None:
return self.wrapped_factory.startedConnecting(connector)
- def buildProtocol(self, addr):
+ def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
if wrapped_protocol is None:
raise TypeError("buildProtocol produced None instead of a Protocol")
@@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.proxy_creds,
)
- def clientConnectionFailed(self, connector, reason):
+ def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy failed: %s", reason)
if not self.on_connection.called:
self.on_connection.errback(reason)
return self.wrapped_factory.clientConnectionFailed(connector, reason)
- def clientConnectionLost(self, connector, reason):
+ def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy lost: %s", reason)
if not self.on_connection.called:
self.on_connection.errback(reason)
@@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol):
)
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
- def connectionMade(self):
+ def connectionMade(self) -> None:
self.http_setup_client.makeConnection(self.transport)
- def connectionLost(self, reason=connectionDone):
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
if self.wrapped_protocol.connected:
self.wrapped_protocol.connectionLost(reason)
@@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol):
if not self.connected_deferred.called:
self.connected_deferred.errback(reason)
- def proxyConnected(self, _):
+ def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
self.wrapped_protocol.makeConnection(self.transport)
self.connected_deferred.callback(self.wrapped_protocol)
@@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol):
if buf:
self.wrapped_protocol.dataReceived(buf)
- def dataReceived(self, data: bytes):
+ def dataReceived(self, data: bytes) -> None:
# if we've set up the HTTP protocol, we can send the data there
if self.wrapped_protocol.connected:
return self.wrapped_protocol.dataReceived(data)
@@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.proxy_creds = proxy_creds
self.on_connected: "defer.Deferred[None]" = defer.Deferred()
- def connectionMade(self):
+ def connectionMade(self) -> None:
logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
@@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.endHeaders()
- def handleStatus(self, version: bytes, status: bytes, message: bytes):
+ def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200":
raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
- def handleEndHeaders(self):
+ def handleEndHeaders(self) -> None:
logger.debug("End Headers")
self.on_connected.callback(None)
- def handleResponse(self, body):
+ def handleResponse(self, body: bytes) -> None:
pass
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index a8a520f8..2f0177f1 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:
self._srv_resolver = srv_resolver
- def endpointForURI(self, parsed_uri: URI):
+ def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
return MatrixHostnameEndpoint(
self._reactor,
self._proxy_reactor,
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index f68646fd..de0e882b 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
import logging
import random
import time
-from typing import Callable, Dict, List
+from typing import Any, Callable, Dict, List
import attr
@@ -109,7 +109,7 @@ class SrvResolver:
def __init__(
self,
- dns_client=client,
+ dns_client: Any = client,
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time,
):
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 43f21404..71b685fa 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class WellKnownLookupResult:
- delegated_server = attr.ib()
+ delegated_server: Optional[bytes]
class WellKnownResolver:
@@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
class _FetchWellKnownFailure(Exception):
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
# a temporary failure.
- temporary = attr.ib()
+ temporary: bool = attr.ib()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c2ec3caa..776ed43f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -23,6 +23,8 @@ from http import HTTPStatus
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
+ Any,
+ BinaryIO,
Callable,
Dict,
Generic,
@@ -44,7 +46,7 @@ from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
-from twisted.internet.task import _EPSILON, Cooperator
+from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
@@ -58,11 +60,13 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
BodyExceededMaxSize,
ByteWriteable,
+ _make_scheduler,
encode_query_args,
read_body_with_max_size,
)
@@ -88,9 +92,6 @@ incoming_responses_counter = Counter(
"synapse_http_matrixfederationclient_responses", "", ["method", "code"]
)
-# a federation response can be rather large (eg a big state_ids is 50M or so), so we
-# need a generous limit here.
-MAX_RESPONSE_SIZE = 100 * 1024 * 1024
MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3
@@ -112,6 +113,11 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
the content type doesn't match we fail the request.
"""
+ # a federation response can be rather large (eg a big state_ids is 50M or so), so we
+ # need a generous limit here.
+ MAX_RESPONSE_SIZE: int = 100 * 1024 * 1024
+ """The largest response this parser will accept."""
+
@abc.abstractmethod
def finish(self) -> T:
"""Called when response has finished streaming and the parser should
@@ -181,7 +187,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
CONTENT_TYPE = "application/json"
- def __init__(self):
+ def __init__(self) -> None:
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
@@ -199,7 +205,6 @@ async def _handle_response(
response: IResponse,
start_ms: int,
parser: ByteParser[T],
- max_response_size: Optional[int] = None,
) -> T:
"""
Reads the body of a response with a timeout and sends it to a parser
@@ -211,16 +216,14 @@ async def _handle_response(
response: response to the request
start_ms: Timestamp when request was made
parser: The parser for the response
- max_response_size: The maximum size to read from the response, if None
- uses the default.
Returns:
The parsed response
"""
- if max_response_size is None:
- max_response_size = MAX_RESPONSE_SIZE
+ max_response_size = parser.MAX_RESPONSE_SIZE
+ finished = False
try:
check_content_type_is(response.headers, parser.CONTENT_TYPE)
@@ -229,6 +232,7 @@ async def _handle_response(
length = await make_deferred_yieldable(d)
+ finished = True
value = parser.finish()
except BodyExceededMaxSize as e:
# The response was too big.
@@ -236,7 +240,7 @@ async def _handle_response(
"{%s} [%s] JSON response exceeded max size %i - %s %s",
request.txn_id,
request.destination,
- MAX_RESPONSE_SIZE,
+ max_response_size,
request.method,
request.uri.decode("ascii"),
)
@@ -279,6 +283,15 @@ async def _handle_response(
e,
)
raise
+ finally:
+ if not finished:
+ # There was an exception and we didn't `finish()` the parse.
+ # Let the parser know that it can free up any resources.
+ try:
+ parser.finish()
+ except Exception:
+ # Ignore any additional exceptions.
+ pass
time_taken_secs = reactor.seconds() - start_ms / 1000
@@ -299,7 +312,9 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""
- def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
+ def __init__(
+ self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
+ ):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file
@@ -317,7 +332,11 @@ class MatrixFederationHttpClient:
requests.
"""
- def __init__(self, hs: "HomeServer", tls_client_options_factory):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ ):
self.hs = hs
self.signing_key = hs.signing_key
self.server_name = hs.hostname
@@ -348,10 +367,7 @@ class MatrixFederationHttpClient:
self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60
- def schedule(x):
- self.reactor.callLater(_EPSILON, x)
-
- self._cooperator = Cooperator(scheduler=schedule)
+ self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))
self._sleeper = AwakenableSleeper(self.reactor)
@@ -364,7 +380,7 @@ class MatrixFederationHttpClient:
self,
request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False,
- **send_request_args,
+ **send_request_args: Any,
) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
@@ -740,7 +756,7 @@ class MatrixFederationHttpClient:
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(
(
- 'X-Matrix origin=%s,key="%s",sig="%s",destination="%s"'
+ 'X-Matrix origin="%s",key="%s",sig="%s",destination="%s"'
% (
self.server_name,
key,
@@ -765,7 +781,6 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
- max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]:
...
@@ -783,7 +798,6 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser[T]] = None,
- max_response_size: Optional[int] = None,
) -> T:
...
@@ -800,7 +814,6 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
- max_response_size: Optional[int] = None,
):
"""Sends the specified json data using PUT
@@ -836,8 +849,6 @@ class MatrixFederationHttpClient:
enabled.
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
- max_response_size: The maximum size to read from the response, if None
- uses the default.
Returns:
Succeeds when we get a 2xx HTTP response. The
@@ -888,7 +899,6 @@ class MatrixFederationHttpClient:
response,
start_ms,
parser=parser,
- max_response_size=max_response_size,
)
return body
@@ -977,7 +987,6 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
- max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]:
...
@@ -992,7 +1001,6 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = ...,
try_trailing_slash_on_400: bool = ...,
parser: ByteParser[T] = ...,
- max_response_size: Optional[int] = ...,
) -> T:
...
@@ -1006,7 +1014,6 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
- max_response_size: Optional[int] = None,
):
"""GETs some json from the given host homeserver and path
@@ -1036,9 +1043,6 @@ class MatrixFederationHttpClient:
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
- max_response_size: The maximum size to read from the response. If None,
- uses the default.
-
Returns:
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -1083,7 +1087,6 @@ class MatrixFederationHttpClient:
response,
start_ms,
parser=parser,
- max_response_size=max_response_size,
)
return body
@@ -1159,7 +1162,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
- output_stream,
+ output_stream: BinaryIO,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
@@ -1250,10 +1253,10 @@ class MatrixFederationHttpClient:
return length, headers
-def _flatten_response_never_received(e):
+def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
reasons = ", ".join(
- _flatten_response_never_received(f.value) for f in e.reasons
+ _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
)
return "%s:[%s]" % (type(e).__name__, reasons)
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a16dde23..b2a50c91 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -245,7 +245,7 @@ def http_proxy_endpoint(
proxy: Optional[bytes],
reactor: IReactorCore,
tls_options_factory: Optional[IPolicyForHTTPS],
- **kwargs,
+ **kwargs: object,
) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
"""Parses an http proxy setting and returns an endpoint for the proxy
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 4886626d..2b6d1135 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -162,7 +162,7 @@ class RequestMetrics:
with _in_flight_requests_lock:
_in_flight_requests.add(self)
- def stop(self, time_sec, response_code, sent_bytes):
+ def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
with _in_flight_requests_lock:
_in_flight_requests.discard(self)
@@ -186,13 +186,13 @@ class RequestMetrics:
)
return
- response_code = str(response_code)
+ response_code_str = str(response_code)
- outgoing_responses_counter.labels(self.method, response_code).inc()
+ outgoing_responses_counter.labels(self.method, response_code_str).inc()
response_count.labels(self.method, self.name, tag).inc()
- response_timer.labels(self.method, self.name, tag, response_code).observe(
+ response_timer.labels(self.method, self.name, tag, response_code_str).observe(
time_sec - self.start_ts
)
@@ -221,7 +221,7 @@ class RequestMetrics:
# flight.
self.update_metrics()
- def update_metrics(self):
+ def update_metrics(self) -> None:
"""Updates the in flight metrics with values from this request."""
if not self.start_context:
logger.error(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 657bffcd..e3dcc3f3 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -33,6 +33,7 @@ from typing import (
Optional,
Pattern,
Tuple,
+ TypeVar,
Union,
)
@@ -92,6 +93,68 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
HTTP_STATUS_REQUEST_CANCELLED = 499
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+_cancellable_method_names = frozenset(
+ {
+ # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
+ # methods
+ "on_GET",
+ "on_PUT",
+ "on_POST",
+ "on_DELETE",
+ # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
+ # methods
+ "_async_render_GET",
+ "_async_render_PUT",
+ "_async_render_POST",
+ "_async_render_DELETE",
+ "_async_render_OPTIONS",
+ # `ReplicationEndpoint` methods
+ "_handle_request",
+ }
+)
+
+
+def cancellable(method: F) -> F:
+ """Marks a servlet method as cancellable.
+
+ Methods with this decorator will be cancelled if the client disconnects before we
+ finish processing the request.
+
+ During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
+ the method. The `cancel()` call will propagate down to the `Deferred` that is
+ currently being waited on. That `Deferred` will raise a `CancelledError`, which will
+ propagate up, as per normal exception handling.
+
+ Before applying this decorator to a new endpoint, you MUST recursively check
+ that all `await`s in the function are on `async` functions or `Deferred`s that
+ handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
+ premature logging context closure, to stuck requests, to database corruption.
+
+ Usage:
+ class SomeServlet(RestServlet):
+ @cancellable
+ async def on_GET(self, request: SynapseRequest) -> ...:
+ ...
+ """
+ if method.__name__ not in _cancellable_method_names and not any(
+ method.__name__.startswith(prefix) for prefix in _cancellable_method_names
+ ):
+ raise ValueError(
+ "@cancellable decorator can only be applied to servlet methods."
+ )
+
+ method.cancellable = True # type: ignore[attr-defined]
+ return method
+
+
+def is_method_cancellable(method: Callable[..., Any]) -> bool:
+ """Checks whether a servlet method has the `@cancellable` flag."""
+ return getattr(method, "cancellable", False)
+
+
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients."""
@@ -253,6 +316,9 @@ class HttpServer(Protocol):
If the regex contains groups these gets passed to the callback via
an unpacked tuple.
+ The callback may be marked with the `@cancellable` decorator, which will
+ cause request processing to be cancelled when clients disconnect early.
+
Args:
method: The HTTP method to listen to.
path_patterns: The regex used to match requests.
@@ -283,7 +349,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
def render(self, request: SynapseRequest) -> int:
"""This gets called by twisted every time someone sends us a request."""
- defer.ensureDeferred(self._async_render_wrapper(request))
+ request.render_deferred = defer.ensureDeferred(
+ self._async_render_wrapper(request)
+ )
return NOT_DONE_YET
@wrap_async_request_handler
@@ -319,6 +387,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
if method_handler:
+ request.is_render_cancellable = is_method_cancellable(method_handler)
+
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
@@ -479,6 +549,8 @@ class JsonResource(DirectServeJsonResource):
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
+ request.is_render_cancellable = is_method_cancellable(callback)
+
# Make sure we have an appropriate name for this handler in prometheus
# (rather than the default of JsonResource).
request.request_metrics.name = servlet_classname
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 0b85a57d..eeec74b7 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
+from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
from twisted.web.http import HTTPChannel
@@ -91,6 +92,14 @@ class SynapseRequest(Request):
# we can't yet create the logcontext, as we don't know the method.
self.logcontext: Optional[LoggingContext] = None
+ # The `Deferred` to cancel if the client disconnects early and
+ # `is_render_cancellable` is set. Expected to be set by `Resource.render`.
+ self.render_deferred: Optional["Deferred[None]"] = None
+ # A boolean indicating whether `render_deferred` should be cancelled if the
+ # client disconnects early. Expected to be set by the coroutine started by
+ # `Resource.render`, if rendering is asynchronous.
+ self.is_render_cancellable = False
+
global _next_request_seq
self.request_seq = _next_request_seq
_next_request_seq += 1
@@ -357,7 +366,21 @@ class SynapseRequest(Request):
{"event": "client connection lost", "reason": str(reason.value)}
)
- if not self._is_processing:
+ if self._is_processing:
+ if self.is_render_cancellable:
+ if self.render_deferred is not None:
+ # Throw a cancellation into the request processing, in the hope
+ # that it will finish up sooner than it normally would.
+ # The `self.processing()` context manager will call
+ # `_finished_processing()` when done.
+ with PreserveLoggingContext():
+ self.render_deferred.cancel()
+ else:
+ logger.error(
+ "Connection from client lost, but have no Deferred to "
+ "cancel even though the request is marked as cancellable."
+ )
+ else:
self._finished_processing()
def _started_processing(self, servlet_name: str) -> None:
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 475756f1..5a61b21e 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -31,7 +31,11 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
-from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+ IPushProducer,
+ IReactorTCP,
+ IStreamClientEndpoint,
+)
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
@@ -59,14 +63,14 @@ class LogProducer:
_buffer: Deque[logging.LogRecord]
_paused: bool = attr.ib(default=False, init=False)
- def pauseProducing(self):
+ def pauseProducing(self) -> None:
self._paused = True
- def stopProducing(self):
+ def stopProducing(self) -> None:
self._paused = True
self._buffer = deque()
- def resumeProducing(self):
+ def resumeProducing(self) -> None:
# If we're already producing, nothing to do.
self._paused = False
@@ -102,8 +106,8 @@ class RemoteHandler(logging.Handler):
host: str,
port: int,
maximum_buffer: int = 1000,
- level=logging.NOTSET,
- _reactor=None,
+ level: int = logging.NOTSET,
+ _reactor: Optional[IReactorTCP] = None,
):
super().__init__(level=level)
self.host = host
@@ -118,7 +122,7 @@ class RemoteHandler(logging.Handler):
if _reactor is None:
from twisted.internet import reactor
- _reactor = reactor
+ _reactor = reactor # type: ignore[assignment]
try:
ip = ip_address(self.host)
@@ -139,7 +143,7 @@ class RemoteHandler(logging.Handler):
self._stopping = False
self._connect()
- def close(self):
+ def close(self) -> None:
self._stopping = True
self._service.stopService()
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
index c0f12ecd..c88b8ae5 100644
--- a/synapse/logging/formatter.py
+++ b/synapse/logging/formatter.py
@@ -16,6 +16,8 @@
import logging
import traceback
from io import StringIO
+from types import TracebackType
+from typing import Optional, Tuple, Type
class LogFormatter(logging.Formatter):
@@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter):
where it was caught are logged).
"""
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def formatException(self, ei):
+ def formatException(
+ self,
+ ei: Tuple[
+ Optional[Type[BaseException]],
+ Optional[BaseException],
+ Optional[TracebackType],
+ ],
+ ) -> str:
sio = StringIO()
(typ, val, tb) = ei
diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py
index 478b5274..dec2a2c3 100644
--- a/synapse/logging/handlers.py
+++ b/synapse/logging/handlers.py
@@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
)
self._flushing_thread.start()
- def on_reactor_running():
+ def on_reactor_running() -> None:
self._reactor_started = True
reactor_to_use: IReactorCore
@@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
else:
return True
- def _flush_periodically(self):
+ def _flush_periodically(self) -> None:
"""
Whilst this handler is active, flush the handler periodically.
"""
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index a02b5bf6..903ec40c 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -168,9 +168,24 @@ import inspect
import logging
import re
from functools import wraps
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Collection,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Pattern,
+ Type,
+ TypeVar,
+ Union,
+)
import attr
+from typing_extensions import ParamSpec
from twisted.internet import defer
from twisted.web.http import Request
@@ -256,7 +271,7 @@ try:
def set_process(self, *args, **kwargs):
return self._reporter.set_process(*args, **kwargs)
- def report_span(self, span):
+ def report_span(self, span: "opentracing.Span") -> None:
try:
return self._reporter.report_span(span)
except Exception:
@@ -307,15 +322,19 @@ _homeserver_whitelist: Optional[Pattern[str]] = None
Sentinel = object()
-def only_if_tracing(func):
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
"""Executes the function only if we're tracing. Otherwise returns None."""
@wraps(func)
- def _only_if_tracing_inner(*args, **kwargs):
+ def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
if opentracing:
return func(*args, **kwargs)
else:
- return
+ return None
return _only_if_tracing_inner
@@ -356,17 +375,10 @@ def ensure_active_span(message, ret=None):
return ensure_active_span_inner_1
-@contextlib.contextmanager
-def noop_context_manager(*args, **kwargs):
- """Does exactly what it says on the tin"""
- # TODO: replace with contextlib.nullcontext once we drop support for Python 3.6
- yield
-
-
# Setup
-def init_tracer(hs: "HomeServer"):
+def init_tracer(hs: "HomeServer") -> None:
"""Set the whitelists and initialise the JaegerClient tracer"""
global opentracing
if not hs.config.tracing.opentracer_enabled:
@@ -408,11 +420,11 @@ def init_tracer(hs: "HomeServer"):
@only_if_tracing
-def set_homeserver_whitelist(homeserver_whitelist):
+def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None:
"""Sets the homeserver whitelist
Args:
- homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers
+ homeserver_whitelist: regexes specifying whitelisted homeservers
"""
global _homeserver_whitelist
if homeserver_whitelist:
@@ -423,15 +435,15 @@ def set_homeserver_whitelist(homeserver_whitelist):
@only_if_tracing
-def whitelisted_homeserver(destination):
+def whitelisted_homeserver(destination: str) -> bool:
"""Checks if a destination matches the whitelist
Args:
- destination (str)
+ destination
"""
if _homeserver_whitelist:
- return _homeserver_whitelist.match(destination)
+ return _homeserver_whitelist.match(destination) is not None
return False
@@ -457,11 +469,11 @@ def start_active_span(
Args:
See opentracing.tracer
Returns:
- scope (Scope) or noop_context_manager
+ scope (Scope) or contextlib.nullcontext
"""
if opentracing is None:
- return noop_context_manager() # type: ignore[unreachable]
+ return contextlib.nullcontext() # type: ignore[unreachable]
if tracer is None:
# use the global tracer by default
@@ -505,7 +517,7 @@ def start_active_span_follows_from(
tracer: override the opentracing tracer. By default the global tracer is used.
"""
if opentracing is None:
- return noop_context_manager() # type: ignore[unreachable]
+ return contextlib.nullcontext() # type: ignore[unreachable]
references = [opentracing.follows_from(context) for context in contexts]
scope = start_active_span(
@@ -525,19 +537,19 @@ def start_active_span_follows_from(
def start_active_span_from_edu(
- edu_content,
- operation_name,
- references: Optional[list] = None,
- tags=None,
- start_time=None,
- ignore_active_span=False,
- finish_on_close=True,
-):
+ edu_content: Dict[str, Any],
+ operation_name: str,
+ references: Optional[List["opentracing.Reference"]] = None,
+ tags: Optional[Dict] = None,
+ start_time: Optional[float] = None,
+ ignore_active_span: bool = False,
+ finish_on_close: bool = True,
+) -> "opentracing.Scope":
"""
Extracts a span context from an edu and uses it to start a new active span
Args:
- edu_content (dict): and edu_content with a `context` field whose value is
+ edu_content: an edu_content with a `context` field whose value is
canonical json for a dict which contains opentracing information.
For the other args see opentracing.tracer
@@ -545,7 +557,7 @@ def start_active_span_from_edu(
references = references or []
if opentracing is None:
- return noop_context_manager() # type: ignore[unreachable]
+ return contextlib.nullcontext() # type: ignore[unreachable]
carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {}
@@ -578,27 +590,27 @@ def start_active_span_from_edu(
# Opentracing setters for tags, logs, etc
@only_if_tracing
-def active_span():
+def active_span() -> Optional["opentracing.Span"]:
"""Get the currently active span, if any"""
return opentracing.tracer.active_span
@ensure_active_span("set a tag")
-def set_tag(key, value):
+def set_tag(key: str, value: Union[str, bool, int, float]) -> None:
"""Sets a tag on the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_tag(key, value)
@ensure_active_span("log")
-def log_kv(key_values, timestamp=None):
+def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None:
"""Log to the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.log_kv(key_values, timestamp)
@ensure_active_span("set the traces operation name")
-def set_operation_name(operation_name):
+def set_operation_name(operation_name: str) -> None:
"""Sets the operation name of the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_operation_name(operation_name)
@@ -624,7 +636,9 @@ def force_tracing(span=Sentinel) -> None:
span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
-def is_context_forced_tracing(span_context) -> bool:
+def is_context_forced_tracing(
+ span_context: Optional["opentracing.SpanContext"],
+) -> bool:
"""Check if sampling has been force for the given span context."""
if span_context is None:
return False
@@ -696,13 +710,13 @@ def inject_response_headers(response_headers: Headers) -> None:
@ensure_active_span("get the active span context as a dict", ret={})
-def get_active_span_text_map(destination=None):
+def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
injecting a span into an empty carrier.
Args:
- destination (str): the name of the remote server.
+ destination: the name of the remote server.
Returns:
dict: the active span's context if opentracing is enabled, otherwise empty.
@@ -721,7 +735,7 @@ def get_active_span_text_map(destination=None):
@ensure_active_span("get the span context as a string.", ret={})
-def active_span_context_as_string():
+def active_span_context_as_string() -> str:
"""
Returns:
The active span context encoded as a string.
@@ -750,21 +764,21 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon
@only_if_tracing
-def span_context_from_string(carrier):
+def span_context_from_string(carrier: str) -> Optional["opentracing.SpanContext"]:
"""
Returns:
The active span context decoded from a string.
"""
- carrier = json_decoder.decode(carrier)
- return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+ payload: Dict[str, str] = json_decoder.decode(carrier)
+ return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, payload)
@only_if_tracing
-def extract_text_map(carrier):
+def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanContext"]:
"""
Wrapper method for opentracing's tracer.extract for TEXT_MAP.
Args:
- carrier (dict): a dict possibly containing a span context.
+ carrier: a dict possibly containing a span context.
Returns:
The active span context extracted from carrier.
@@ -843,7 +857,7 @@ def trace(func=None, opname=None):
return decorator
-def tag_args(func):
+def tag_args(func: Callable[P, R]) -> Callable[P, R]:
"""
Tags all of the args to the active span.
"""
@@ -852,11 +866,11 @@ def tag_args(func):
return func
@wraps(func)
- def _tag_args_inner(*args, **kwargs):
+ def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]):
- set_tag("ARG_" + arg, args[i])
- set_tag("args", args[len(argspec.args) :])
+ set_tag("ARG_" + arg, args[i]) # type: ignore[index]
+ set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
set_tag("kwargs", kwargs)
return func(*args, **kwargs)
@@ -864,7 +878,9 @@ def tag_args(func):
@contextlib.contextmanager
-def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+def trace_servlet(
+ request: "SynapseRequest", extract_context: bool = False
+) -> Generator[None, None, None]:
"""Returns a context manager which traces a request. It starts a span
with some servlet specific tags such as the request metrics name and
request information.
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index d57e7c53..a26a1a58 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -13,6 +13,8 @@
# limitations under the License.import logging
import logging
+from types import TracebackType
+from typing import Optional, Type
from opentracing import Scope, ScopeManager
@@ -107,19 +109,26 @@ class _LogContextScope(Scope):
and - if enter_logcontext was set - the logcontext is finished too.
"""
- def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
+ def __init__(
+ self,
+ manager: LogContextScopeManager,
+ span,
+ logcontext,
+ enter_logcontext: bool,
+ finish_on_close: bool,
+ ):
"""
Args:
- manager (LogContextScopeManager):
+ manager:
the manager that is responsible for this scope.
span (Span):
the opentracing span which this scope represents the local
lifetime for.
logcontext (LogContext):
the logcontext to which this scope is attached.
- enter_logcontext (Boolean):
+ enter_logcontext:
if True the logcontext will be exited when the scope is finished
- finish_on_close (Boolean):
+ finish_on_close:
if True finish the span when the scope is closed
"""
super().__init__(manager, span)
@@ -127,16 +136,21 @@ class _LogContextScope(Scope):
self._finish_on_close = finish_on_close
self._enter_logcontext = enter_logcontext
- def __exit__(self, exc_type, value, traceback):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
if exc_type == twisted.internet.defer._DefGen_Return:
# filter out defer.returnValue() calls
exc_type = value = traceback = None
super().__exit__(exc_type, value, traceback)
- def __str__(self):
+ def __str__(self) -> str:
return f"Scope<{self.span}>"
- def close(self):
+ def close(self) -> None:
active_scope = self.manager.active
if active_scope is not self:
logger.error(
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 29880974..eef3462e 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -14,6 +14,7 @@
import logging
import threading
+from contextlib import nullcontext
from functools import wraps
from types import TracebackType
from typing import (
@@ -41,11 +42,7 @@ from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
)
-from synapse.logging.opentracing import (
- SynapseTags,
- noop_context_manager,
- start_active_span,
-)
+from synapse.logging.opentracing import SynapseTags, start_active_span
from synapse.metrics._types import Collector
if TYPE_CHECKING:
@@ -238,7 +235,7 @@ def run_as_background_process(
f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
)
else:
- ctx = noop_context_manager()
+ ctx = nullcontext()
with ctx:
return await func(*args, **kwargs)
except Exception:
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 6bc329f0..1fc8a0e8 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -18,6 +18,7 @@ import os
import re
from typing import Iterable, Optional, overload
+import attr
from prometheus_client import REGISTRY, Metric
from typing_extensions import Literal
@@ -27,52 +28,24 @@ from synapse.metrics._types import Collector
logger = logging.getLogger(__name__)
-def _setup_jemalloc_stats() -> None:
- """Checks to see if jemalloc is loaded, and hooks up a collector to record
- statistics exposed by jemalloc.
- """
-
- # Try to find the loaded jemalloc shared library, if any. We need to
- # introspect into what is loaded, rather than loading whatever is on the
- # path, as if we load a *different* jemalloc version things will seg fault.
-
- # We look in `/proc/self/maps`, which only exists on linux.
- if not os.path.exists("/proc/self/maps"):
- logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
- return
-
- # We're looking for a path at the end of the line that includes
- # "libjemalloc".
- regex = re.compile(r"/\S+/libjemalloc.*$")
-
- jemalloc_path = None
- with open("/proc/self/maps") as f:
- for line in f:
- match = regex.search(line.strip())
- if match:
- jemalloc_path = match.group()
-
- if not jemalloc_path:
- # No loaded jemalloc was found.
- logger.debug("jemalloc not found")
- return
-
- logger.debug("Found jemalloc at %s", jemalloc_path)
-
- jemalloc = ctypes.CDLL(jemalloc_path)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class JemallocStats:
+ jemalloc: ctypes.CDLL
@overload
def _mallctl(
- name: str, read: Literal[True] = True, write: Optional[int] = None
+ self, name: str, read: Literal[True] = True, write: Optional[int] = None
) -> int:
...
@overload
- def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None:
+ def _mallctl(
+ self, name: str, read: Literal[False], write: Optional[int] = None
+ ) -> None:
...
def _mallctl(
- name: str, read: bool = True, write: Optional[int] = None
+ self, name: str, read: bool = True, write: Optional[int] = None
) -> Optional[int]:
"""Wrapper around `mallctl` for reading and writing integers to
jemalloc.
@@ -120,7 +93,7 @@ def _setup_jemalloc_stats() -> None:
# Where oldp/oldlenp is a buffer where the old value will be written to
# (if not null), and newp/newlen is the buffer with the new value to set
# (if not null). Note that they're all references *except* newlen.
- result = jemalloc.mallctl(
+ result = self.jemalloc.mallctl(
name.encode("ascii"),
input_var_ref,
input_len_ref,
@@ -136,21 +109,80 @@ def _setup_jemalloc_stats() -> None:
return input_var.value
- def _jemalloc_refresh_stats() -> None:
+ def refresh_stats(self) -> None:
"""Request that jemalloc updates its internal statistics. This needs to
be called before querying for stats, otherwise it will return stale
values.
"""
try:
- _mallctl("epoch", read=False, write=1)
+ self._mallctl("epoch", read=False, write=1)
except Exception as e:
logger.warning("Failed to reload jemalloc stats: %s", e)
+ def get_stat(self, name: str) -> int:
+ """Request the stat of the given name at the time of the last
+ `refresh_stats` call. This may throw if we fail to read
+ the stat.
+ """
+ return self._mallctl(f"stats.{name}")
+
+
+_JEMALLOC_STATS: Optional[JemallocStats] = None
+
+
+def get_jemalloc_stats() -> Optional[JemallocStats]:
+ """Returns an interface to jemalloc, if it is being used.
+
+ Note that this will always return None until `setup_jemalloc_stats` has been
+ called.
+ """
+ return _JEMALLOC_STATS
+
+
+def _setup_jemalloc_stats() -> None:
+ """Checks to see if jemalloc is loaded, and hooks up a collector to record
+ statistics exposed by jemalloc.
+ """
+
+ global _JEMALLOC_STATS
+
+ # Try to find the loaded jemalloc shared library, if any. We need to
+ # introspect into what is loaded, rather than loading whatever is on the
+ # path, as if we load a *different* jemalloc version things will seg fault.
+
+ # We look in `/proc/self/maps`, which only exists on linux.
+ if not os.path.exists("/proc/self/maps"):
+ logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
+ return
+
+ # We're looking for a path at the end of the line that includes
+ # "libjemalloc".
+ regex = re.compile(r"/\S+/libjemalloc.*$")
+
+ jemalloc_path = None
+ with open("/proc/self/maps") as f:
+ for line in f:
+ match = regex.search(line.strip())
+ if match:
+ jemalloc_path = match.group()
+
+ if not jemalloc_path:
+ # No loaded jemalloc was found.
+ logger.debug("jemalloc not found")
+ return
+
+ logger.debug("Found jemalloc at %s", jemalloc_path)
+
+ jemalloc_dll = ctypes.CDLL(jemalloc_path)
+
+ stats = JemallocStats(jemalloc_dll)
+ _JEMALLOC_STATS = stats
+
class JemallocCollector(Collector):
"""Metrics for internal jemalloc stats."""
def collect(self) -> Iterable[Metric]:
- _jemalloc_refresh_stats()
+ stats.refresh_stats()
g = GaugeMetricFamily(
"jemalloc_stats_app_memory_bytes",
@@ -184,7 +216,7 @@ def _setup_jemalloc_stats() -> None:
"metadata",
):
try:
- value = _mallctl(f"stats.{t}")
+ value = stats.get_stat(t)
except Exception as e:
# There was an error fetching the value, skip.
logger.warning("Failed to read jemalloc stats.%s: %s", t, e)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 73f92d2d..a8ad575f 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -47,12 +47,14 @@ from synapse.events.spamcheck import (
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
+ SHOULD_DROP_FEDERATED_EVENT_CALLBACK,
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK,
USER_MAY_CREATE_ROOM_CALLBACK,
USER_MAY_INVITE_CALLBACK,
USER_MAY_JOIN_ROOM_CALLBACK,
USER_MAY_PUBLISH_ROOM_CALLBACK,
USER_MAY_SEND_3PID_INVITE_CALLBACK,
+ SpamChecker,
)
from synapse.events.third_party_rules import (
CHECK_CAN_DEACTIVATE_USER_CALLBACK,
@@ -138,6 +140,7 @@ are loaded into Synapse.
"""
PRESENCE_ALL_USERS = PresenceRouter.ALL_USERS
+NOT_SPAM = SpamChecker.NOT_SPAM
__all__ = [
"errors",
@@ -146,6 +149,7 @@ __all__ = [
"respond_with_html",
"run_in_background",
"cached",
+ "NOT_SPAM",
"UserID",
"DatabasePool",
"LoggingTransaction",
@@ -190,6 +194,7 @@ class ModuleApi:
self._store: Union[
DataStore, "GenericWorkerSlavedStore"
] = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
@@ -234,6 +239,9 @@ class ModuleApi:
self,
*,
check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+ should_drop_federated_event: Optional[
+ SHOULD_DROP_FEDERATED_EVENT_CALLBACK
+ ] = None,
user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
@@ -254,6 +262,7 @@ class ModuleApi:
"""
return self._spam_checker.register_callbacks(
check_event_for_spam=check_event_for_spam,
+ should_drop_federated_event=should_drop_federated_event,
user_may_join_room=user_may_join_room,
user_may_invite=user_may_invite,
user_may_send_3pid_invite=user_may_send_3pid_invite,
@@ -903,7 +912,7 @@ class ModuleApi:
The filtered state events in the room.
"""
state_ids = yield defer.ensureDeferred(
- self._store.get_filtered_current_state_ids(
+ self._storage_controllers.state.get_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
)
@@ -1139,7 +1148,10 @@ class ModuleApi:
)
async def sleep(self, seconds: float) -> None:
- """Sleeps for the given number of seconds."""
+ """Sleeps for the given number of seconds.
+
+ Added in Synapse v1.49.0.
+ """
await self._clock.sleep(seconds)
@@ -1278,20 +1290,16 @@ class ModuleApi:
# regardless of their state key
]
"""
+ state_filter = None
if event_filter:
# If a filter was provided, turn it into a StateFilter and retrieve a filtered
# view of the state.
state_filter = StateFilter.from_types(event_filter)
- state_ids = await self._store.get_filtered_current_state_ids(
- room_id,
- state_filter,
- )
- else:
- # If no filter was provided, get the whole state. We could also reuse the call
- # to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`,
- # but get_filtered_current_state_ids isn't cached and `get_current_state_ids`
- # is, so using the latter when we can is better for perf.
- state_ids = await self._store.get_current_state_ids(room_id)
+
+ state_ids = await self._storage_controllers.state.get_current_state_ids(
+ room_id,
+ state_filter,
+ )
state_events = await self._store.get_events(state_ids.values())
@@ -1419,6 +1427,28 @@ class ModuleApi:
user_id, spec, {"actions": actions}
)
+ async def get_monthly_active_users_by_service(
+ self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None
+ ) -> List[Tuple[str, str]]:
+ """Generates list of monthly active users and their services.
+ Please see corresponding storage docstring for more details.
+
+ Added in Synapse v1.61.0.
+
+ Arguments:
+ start_timestamp: If specified, only include users that were first active
+ at or after this point
+ end_timestamp: If specified, only include users that were first active
+ at or before this point
+
+ Returns:
+ A list of tuples (appservice_id, user_id)
+
+ """
+ return await self._store.get_monthly_active_users_by_service(
+ start_timestamp, end_timestamp
+ )
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index e58e0e60..bedd045d 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -15,6 +15,7 @@
"""Exception types which are exposed as part of the stable module API"""
from synapse.api.errors import (
+ Codes,
InvalidClientCredentialsError,
RedirectException,
SynapseError,
@@ -24,6 +25,7 @@ from synapse.handlers.push_rules import InvalidRuleException
from synapse.storage.push_rule import RuleNotFoundException
__all__ = [
+ "Codes",
"InvalidClientCredentialsError",
"RedirectException",
"SynapseError",
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 01a50b9d..54b0ec4b 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -33,7 +33,7 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.api.constants import EventTypes, HistoryVisibility, Membership
+from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -46,6 +46,7 @@ from synapse.types import (
JsonDict,
PersistedEventPosition,
RoomStreamToken,
+ StreamKeyType,
StreamToken,
UserID,
)
@@ -220,7 +221,7 @@ class Notifier:
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
self.hs = hs
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
@@ -370,7 +371,7 @@ class Notifier:
if users or rooms:
self.on_new_event(
- "room_key",
+ StreamKeyType.ROOM,
max_room_stream_token,
users=users,
rooms=rooms,
@@ -440,7 +441,7 @@ class Notifier:
for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set())
- if stream_key == "to_device_key":
+ if stream_key == StreamKeyType.TO_DEVICE:
issue9533_logger.debug(
"to-device messages stream id %s, awaking streams for %s",
new_token,
@@ -622,7 +623,7 @@ class Notifier:
if name == "room":
new_events = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
user.to_string(),
new_events,
is_peeking=is_peeking,
@@ -631,7 +632,7 @@ class Notifier:
now = self.clock.time_msec()
new_events[:] = [
{
- "type": "m.presence",
+ "type": EduTypes.PRESENCE,
"content": format_user_presence_state(event, now),
}
for event in new_events
@@ -680,7 +681,7 @@ class Notifier:
return joined_room_ids, True
async def _is_world_readable(self, room_id: str) -> bool:
- state = await self.state_handler.get_current_state(
+ state = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index a1b77110..57c4d704 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -12,6 +12,80 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+This module implements the push rules & notifications portion of the Matrix
+specification.
+
+There's a few related features:
+
+* Push notifications (i.e. email or outgoing requests to a Push Gateway).
+* Calculation of unread notifications (for /sync and /notifications).
+
+When Synapse receives a new event (locally, via the Client-Server API, or via
+federation), the following occurs:
+
+1. The push rules get evaluated to generate a set of per-user actions.
+2. The event is persisted into the database.
+3. (In the background) The notifier is notified about the new event.
+
+The per-user actions are initially stored in the event_push_actions_staging table,
+before getting moved into the event_push_actions table when the event is persisted.
+The event_push_actions table is periodically summarised into the event_push_summary
+and event_push_summary_stream_ordering tables.
+
+Since push actions block an event from being persisted the generation of push
+actions is performance sensitive.
+
+The general interaction of the classes are:
+
+ +---------------------------------------------+
+ | FederationEventHandler/EventCreationHandler |
+ +---------------------------------------------+
+ |
+ v
+ +-----------------------+ +---------------------------+
+ | BulkPushRuleEvaluator |---->| PushRuleEvaluatorForEvent |
+ +-----------------------+ +---------------------------+
+ |
+ v
+ +-----------------------------+
+ | EventPushActionsWorkerStore |
+ +-----------------------------+
+
+The notifier notifies the pusher pool of the new event, which checks for affected
+users. Each user-configured pusher of the affected users then performs the
+previously calculated action.
+
+The general interaction of the classes are:
+
+ +----------+
+ | Notifier |
+ +----------+
+ |
+ v
+ +------------+ +--------------+
+ | PusherPool |---->| PusherConfig |
+ +------------+ +--------------+
+ |
+ | +---------------+
+ +<--->| PusherFactory |
+ | +---------------+
+ v
+ +------------------------+ +-----------------------------------------------+
+ | EmailPusher/HttpPusher |---->| EventPushActionsWorkerStore/PusherWorkerStore |
+ +------------------------+ +-----------------------------------------------+
+ |
+ v
+ +-------------------------+
+ | Mailer/SimpleHttpClient |
+ +-------------------------+
+
+The Pusher instance also calls out to various utilities for generating payloads
+(or email templates), but those interactions are not detailed in this diagram
+(and are specific to the type of pusher).
+
+"""
+
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
deleted file mode 100644
index 60758df0..00000000
--- a/synapse/push/action_generator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING
-
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
-from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
-from synapse.util.metrics import Measure
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class ActionGenerator:
- def __init__(self, hs: "HomeServer"):
- self.clock = hs.get_clock()
- self.bulk_evaluator = BulkPushRuleEvaluator(hs)
- # really we want to get all user ids and all profile tags too,
- # since we want the actions for each profile tag for every user and
- # also actions for a client with no profile tag for each user.
- # Currently the event stream doesn't support profile tags on an
- # event stream, so we just run the rules for a client with no profile
- # tag (ie. we just need all the users).
-
- async def handle_push_actions_for_event(
- self, event: EventBase, context: EventContext
- ) -> None:
- with Measure(self.clock, "action_for_event_by_user"):
- await self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index a17b35a6..819bc9e9 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -139,6 +139,7 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
{
"kind": "event_match",
"key": "content.body",
+ # Match the localpart of the requester's MXID.
"pattern_type": "user_localpart",
}
],
@@ -191,6 +192,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"pattern": "invite",
"_cache_key": "_invite_member",
},
+ # Match the requester's MXID.
{"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
],
"actions": [
@@ -290,7 +292,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_room_server_acl",
}
],
- "actions": ["dont_notify"],
+ "actions": [],
},
]
@@ -351,6 +353,18 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
],
},
{
+ "rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
+ "conditions": [
+ {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.thread",
+ # Match the requester's MXID.
+ "sender_type": "user_id",
+ }
+ ],
+ "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ },
+ {
"rule_id": "global/underride/.m.rule.message",
"conditions": [
{
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index b07cf2ee..7791b289 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes
-from synapse.event_auth import get_user_power_level
-from synapse.events import EventBase
+from synapse.event_auth import auth_types_for_event, get_user_power_level
+from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
@@ -29,7 +30,9 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
+from synapse.util.metrics import measure_func
+from ..storage.state import StateFilter
from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
@@ -77,8 +80,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
return False
# Exclude edits.
- relates_to = event.content.get("m.relates_to", {})
- if relates_to.get("rel_type") == RelationTypes.REPLACE:
+ relates_to = relation_from_event(event)
+ if relates_to and relates_to.rel_type == RelationTypes.REPLACE:
return False
# Mark events that have a non-empty string body as unread.
@@ -105,6 +108,7 @@ class BulkPushRuleEvaluator:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
+ self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()
# Used by `RulesForRoom` to ensure only one thing mutates the cache at a
@@ -118,6 +122,9 @@ class BulkPushRuleEvaluator:
resizable=False,
)
+ # Whether to support MSC3772 is supported.
+ self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
+
async def _get_rules_for_event(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, Any]]]:
@@ -146,12 +153,10 @@ class BulkPushRuleEvaluator:
if event.type == "m.room.member" and event.content["membership"] == "invite":
invited = event.state_key
if invited and self.hs.is_mine_id(invited):
- has_pusher = await self.store.user_has_pusher(invited)
- if has_pusher:
- rules_by_user = dict(rules_by_user)
- rules_by_user[invited] = await self.store.get_push_rules_for_user(
- invited
- )
+ rules_by_user = dict(rules_by_user)
+ rules_by_user[invited] = await self.store.get_push_rules_for_user(
+ invited
+ )
return rules_by_user
@@ -166,8 +171,12 @@ class BulkPushRuleEvaluator:
async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
) -> Tuple[dict, int]:
- prev_state_ids = await context.get_prev_state_ids()
+ event_types = auth_types_for_event(event.room_version, event)
+ prev_state_ids = await context.get_prev_state_ids(
+ StateFilter.from_types(event_types)
+ )
pl_event_id = prev_state_ids.get(POWER_KEY)
+
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
@@ -185,6 +194,61 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
+ async def _get_mutual_relations(
+ self, event: EventBase, rules: Iterable[Dict[str, Any]]
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ """
+ Fetch event metadata for events which related to the same event as the given event.
+
+ If the given event has no relation information, returns an empty dictionary.
+
+ Args:
+ event_id: The event ID which is targeted by relations.
+ rules: The push rules which will be processed for this event.
+
+ Returns:
+ A dictionary of relation type to:
+ A set of tuples of:
+ The sender
+ The event type
+ """
+
+ # If the experimental feature is not enabled, skip fetching relations.
+ if not self._relations_match_enabled:
+ return {}
+
+ # If the event does not have a relation, then cannot have any mutual
+ # relations.
+ relation = relation_from_event(event)
+ if not relation:
+ return {}
+
+ # Pre-filter to figure out which relation types are interesting.
+ rel_types = set()
+ for rule in rules:
+ # Skip disabled rules.
+ if "enabled" in rule and not rule["enabled"]:
+ continue
+
+ for condition in rule["conditions"]:
+ if condition["kind"] != "org.matrix.msc3772.relation_match":
+ continue
+
+ # rel_type is required.
+ rel_type = condition.get("rel_type")
+ if rel_type:
+ rel_types.add(rel_type)
+
+ # If no valid rules were found, no mutual relations.
+ if not rel_types:
+ return {}
+
+ # If any valid rules were found, fetch the mutual relations.
+ return await self.store.get_mutual_event_relations(
+ relation.parent_id, rel_types
+ )
+
+ @measure_func("action_for_event_by_user")
async def action_for_event_by_user(
self, event: EventBase, context: EventContext
) -> None:
@@ -192,6 +256,10 @@ class BulkPushRuleEvaluator:
should increment the unread count, and insert the results into the
event_push_actions_staging table.
"""
+ if event.internal_metadata.is_outlier():
+ # This can happen due to out of band memberships
+ return
+
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
@@ -204,11 +272,18 @@ class BulkPushRuleEvaluator:
sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context)
- evaluator = PushRuleEvaluatorForEvent(
- event, len(room_members), sender_power_level, power_levels
+ relations = await self._get_mutual_relations(
+ event, itertools.chain(*rules_by_user.values())
)
- condition_cache: Dict[str, bool] = {}
+ evaluator = PushRuleEvaluatorForEvent(
+ event,
+ len(room_members),
+ sender_power_level,
+ power_levels,
+ relations,
+ self._relations_match_enabled,
+ )
# If the event is not a state event check if any users ignore the sender.
if not event.is_state():
@@ -247,8 +322,8 @@ class BulkPushRuleEvaluator:
if "enabled" in rule and not rule["enabled"]:
continue
- matches = _condition_checker(
- evaluator, rule["conditions"], uid, display_name, condition_cache
+ matches = evaluator.check_conditions(
+ rule["conditions"], uid, display_name
)
if matches:
actions = [x for x in rule["actions"] if x != "dont_notify"]
@@ -267,32 +342,6 @@ class BulkPushRuleEvaluator:
)
-def _condition_checker(
- evaluator: PushRuleEvaluatorForEvent,
- conditions: List[dict],
- uid: str,
- display_name: Optional[str],
- cache: Dict[str, bool],
-) -> bool:
- for cond in conditions:
- _cache_key = cond.get("_cache_key", None)
- if _cache_key:
- res = cache.get(_cache_key, None)
- if res is False:
- return False
- elif res is True:
- continue
-
- res = evaluator.matches(cond, uid, display_name)
- if _cache_key:
- cache[_cache_key] = bool(res)
-
- if not res:
- return False
-
- return True
-
-
MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 63b22d50..5117ef68 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -48,6 +48,10 @@ def format_push_rules_for_user(
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
+ sender_type = c.pop("sender_type", None)
+ if sender_type == "user_id":
+ c["sender"] = user.to_string()
+
rulearray = rules["global"][template_name]
template_rule = _rule_to_template(r)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 58183445..e96fb45e 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -65,7 +65,7 @@ class HttpPusher(Pusher):
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
super().__init__(hs, pusher_config)
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
self.app_display_name = pusher_config.app_display_name
self.device_display_name = pusher_config.device_display_name
self.pushkey_ts = pusher_config.ts
@@ -343,7 +343,9 @@ class HttpPusher(Pusher):
}
return d
- ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
+ ctx = await push_tools.get_context_for_event(
+ self._storage_controllers, event, self.user_id
+ )
d = {
"notification": {
@@ -405,7 +407,7 @@ class HttpPusher(Pusher):
rejected = []
if "rejected" in resp:
rejected = resp["rejected"]
- else:
+ if not rejected:
self.badge_count_last_call = badge
return rejected
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 5ccdd883..015c19b2 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -114,10 +114,10 @@ class Mailer:
self.send_email_handler = hs.get_send_email_handler()
self.store = self.hs.get_datastores().main
- self.state_store = self.hs.get_storage().state
+ self._state_storage_controller = self.hs.get_storage_controllers().state
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.app_name = app_name
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
@@ -255,7 +255,9 @@ class Mailer:
user_display_name = user_id
async def _fetch_room_state(room_id: str) -> None:
- room_state = await self.store.get_current_state_ids(room_id)
+ room_state = await self._state_storage_controller.get_current_state_ids(
+ room_id
+ )
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
@@ -456,7 +458,7 @@ class Mailer:
}
the_events = await filter_events_for_client(
- self.storage, user_id, results.events_before
+ self._storage_controllers, user_id, results.events_before
)
the_events.append(notif_event)
@@ -494,7 +496,7 @@ class Mailer:
)
else:
# Attempt to check the historical state for the room.
- historical_state = await self.state_store.get_state_for_event(
+ historical_state = await self._state_storage_controller.get_state_for_event(
event.event_id, StateFilter.from_types((type_state_key,))
)
sender_state_event = historical_state.get(type_state_key)
@@ -767,8 +769,10 @@ class Mailer:
member_event_ids.append(sender_state_event_id)
else:
# Attempt to check the historical state for the room.
- historical_state = await self.state_store.get_state_for_event(
- event_id, StateFilter.from_types((type_state_key,))
+ historical_state = (
+ await self._state_storage_controller.get_state_for_event(
+ event_id, StateFilter.from_types((type_state_key,))
+ )
)
sender_state_event = historical_state.get(type_state_key)
if sender_state_event:
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index f617c759..2e8a017a 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,7 +15,7 @@
import logging
import re
-from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union
+from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
from matrix_common.regex import glob_to_regex, to_word_pattern
@@ -120,18 +120,68 @@ class PushRuleEvaluatorForEvent:
room_member_count: int,
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
+ relations: Dict[str, Set[Tuple[str, str]]],
+ relations_match_enabled: bool,
):
self._event = event
self._room_member_count = room_member_count
self._sender_power_level = sender_power_level
self._power_levels = power_levels
+ self._relations = relations
+ self._relations_match_enabled = relations_match_enabled
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
+ # Maps cache keys to final values.
+ self._condition_cache: Dict[str, bool] = {}
+
+ def check_conditions(
+ self, conditions: List[dict], uid: str, display_name: Optional[str]
+ ) -> bool:
+ """
+ Returns true if a user's conditions/user ID/display name match the event.
+
+ Args:
+ conditions: The user's conditions to match.
+ uid: The user's MXID.
+ display_name: The display name.
+
+ Returns:
+ True if all conditions match the event, False otherwise.
+ """
+ for cond in conditions:
+ _cache_key = cond.get("_cache_key", None)
+ if _cache_key:
+ res = self._condition_cache.get(_cache_key, None)
+ if res is False:
+ return False
+ elif res is True:
+ continue
+
+ res = self.matches(cond, uid, display_name)
+ if _cache_key:
+ self._condition_cache[_cache_key] = bool(res)
+
+ if not res:
+ return False
+
+ return True
+
def matches(
self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
+ """
+ Returns true if a user's condition/user ID/display name match the event.
+
+ Args:
+ condition: The user's condition to match.
+ uid: The user's MXID.
+ display_name: The display name, or None if there is not one.
+
+ Returns:
+ True if the condition matches the event, False otherwise.
+ """
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name":
@@ -142,10 +192,29 @@ class PushRuleEvaluatorForEvent:
return _sender_notification_permission(
self._event, condition, self._sender_power_level, self._power_levels
)
+ elif (
+ condition["kind"] == "org.matrix.msc3772.relation_match"
+ and self._relations_match_enabled
+ ):
+ return self._relation_match(condition, user_id)
else:
+ # XXX This looks incorrect -- we have reached an unknown condition
+ # kind and are unconditionally returning that it matches. Note
+ # that it seems possible to provide a condition to the /pushrules
+ # endpoint with an unknown kind, see _rule_tuple_from_request_object.
return True
def _event_match(self, condition: dict, user_id: str) -> bool:
+ """
+ Check an "event_match" push rule condition.
+
+ Args:
+ condition: The "event_match" push rule condition to match.
+ user_id: The user's MXID.
+
+ Returns:
+ True if the condition matches the event, False otherwise.
+ """
pattern = condition.get("pattern", None)
if not pattern:
@@ -167,13 +236,22 @@ class PushRuleEvaluatorForEvent:
return _glob_matches(pattern, body, word_boundary=True)
else:
- haystack = self._get_value(condition["key"])
+ haystack = self._value_cache.get(condition["key"], None)
if haystack is None:
return False
return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name: Optional[str]) -> bool:
+ """
+ Check an "event_match" push rule condition.
+
+ Args:
+ display_name: The display name, or None if there is not one.
+
+ Returns:
+ True if the display name is found in the event body, False otherwise.
+ """
if not display_name:
return False
@@ -191,8 +269,40 @@ class PushRuleEvaluatorForEvent:
return bool(r.search(body))
- def _get_value(self, dotted_key: str) -> Optional[str]:
- return self._value_cache.get(dotted_key, None)
+ def _relation_match(self, condition: dict, user_id: str) -> bool:
+ """
+ Check an "relation_match" push rule condition.
+
+ Args:
+ condition: The "event_match" push rule condition to match.
+ user_id: The user's MXID.
+
+ Returns:
+ True if the condition matches the event, False otherwise.
+ """
+ rel_type = condition.get("rel_type")
+ if not rel_type:
+ logger.warning("relation_match condition missing rel_type")
+ return False
+
+ sender_pattern = condition.get("sender")
+ if sender_pattern is None:
+ sender_type = condition.get("sender_type")
+ if sender_type == "user_id":
+ sender_pattern = user_id
+ type_pattern = condition.get("type")
+
+ # If any other relations matches, return True.
+ for sender, event_type in self._relations.get(rel_type, ()):
+ if sender_pattern and not _glob_matches(sender_pattern, sender):
+ continue
+ if type_pattern and not _glob_matches(type_pattern, event_type):
+ continue
+ # All values must have matched.
+ return True
+
+ # No relations matched.
+ return False
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index a1bf5b20..8397229c 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -16,7 +16,7 @@ from typing import Dict
from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
-from synapse.storage import Storage
+from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
@@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
async def get_context_for_event(
- storage: Storage, ev: EventBase, user_id: str
+ storage: StorageControllers, ev: EventBase, user_id: str
) -> Dict[str, str]:
ctx = {}
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2bd244ed..a4ae4040 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -26,7 +26,8 @@ from twisted.web.server import Request
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
-from synapse.http.server import HttpServer
+from synapse.http.server import HttpServer, is_method_cancellable
+from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.types import JsonDict
@@ -310,6 +311,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
url_args = list(self.PATH_ARGS)
method = self.METHOD
+ if self.CACHE and is_method_cancellable(self._handle_request):
+ raise Exception(
+ f"{self.__class__.__name__} has been marked as cancellable, but CACHE "
+ "is set. The cancellable flag would have no effect."
+ )
+
if self.CACHE:
url_args.append("txn_id")
@@ -324,7 +331,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
)
async def _check_auth_and_handle(
- self, request: Request, **kwargs: Any
+ self, request: SynapseRequest, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
@@ -340,8 +347,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self.CACHE:
txn_id = kwargs.pop("txn_id")
+ # We ignore the `@cancellable` flag, since cancellation wouldn't interupt
+ # `_handle_request` and `ResponseCache` does not handle cancellation
+ # correctly yet. In particular, there may be issues to do with logging
+ # context lifetimes.
+
return await self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
)
+ # The `@cancellable` decorator may be applied to `_handle_request`. But we
+ # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
+ # so we have to set up the cancellable flag ourselves.
+ request.is_render_cancellable = is_method_cancellable(self._handle_request)
+
return await self._handle_request(request, **kwargs)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 3e7300b4..eed29cd5 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
super().__init__(hs)
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self.federation_event_handler = hs.get_federation_event_handler()
@@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event.internal_metadata.outlier = event_payload["outlier"]
context = EventContext.deserialize(
- self.storage, event_payload["context"]
+ self._storage_controllers, event_payload["context"]
)
event_and_contexts.append((event, context))
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index ce781768..c2b2588e 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
@staticmethod
@@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event.internal_metadata.outlier = content["outlier"]
requester = Requester.deserialize(self.store, content["requester"])
- context = EventContext.deserialize(self.storage, content["context"])
+ context = EventContext.deserialize(
+ self._storage_controllers, content["context"]
+ )
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
deleted file mode 100644
index d6f37d74..00000000
--- a/synapse/replication/slave/storage/groups.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING, Any, Iterable
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import GroupServerStream
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.group_server import GroupServerWorkerStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self.hs = hs
-
- self._group_updates_id_gen = SlavedIdTracker(
- db_conn, "local_group_updates", "stream_id"
- )
- self._group_updates_stream_cache = StreamChangeCache(
- "_group_updates_stream_cache",
- self._group_updates_id_gen.get_current_token(),
- )
-
- def get_group_stream_token(self) -> int:
- return self._group_updates_id_gen.get_current_token()
-
- def process_replication_rows(
- self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
- ) -> None:
- if stream_name == GroupServerStream.NAME:
- self._group_updates_id_gen.advance(instance_name, token)
- for row in rows:
- self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
-
- return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 350762f4..2f592450 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -30,7 +30,6 @@ from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams import (
AccountDataStream,
DeviceListsStream,
- GroupServerStream,
PushersStream,
PushRulesStream,
ReceiptsStream,
@@ -43,7 +42,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
EventsStreamRow,
)
-from synapse.types import PersistedEventPosition, ReadReceipt, UserID
+from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID
from synapse.util.async_helpers import Linearizer, timeout_deferred
from synapse.util.metrics import Measure
@@ -153,19 +152,19 @@ class ReplicationDataHandler:
if stream_name == TypingStream.NAME:
self._typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows]
+ StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows]
)
elif stream_name == PushRulesStream.NAME:
self.notifier.on_new_event(
- "push_rules_key", token, users=[row.user_id for row in rows]
+ StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
)
elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
self.notifier.on_new_event(
- "account_data_key", token, users=[row.user_id for row in rows]
+ StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
)
elif stream_name == ReceiptsStream.NAME:
self.notifier.on_new_event(
- "receipt_key", token, rooms=[row.room_id for row in rows]
+ StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
)
await self._pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
@@ -173,17 +172,17 @@ class ReplicationDataHandler:
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
- self.notifier.on_new_event("to_device_key", token, users=entities)
+ self.notifier.on_new_event(
+ StreamKeyType.TO_DEVICE, token, users=entities
+ )
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
- self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
- elif stream_name == GroupServerStream.NAME:
self.notifier.on_new_event(
- "groups_key", token, users=[row.user_id for row in rows]
+ StreamKeyType.DEVICE_LIST, token, rooms=all_room_ids
)
elif stream_name == PushersStream.NAME:
for row in rows:
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index fe349481..32f52e54 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -58,6 +58,15 @@ class Command(metaclass=abc.ABCMeta):
# by default, we just use the command name.
return self.NAME
+ def redis_channel_name(self, prefix: str) -> str:
+ """
+ Returns the Redis channel name upon which to publish this command.
+
+ Args:
+ prefix: The prefix for the channel.
+ """
+ return prefix
+
SC = TypeVar("SC", bound="_SimpleCommand")
@@ -395,6 +404,9 @@ class UserIpCommand(Command):
f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
)
+ def redis_channel_name(self, prefix: str) -> str:
+ return f"{prefix}/USER_IP"
+
class RemoteServerUpCommand(_SimpleCommand):
"""Sent when a worker has detected that a remote server is no longer
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 9aba1cd4..e1cbfa50 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -1,5 +1,5 @@
# Copyright 2017 Vector Creations Ltd
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020, 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -101,6 +101,9 @@ class ReplicationCommandHandler:
self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name()
+ # Additional Redis channel suffixes to subscribe to.
+ self._channels_to_subscribe_to: List[str] = []
+
self._is_presence_writer = (
hs.get_instance_name() in hs.config.worker.writers.presence
)
@@ -243,6 +246,31 @@ class ReplicationCommandHandler:
# If we're NOT using Redis, this must be handled by the master
self._should_insert_client_ips = hs.get_instance_name() == "master"
+ if self._is_master or self._should_insert_client_ips:
+ self.subscribe_to_channel("USER_IP")
+
+ def subscribe_to_channel(self, channel_name: str) -> None:
+ """
+ Indicates that we wish to subscribe to a Redis channel by name.
+
+ (The name will later be prefixed with the server name; i.e. subscribing
+ to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.)
+
+ Raises:
+ - If replication has already started, then it's too late to subscribe
+ to new channels.
+ """
+
+ if self._factory is not None:
+ # We don't allow subscribing after the fact to avoid the chance
+ # of missing an important message because we didn't subscribe in time.
+ raise RuntimeError(
+ "Cannot subscribe to more channels after replication started."
+ )
+
+ if channel_name not in self._channels_to_subscribe_to:
+ self._channels_to_subscribe_to.append(channel_name)
+
def _add_command_to_stream_queue(
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
@@ -321,7 +349,9 @@ class ReplicationCommandHandler:
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
- hs, outbound_redis_connection
+ hs,
+ outbound_redis_connection,
+ channel_names=self._channels_to_subscribe_to,
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 989c5be0..fd1c0ec6 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,7 +14,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast
import attr
import txredisapi
@@ -85,14 +85,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
Attributes:
synapse_handler: The command handler to handle incoming commands.
- synapse_stream_name: The *redis* stream name to subscribe to and publish
+ synapse_stream_prefix: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
synapse_handler: "ReplicationCommandHandler"
- synapse_stream_name: str
+ synapse_stream_prefix: str
+ synapse_channel_names: List[str]
synapse_outbound_redis_connection: txredisapi.ConnectionHandler
def __init__(self, *args: Any, **kwargs: Any):
@@ -117,8 +118,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
- logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
- await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
+ fully_qualified_stream_names = [
+ f"{self.synapse_stream_prefix}/{stream_suffix}"
+ for stream_suffix in self.synapse_channel_names
+ ] + [self.synapse_stream_prefix]
+ logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
+ await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))
+
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
@@ -215,10 +221,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
+ channel_name = cmd.redis_channel_name(self.synapse_stream_prefix)
+
await make_deferred_yieldable(
- self.synapse_outbound_redis_connection.publish(
- self.synapse_stream_name, encoded_string
- )
+ self.synapse_outbound_redis_connection.publish(channel_name, encoded_string)
)
@@ -300,20 +306,27 @@ def format_address(address: IAddress) -> str:
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
- subscribes to a stream.
+ subscribes to some streams.
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
send outbound commands (this is separate to the redis connection
used to subscribe).
+ channel_names: A list of channel names to append to the base channel name
+ to additionally subscribe to.
+ e.g. if ['ABC', 'DEF'] is specified then we'll listen to:
+ example.com; example.com/ABC; and example.com/DEF.
"""
maxDelay = 5
protocol = RedisSubscriber
def __init__(
- self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler
+ self,
+ hs: "HomeServer",
+ outbound_redis_connection: txredisapi.ConnectionHandler,
+ channel_names: List[str],
):
super().__init__(
@@ -326,7 +339,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
)
self.synapse_handler = hs.get_replication_command_handler()
- self.synapse_stream_name = hs.hostname
+ self.synapse_stream_prefix = hs.hostname
+ self.synapse_channel_names = channel_names
self.synapse_outbound_redis_connection = outbound_redis_connection
@@ -340,7 +354,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
# protocol.
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
- p.synapse_stream_name = self.synapse_stream_name
+ p.synapse_stream_prefix = self.synapse_stream_prefix
+ p.synapse_channel_names = self.synapse_channel_names
return p
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index f41eabd8..b1cd55bf 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -29,7 +29,6 @@ from synapse.replication.tcp.streams._base import (
BackfillStream,
CachesStream,
DeviceListsStream,
- GroupServerStream,
PresenceFederationStream,
PresenceStream,
PushersStream,
@@ -61,7 +60,6 @@ STREAMS_MAP = {
FederationStream,
TagAccountDataStream,
AccountDataStream,
- GroupServerStream,
UserSignatureStream,
)
}
@@ -81,6 +79,5 @@ __all__ = [
"ToDeviceStream",
"TagAccountDataStream",
"AccountDataStream",
- "GroupServerStream",
"UserSignatureStream",
]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 495f2f02..398bebea 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -585,26 +585,6 @@ class AccountDataStream(Stream):
return updates, to_token, limited
-class GroupServerStream(Stream):
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class GroupsStreamRow:
- group_id: str
- user_id: str
- type: str
- content: JsonDict
-
- NAME = "groups"
- ROW_TYPE = GroupsStreamRow
-
- def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
- super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_group_stream_token),
- store.get_all_groups_changes,
- )
-
-
class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key"""
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 57c4773e..b7122151 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -26,7 +26,6 @@ from synapse.rest.client import (
directory,
events,
filter,
- groups,
initial_sync,
keys,
knock,
@@ -118,8 +117,6 @@ class ClientRestResource(JsonResource):
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)
- if hs.config.experimental.groups_enabled:
- groups.register_servlets(hs, client_resource)
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
room_batch.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index cb4d55c8..1aa08f8d 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -47,7 +47,6 @@ from synapse.rest.admin.federation import (
DestinationRestServlet,
ListDestinationsRestServlet,
)
-from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet,
@@ -293,8 +292,6 @@ def register_servlets_for_client_rest_resource(
ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
- if hs.config.experimental.groups_enabled:
- DeleteGroupAdminRestServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
# Load the media repo ones if we're using them. Otherwise load the servlets which
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
deleted file mode 100644
index cd697e18..00000000
--- a/synapse/rest/admin/groups.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
-
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet
-from synapse.http.site import SynapseRequest
-from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
-from synapse.types import JsonDict
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class DeleteGroupAdminRestServlet(RestServlet):
- """Allows deleting of local groups"""
-
- PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
-
- def __init__(self, hs: "HomeServer"):
- self.group_server = hs.get_groups_server_handler()
- self.is_mine_id = hs.is_mine_id
- self.auth = hs.get_auth()
-
- async def on_POST(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- if not self.is_mine_id(group_id):
- raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups")
-
- await self.group_server.delete_group(group_id, requester.user.to_string())
- return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 8ca57bdb..19d4a008 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -83,7 +83,7 @@ class QuarantineMediaByUser(RestServlet):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
- logging.info("Quarantining local media by user: %s", user_id)
+ logging.info("Quarantining media by user: %s", user_id)
# Quarantine all media this user has uploaded
num_quarantined = await self.store.quarantine_media_ids_by_user(
@@ -112,7 +112,7 @@ class QuarantineMediaByID(RestServlet):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
- logging.info("Quarantining local media by ID: %s/%s", server_name, media_id)
+ logging.info("Quarantining media by ID: %s/%s", server_name, media_id)
# Quarantine this media id
await self.store.quarantine_media_by_id(
@@ -140,9 +140,7 @@ class UnquarantineMediaByID(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- logging.info(
- "Remove from quarantine local media by ID: %s/%s", server_name, media_id
- )
+ logging.info("Remove from quarantine media by ID: %s/%s", server_name, media_id)
# Remove from quarantine this media id
await self.store.quarantine_media_by_id(server_name, media_id, None)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 356d6f74..9d953d58 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
+from synapse.storage.state import StateFilter
from synapse.types import JsonDict, RoomID, UserID, create_requester
from synapse.util import json_decoder
@@ -418,6 +419,7 @@ class RoomStateRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
@@ -430,7 +432,7 @@ class RoomStateRestServlet(RestServlet):
if not ret:
raise NotFoundError("Room not found")
- event_ids = await self.store.get_current_state_ids(room_id)
+ event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = self._event_serializer.serialize_events(events.values(), now)
@@ -447,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- self.state_handler = hs.get_state_handler()
+ self.store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self.is_mine = hs.is_mine
async def on_POST(
@@ -489,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
)
# send invite if room has "JoinRules.INVITE"
- room_state = await self.state_handler.get_current_state(room_id)
- join_rules_event = room_state.get((EventTypes.JoinRules, ""))
+ join_rules_event = (
+ await self._storage_controllers.state.get_current_state_event(
+ room_id, EventTypes.JoinRules, ""
+ )
+ )
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
# update_membership with an action of "invite" can raise a
@@ -535,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
+ self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
@@ -552,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
user_to_add = content.get("user_id", requester.user.to_string())
# Figure out which local users currently have power in the room, if any.
- room_state = await self.state_handler.get_current_state(room_id)
- if not room_state:
+ filtered_room_state = await self._state_storage_controller.get_current_state(
+ room_id,
+ StateFilter.from_types(
+ [
+ (EventTypes.Create, ""),
+ (EventTypes.PowerLevels, ""),
+ (EventTypes.JoinRules, ""),
+ (EventTypes.Member, user_to_add),
+ ]
+ ),
+ )
+ if not filtered_room_state:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
- create_event = room_state[(EventTypes.Create, "")]
- power_levels = room_state.get((EventTypes.PowerLevels, ""))
+ create_event = filtered_room_state[(EventTypes.Create, "")]
+ power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
if power_levels is not None:
# We pick the local user with the highest power.
@@ -633,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
# Now we check if the user we're granting admin rights to is already in
# the room. If not and it's not a public room we invite them.
- member_event = room_state.get((EventTypes.Member, user_to_add))
+ member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
is_joined = False
if member_event:
is_joined = member_event.content["membership"] in (
@@ -644,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
if is_joined:
return HTTPStatus.OK, {}
- join_rules = room_state.get((EventTypes.JoinRules, ""))
+ join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
is_public = False
if join_rules:
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 8e29ada8..f0614a28 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -226,6 +226,13 @@ class UserRestServletV2(RestServlet):
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
+ logout_devices = body.get("logout_devices", True)
+ if not isinstance(logout_devices, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "'logout_devices' parameter is not of type boolean",
+ )
+
deactivate = body.get("deactivated", False)
if not isinstance(deactivate, bool):
raise SynapseError(
@@ -305,7 +312,6 @@ class UserRestServletV2(RestServlet):
await self.store.set_server_admin(target_user, set_admin_to)
if password is not None:
- logout_devices = True
new_password_hash = await self.auth_handler.hash(password)
await self.set_password_handler.set_password(
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
deleted file mode 100644
index 7e1149c7..00000000
--- a/synapse/rest/client/groups.py
+++ /dev/null
@@ -1,962 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from functools import wraps
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
-
-from twisted.web.server import Request
-
-from synapse.api.constants import (
- MAX_GROUP_CATEGORYID_LENGTH,
- MAX_GROUP_ROLEID_LENGTH,
- MAX_GROUPID_LENGTH,
-)
-from synapse.api.errors import Codes, SynapseError
-from synapse.handlers.groups_local import GroupsLocalHandler
-from synapse.http.server import HttpServer
-from synapse.http.servlet import (
- RestServlet,
- assert_params_in_dict,
- parse_json_object_from_request,
-)
-from synapse.http.site import SynapseRequest
-from synapse.types import GroupID, JsonDict
-
-from ._base import client_patterns
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-def _validate_group_id(
- f: Callable[..., Awaitable[Tuple[int, JsonDict]]]
-) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]:
- """Wrapper to validate the form of the group ID.
-
- Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
- """
-
- @wraps(f)
- def wrapper(
- self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any
- ) -> Awaitable[Tuple[int, JsonDict]]:
- if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
-
- return f(self, request, group_id, *args, **kwargs)
-
- return wrapper
-
-
-class GroupServlet(RestServlet):
- """Get the group profile"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- group_description = await self.groups_handler.get_group_profile(
- group_id, requester_user_id
- )
-
- return 200, group_description
-
- @_validate_group_id
- async def on_POST(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- assert_params_in_dict(
- content, ("name", "avatar_url", "short_description", "long_description")
- )
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot create group profiles."
- await self.groups_handler.update_group_profile(
- group_id, requester_user_id, content
- )
-
- return 200, {}
-
-
-class GroupSummaryServlet(RestServlet):
- """Get the full group summary"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- get_group_summary = await self.groups_handler.get_group_summary(
- group_id, requester_user_id
- )
-
- return 200, get_group_summary
-
-
-class GroupSummaryRoomsCatServlet(RestServlet):
- """Update/delete a rooms entry in the summary.
-
- Matches both:
- - /groups/:group/summary/rooms/:room_id
- - /groups/:group/summary/categories/:category/rooms/:room_id
- """
-
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/summary"
- "(/categories/(?P<category_id>[^/]+))?"
- "/rooms/(?P<room_id>[^/]*)$"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self,
- request: SynapseRequest,
- group_id: str,
- category_id: Optional[str],
- room_id: str,
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- if category_id == "":
- raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
-
- if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
- raise SynapseError(
- 400,
- "category_id may not be longer than %s characters"
- % (MAX_GROUP_CATEGORYID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group summaries."
- resp = await self.groups_handler.update_group_summary_room(
- group_id,
- requester_user_id,
- room_id=room_id,
- category_id=category_id,
- content=content,
- )
-
- return 200, resp
-
- @_validate_group_id
- async def on_DELETE(
- self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group profiles."
- resp = await self.groups_handler.delete_group_summary_room(
- group_id, requester_user_id, room_id=room_id, category_id=category_id
- )
-
- return 200, resp
-
-
-class GroupCategoryServlet(RestServlet):
- """Get/add/update/delete a group category"""
-
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str, category_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- category = await self.groups_handler.get_group_category(
- group_id, requester_user_id, category_id=category_id
- )
-
- return 200, category
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str, category_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- if not category_id:
- raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
-
- if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
- raise SynapseError(
- 400,
- "category_id may not be longer than %s characters"
- % (MAX_GROUP_CATEGORYID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group categories."
- resp = await self.groups_handler.update_group_category(
- group_id, requester_user_id, category_id=category_id, content=content
- )
-
- return 200, resp
-
- @_validate_group_id
- async def on_DELETE(
- self, request: SynapseRequest, group_id: str, category_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group categories."
- resp = await self.groups_handler.delete_group_category(
- group_id, requester_user_id, category_id=category_id
- )
-
- return 200, resp
-
-
-class GroupCategoriesServlet(RestServlet):
- """Get all group categories"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- category = await self.groups_handler.get_group_categories(
- group_id, requester_user_id
- )
-
- return 200, category
-
-
-class GroupRoleServlet(RestServlet):
- """Get/add/update/delete a group role"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str, role_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- category = await self.groups_handler.get_group_role(
- group_id, requester_user_id, role_id=role_id
- )
-
- return 200, category
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str, role_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- if not role_id:
- raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
-
- if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
- raise SynapseError(
- 400,
- "role_id may not be longer than %s characters"
- % (MAX_GROUP_ROLEID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group roles."
- resp = await self.groups_handler.update_group_role(
- group_id, requester_user_id, role_id=role_id, content=content
- )
-
- return 200, resp
-
- @_validate_group_id
- async def on_DELETE(
- self, request: SynapseRequest, group_id: str, role_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group roles."
- resp = await self.groups_handler.delete_group_role(
- group_id, requester_user_id, role_id=role_id
- )
-
- return 200, resp
-
-
-class GroupRolesServlet(RestServlet):
- """Get all group roles"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- category = await self.groups_handler.get_group_roles(
- group_id, requester_user_id
- )
-
- return 200, category
-
-
-class GroupSummaryUsersRoleServlet(RestServlet):
- """Update/delete a user's entry in the summary.
-
- Matches both:
- - /groups/:group/summary/users/:room_id
- - /groups/:group/summary/roles/:role/users/:user_id
- """
-
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/summary"
- "(/roles/(?P<role_id>[^/]+))?"
- "/users/(?P<user_id>[^/]*)$"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self,
- request: SynapseRequest,
- group_id: str,
- role_id: Optional[str],
- user_id: str,
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- if role_id == "":
- raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
-
- if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH:
- raise SynapseError(
- 400,
- "role_id may not be longer than %s characters"
- % (MAX_GROUP_ROLEID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group summaries."
- resp = await self.groups_handler.update_group_summary_user(
- group_id,
- requester_user_id,
- user_id=user_id,
- role_id=role_id,
- content=content,
- )
-
- return 200, resp
-
- @_validate_group_id
- async def on_DELETE(
- self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group summaries."
- resp = await self.groups_handler.delete_group_summary_user(
- group_id, requester_user_id, user_id=user_id, role_id=role_id
- )
-
- return 200, resp
-
-
-class GroupRoomServlet(RestServlet):
- """Get all rooms in a group"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- result = await self.groups_handler.get_rooms_in_group(
- group_id, requester_user_id
- )
-
- return 200, result
-
-
-class GroupUsersServlet(RestServlet):
- """Get all users in a group"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- result = await self.groups_handler.get_users_in_group(
- group_id, requester_user_id
- )
-
- return 200, result
-
-
-class GroupInvitedUsersServlet(RestServlet):
- """Get users invited to a group"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_GET(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- result = await self.groups_handler.get_invited_users_in_group(
- group_id, requester_user_id
- )
-
- return 200, result
-
-
-class GroupSettingJoinPolicyServlet(RestServlet):
- """Set group join policy"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
-
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group join policy."
- result = await self.groups_handler.set_group_join_policy(
- group_id, requester_user_id, content
- )
-
- return 200, result
-
-
-class GroupCreateServlet(RestServlet):
- """Create a group"""
-
- PATTERNS = client_patterns("/create_group$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
- self.server_name = hs.hostname
-
- async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- # TODO: Create group on remote server
- content = parse_json_object_from_request(request)
- localpart = content.pop("localpart")
- group_id = GroupID(localpart, self.server_name).to_string()
-
- if not localpart:
- raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
-
- if len(group_id) > MAX_GROUPID_LENGTH:
- raise SynapseError(
- 400,
- "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,),
- Codes.INVALID_PARAM,
- )
-
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot create groups."
- result = await self.groups_handler.create_group(
- group_id, requester_user_id, content
- )
-
- return 200, result
-
-
-class GroupAdminRoomsServlet(RestServlet):
- """Add a room to the group"""
-
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str, room_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify rooms in a group."
- result = await self.groups_handler.add_room_to_group(
- group_id, requester_user_id, room_id, content
- )
-
- return 200, result
-
- @_validate_group_id
- async def on_DELETE(
- self, request: SynapseRequest, group_id: str, room_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group categories."
- result = await self.groups_handler.remove_room_from_group(
- group_id, requester_user_id, room_id
- )
-
- return 200, result
-
-
-class GroupAdminRoomsConfigServlet(RestServlet):
- """Update the config of a room in a group"""
-
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
- "/config/(?P<config_key>[^/]*)$"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot modify group categories."
- result = await self.groups_handler.update_room_in_group(
- group_id, requester_user_id, room_id, config_key, content
- )
-
- return 200, result
-
-
-class GroupAdminUsersInviteServlet(RestServlet):
- """Invite a user to the group"""
-
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
- self.store = hs.get_datastores().main
- self.is_mine_id = hs.is_mine_id
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str, user_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- config = content.get("config", {})
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot invite users to a group."
- result = await self.groups_handler.invite(
- group_id, user_id, requester_user_id, config
- )
-
- return 200, result
-
-
-class GroupAdminUsersKickServlet(RestServlet):
- """Kick a user from the group"""
-
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str, user_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot kick users from a group."
- result = await self.groups_handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content
- )
-
- return 200, result
-
-
-class GroupSelfLeaveServlet(RestServlet):
- """Leave a joined group"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot leave a group for a users."
- result = await self.groups_handler.remove_user_from_group(
- group_id, requester_user_id, requester_user_id, content
- )
-
- return 200, result
-
-
-class GroupSelfJoinServlet(RestServlet):
- """Attempt to join a group, or knock"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot join a user to a group."
- result = await self.groups_handler.join_group(
- group_id, requester_user_id, content
- )
-
- return 200, result
-
-
-class GroupSelfAcceptInviteServlet(RestServlet):
- """Accept a group invite"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- assert isinstance(
- self.groups_handler, GroupsLocalHandler
- ), "Workers cannot accept an invite to a group."
- result = await self.groups_handler.accept_invite(
- group_id, requester_user_id, content
- )
-
- return 200, result
-
-
-class GroupSelfUpdatePublicityServlet(RestServlet):
- """Update whether we publicise a users membership of a group"""
-
- PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.store = hs.get_datastores().main
-
- @_validate_group_id
- async def on_PUT(
- self, request: SynapseRequest, group_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- requester_user_id = requester.user.to_string()
-
- content = parse_json_object_from_request(request)
- publicise = content["publicise"]
- await self.store.update_group_publicity(group_id, requester_user_id, publicise)
-
- return 200, {}
-
-
-class PublicisedGroupsForUserServlet(RestServlet):
- """Get the list of groups a user is advertising"""
-
- PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.store = hs.get_datastores().main
- self.groups_handler = hs.get_groups_local_handler()
-
- async def on_GET(
- self, request: SynapseRequest, user_id: str
- ) -> Tuple[int, JsonDict]:
- await self.auth.get_user_by_req(request, allow_guest=True)
-
- result = await self.groups_handler.get_publicised_groups_for_user(user_id)
-
- return 200, result
-
-
-class PublicisedGroupsForUsersServlet(RestServlet):
- """Get the list of groups a user is advertising"""
-
- PATTERNS = client_patterns("/publicised_groups$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.store = hs.get_datastores().main
- self.groups_handler = hs.get_groups_local_handler()
-
- async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- await self.auth.get_user_by_req(request, allow_guest=True)
-
- content = parse_json_object_from_request(request)
- user_ids = content["user_ids"]
-
- result = await self.groups_handler.bulk_get_publicised_groups(user_ids)
-
- return 200, result
-
-
-class GroupsForUserServlet(RestServlet):
- """Get all groups the logged in user is joined to"""
-
- PATTERNS = client_patterns("/joined_groups$")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.groups_handler = hs.get_groups_local_handler()
-
- async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- requester_user_id = requester.user.to_string()
-
- result = await self.groups_handler.get_joined_groups(requester_user_id)
-
- return 200, result
-
-
-def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- GroupServlet(hs).register(http_server)
- GroupSummaryServlet(hs).register(http_server)
- GroupInvitedUsersServlet(hs).register(http_server)
- GroupUsersServlet(hs).register(http_server)
- GroupRoomServlet(hs).register(http_server)
- GroupSettingJoinPolicyServlet(hs).register(http_server)
- GroupCreateServlet(hs).register(http_server)
- GroupAdminRoomsServlet(hs).register(http_server)
- GroupAdminRoomsConfigServlet(hs).register(http_server)
- GroupAdminUsersInviteServlet(hs).register(http_server)
- GroupAdminUsersKickServlet(hs).register(http_server)
- GroupSelfLeaveServlet(hs).register(http_server)
- GroupSelfJoinServlet(hs).register(http_server)
- GroupSelfAcceptInviteServlet(hs).register(http_server)
- GroupsForUserServlet(hs).register(http_server)
- GroupCategoryServlet(hs).register(http_server)
- GroupCategoriesServlet(hs).register(http_server)
- GroupSummaryRoomsCatServlet(hs).register(http_server)
- GroupRoleServlet(hs).register(http_server)
- GroupRolesServlet(hs).register(http_server)
- GroupSelfUpdatePublicityServlet(hs).register(http_server)
- GroupSummaryUsersRoleServlet(hs).register(http_server)
- PublicisedGroupsForUserServlet(hs).register(http_server)
- PublicisedGroupsForUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py
index 27bfaf0b..38ef4e45 100644
--- a/synapse/rest/client/mutual_rooms.py
+++ b/synapse/rest/client/mutual_rooms.py
@@ -42,21 +42,10 @@ class UserMutualRoomsServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.user_directory_search_enabled = (
- hs.config.userdirectory.user_directory_search_enabled
- )
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
-
- if not self.user_directory_search_enabled:
- raise SynapseError(
- code=400,
- msg="User directory searching is disabled. Cannot determine shared rooms.",
- errcode=Codes.UNKNOWN,
- )
-
UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@@ -67,8 +56,8 @@ class UserMutualRoomsServlet(RestServlet):
errcode=Codes.FORBIDDEN,
)
- rooms = await self.store.get_mutual_rooms_for_users(
- requester.user.to_string(), user_id
+ rooms = await self.store.get_mutual_rooms_between_users(
+ frozenset((requester.user.to_string(), user_id))
)
return 200, {"joined": list(rooms)}
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index b98640b1..8191b4e3 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -148,9 +148,9 @@ class PushRuleRestServlet(RestServlet):
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
- rules = await self.store.get_push_rules_for_user(user_id)
+ rules_raw = await self.store.get_push_rules_for_user(user_id)
- rules = format_push_rules_for_user(requester.user, rules)
+ rules = format_push_rules_for_user(requester.user, rules_raw)
path_parts = path.split("/")[1:]
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index f9caab66..4b03eb87 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -13,12 +13,10 @@
# limitations under the License.
import logging
-import re
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReceiptTypes
from synapse.api.errors import SynapseError
-from synapse.http import get_request_user_agent
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -26,8 +24,6 @@ from synapse.types import JsonDict
from ._base import client_patterns
-pattern = re.compile(r"(?:Element|SchildiChat)/1\.[012]\.")
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -69,14 +65,7 @@ class ReceiptRestServlet(RestServlet):
):
raise SynapseError(400, "Receipt type must be 'm.read'")
- # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
- user_agent = get_request_user_agent(request)
- allow_empty_body = False
- if "Android" in user_agent:
- if pattern.match(user_agent) or "Riot" in user_agent:
- allow_empty_body = True
- # This call makes sure possible empty body is handled correctly
- parse_json_object_from_request(request, allow_empty_body)
+ parse_json_object_from_request(request, allow_empty_body=False)
await self.presence_handler.bump_presence_active_time(requester.user)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 12fed856..a26e9764 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -34,7 +34,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
-from synapse.http.server import HttpServer
+from synapse.http.server import HttpServer, cancellable
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
@@ -143,6 +143,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
self.__class__.__name__,
)
+ @cancellable
def on_GET_no_state_key(
self, request: SynapseRequest, room_id: str, event_type: str
) -> Awaitable[Tuple[int, JsonDict]]:
@@ -153,6 +154,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_PUT(request, room_id, event_type, "")
+ @cancellable
async def on_GET(
self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
) -> Tuple[int, JsonDict]:
@@ -481,6 +483,7 @@ class RoomMemberListRestServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
+ @cancellable
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
@@ -602,6 +605,7 @@ class RoomStateRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
+ @cancellable
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, List[JsonDict]]:
@@ -646,6 +650,7 @@ class RoomEventServlet(RestServlet):
self.clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state = hs.get_state_handler()
+ self._storage_controllers = hs.get_storage_controllers()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
@@ -669,8 +674,10 @@ class RoomEventServlet(RestServlet):
if include_unredacted_content and not await self.auth.is_server_admin(
requester.user
):
- power_level_event = await self._state.get_current_state(
- room_id, EventTypes.PowerLevels, ""
+ power_level_event = (
+ await self._storage_controllers.state.get_current_state_event(
+ room_id, EventTypes.PowerLevels, ""
+ )
)
auth_events = {}
@@ -1189,12 +1196,7 @@ class TimestampLookupRestServlet(RestServlet):
class RoomHierarchyRestServlet(RestServlet):
- PATTERNS = (
- re.compile(
- "^/_matrix/client/(v1|unstable/org.matrix.msc2946)"
- "/rooms/(?P<room_id>[^/]*)/hierarchy$"
- ),
- )
+ PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/hierarchy$"),)
def __init__(self, hs: "HomeServer"):
super().__init__()
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e8772f86..8bbf3514 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -16,7 +16,7 @@ import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
-from synapse.api.constants import Membership, PresenceState
+from synapse.api.constants import EduTypes, Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
@@ -298,14 +298,6 @@ class SyncRestServlet(RestServlet):
if archived:
response["rooms"][Membership.LEAVE] = archived
- if sync_result.groups is not None:
- if sync_result.groups.join:
- response["groups"][Membership.JOIN] = sync_result.groups.join
- if sync_result.groups.invite:
- response["groups"][Membership.INVITE] = sync_result.groups.invite
- if sync_result.groups.leave:
- response["groups"][Membership.LEAVE] = sync_result.groups.leave
-
return response
@staticmethod
@@ -313,7 +305,7 @@ class SyncRestServlet(RestServlet):
return {
"events": [
{
- "type": "m.presence",
+ "type": EduTypes.PRESENCE,
"sender": event.user_id,
"content": format_user_presence_state(
event, time_now, include_user_id=False
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 3e5d6c62..7435fd91 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -65,7 +65,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
+# How often to run the background job to update the "recently accessed"
+# attribute of local and remote media.
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute
+# How often to run the background job to check for local and remote media
+# that should be purged according to the configured media retention settings.
+MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour
class MediaRepository:
@@ -122,11 +127,36 @@ class MediaRepository:
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
)
+ # Media retention configuration options
+ self._media_retention_local_media_lifetime_ms = (
+ hs.config.media.media_retention_local_media_lifetime_ms
+ )
+ self._media_retention_remote_media_lifetime_ms = (
+ hs.config.media.media_retention_remote_media_lifetime_ms
+ )
+
+ # Check whether local or remote media retention is configured
+ if (
+ hs.config.media.media_retention_local_media_lifetime_ms is not None
+ or hs.config.media.media_retention_remote_media_lifetime_ms is not None
+ ):
+ # Run the background job to apply media retention rules routinely,
+ # with the duration between runs dictated by the homeserver config.
+ self.clock.looping_call(
+ self._start_apply_media_retention_rules,
+ MEDIA_RETENTION_CHECK_PERIOD_MS,
+ )
+
def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
)
+ def _start_apply_media_retention_rules(self) -> Deferred:
+ return run_as_background_process(
+ "apply_media_retention_rules", self._apply_media_retention_rules
+ )
+
async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
@@ -557,15 +587,16 @@ class MediaRepository:
)
return None
- t_byte_source = await defer_to_thread(
- self.hs.get_reactor(),
- self._generate_thumbnail,
- thumbnailer,
- t_width,
- t_height,
- t_method,
- t_type,
- )
+ with thumbnailer:
+ t_byte_source = await defer_to_thread(
+ self.hs.get_reactor(),
+ self._generate_thumbnail,
+ thumbnailer,
+ t_width,
+ t_height,
+ t_method,
+ t_type,
+ )
if t_byte_source:
try:
@@ -627,15 +658,16 @@ class MediaRepository:
)
return None
- t_byte_source = await defer_to_thread(
- self.hs.get_reactor(),
- self._generate_thumbnail,
- thumbnailer,
- t_width,
- t_height,
- t_method,
- t_type,
- )
+ with thumbnailer:
+ t_byte_source = await defer_to_thread(
+ self.hs.get_reactor(),
+ self._generate_thumbnail,
+ thumbnailer,
+ t_width,
+ t_height,
+ t_method,
+ t_type,
+ )
if t_byte_source:
try:
@@ -719,124 +751,182 @@ class MediaRepository:
)
return None
- m_width = thumbnailer.width
- m_height = thumbnailer.height
+ with thumbnailer:
+ m_width = thumbnailer.width
+ m_height = thumbnailer.height
- if m_width * m_height >= self.max_image_pixels:
- logger.info(
- "Image too large to thumbnail %r x %r > %r",
- m_width,
- m_height,
- self.max_image_pixels,
- )
- return None
-
- if thumbnailer.transpose_method is not None:
- m_width, m_height = await defer_to_thread(
- self.hs.get_reactor(), thumbnailer.transpose
- )
-
- # We deduplicate the thumbnail sizes by ignoring the cropped versions if
- # they have the same dimensions of a scaled one.
- thumbnails: Dict[Tuple[int, int, str], str] = {}
- for requirement in requirements:
- if requirement.method == "crop":
- thumbnails.setdefault(
- (requirement.width, requirement.height, requirement.media_type),
- requirement.method,
- )
- elif requirement.method == "scale":
- t_width, t_height = thumbnailer.aspect(
- requirement.width, requirement.height
+ if m_width * m_height >= self.max_image_pixels:
+ logger.info(
+ "Image too large to thumbnail %r x %r > %r",
+ m_width,
+ m_height,
+ self.max_image_pixels,
)
- t_width = min(m_width, t_width)
- t_height = min(m_height, t_height)
- thumbnails[
- (t_width, t_height, requirement.media_type)
- ] = requirement.method
-
- # Now we generate the thumbnails for each dimension, store it
- for (t_width, t_height, t_type), t_method in thumbnails.items():
- # Generate the thumbnail
- if t_method == "crop":
- t_byte_source = await defer_to_thread(
- self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
+ return None
+
+ if thumbnailer.transpose_method is not None:
+ m_width, m_height = await defer_to_thread(
+ self.hs.get_reactor(), thumbnailer.transpose
)
- elif t_method == "scale":
- t_byte_source = await defer_to_thread(
- self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
+
+ # We deduplicate the thumbnail sizes by ignoring the cropped versions if
+ # they have the same dimensions of a scaled one.
+ thumbnails: Dict[Tuple[int, int, str], str] = {}
+ for requirement in requirements:
+ if requirement.method == "crop":
+ thumbnails.setdefault(
+ (requirement.width, requirement.height, requirement.media_type),
+ requirement.method,
+ )
+ elif requirement.method == "scale":
+ t_width, t_height = thumbnailer.aspect(
+ requirement.width, requirement.height
+ )
+ t_width = min(m_width, t_width)
+ t_height = min(m_height, t_height)
+ thumbnails[
+ (t_width, t_height, requirement.media_type)
+ ] = requirement.method
+
+ # Now we generate the thumbnails for each dimension, store it
+ for (t_width, t_height, t_type), t_method in thumbnails.items():
+ # Generate the thumbnail
+ if t_method == "crop":
+ t_byte_source = await defer_to_thread(
+ self.hs.get_reactor(),
+ thumbnailer.crop,
+ t_width,
+ t_height,
+ t_type,
+ )
+ elif t_method == "scale":
+ t_byte_source = await defer_to_thread(
+ self.hs.get_reactor(),
+ thumbnailer.scale,
+ t_width,
+ t_height,
+ t_type,
+ )
+ else:
+ logger.error("Unrecognized method: %r", t_method)
+ continue
+
+ if not t_byte_source:
+ continue
+
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ url_cache=url_cache,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
- else:
- logger.error("Unrecognized method: %r", t_method)
- continue
-
- if not t_byte_source:
- continue
-
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- url_cache=url_cache,
- thumbnail=ThumbnailInfo(
- width=t_width,
- height=t_height,
- method=t_method,
- type=t_type,
- ),
- )
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- try:
- await self.media_storage.write_to_file(t_byte_source, f)
- await finish()
- finally:
- t_byte_source.close()
-
- t_len = os.path.getsize(fname)
-
- # Write to database
- if server_name:
- # Multiple remote media download requests can race (when
- # using multiple media repos), so this may throw a violation
- # constraint exception. If it does we'll delete the newly
- # generated thumbnail from disk (as we're in the ctx
- # manager).
- #
- # However: we've already called `finish()` so we may have
- # also written to the storage providers. This is preferable
- # to the alternative where we call `finish()` *after* this,
- # where we could end up having an entry in the DB but fail
- # to write the files to the storage providers.
+ with self.media_storage.store_into_file(file_info) as (
+ f,
+ fname,
+ finish,
+ ):
try:
- await self.store.store_remote_media_thumbnail(
- server_name,
- media_id,
- file_id,
- t_width,
- t_height,
- t_type,
- t_method,
- t_len,
- )
- except Exception as e:
- thumbnail_exists = await self.store.get_remote_media_thumbnail(
- server_name,
- media_id,
- t_width,
- t_height,
- t_type,
+ await self.media_storage.write_to_file(t_byte_source, f)
+ await finish()
+ finally:
+ t_byte_source.close()
+
+ t_len = os.path.getsize(fname)
+
+ # Write to database
+ if server_name:
+ # Multiple remote media download requests can race (when
+ # using multiple media repos), so this may throw a violation
+ # constraint exception. If it does we'll delete the newly
+ # generated thumbnail from disk (as we're in the ctx
+ # manager).
+ #
+ # However: we've already called `finish()` so we may have
+ # also written to the storage providers. This is preferable
+ # to the alternative where we call `finish()` *after* this,
+ # where we could end up having an entry in the DB but fail
+ # to write the files to the storage providers.
+ try:
+ await self.store.store_remote_media_thumbnail(
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
+ )
+ except Exception as e:
+ thumbnail_exists = (
+ await self.store.get_remote_media_thumbnail(
+ server_name,
+ media_id,
+ t_width,
+ t_height,
+ t_type,
+ )
+ )
+ if not thumbnail_exists:
+ raise e
+ else:
+ await self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
)
- if not thumbnail_exists:
- raise e
- else:
- await self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
return {"width": m_width, "height": m_height}
+ async def _apply_media_retention_rules(self) -> None:
+ """
+ Purge old local and remote media according to the media retention rules
+ defined in the homeserver config.
+ """
+ # Purge remote media
+ if self._media_retention_remote_media_lifetime_ms is not None:
+ # Calculate a threshold timestamp derived from the configured lifetime. Any
+ # media that has not been accessed since this timestamp will be removed.
+ remote_media_threshold_timestamp_ms = (
+ self.clock.time_msec() - self._media_retention_remote_media_lifetime_ms
+ )
+
+ logger.info(
+ "Purging remote media last accessed before"
+ f" {remote_media_threshold_timestamp_ms}"
+ )
+
+ await self.delete_old_remote_media(
+ before_ts=remote_media_threshold_timestamp_ms
+ )
+
+ # And now do the same for local media
+ if self._media_retention_local_media_lifetime_ms is not None:
+ # This works the same as the remote media threshold
+ local_media_threshold_timestamp_ms = (
+ self.clock.time_msec() - self._media_retention_local_media_lifetime_ms
+ )
+
+ logger.info(
+ "Purging local media last accessed before"
+ f" {local_media_threshold_timestamp_ms}"
+ )
+
+ await self.delete_old_local_media(
+ before_ts=local_media_threshold_timestamp_ms,
+ keep_profiles=True,
+ delete_quarantined_media=False,
+ delete_protected_media=False,
+ )
+
async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
- old_media = await self.store.get_remote_media_before(before_ts)
+ old_media = await self.store.get_remote_media_ids(
+ before_ts, include_quarantined_media=False
+ )
deleted = 0
@@ -889,6 +979,8 @@ class MediaRepository:
before_ts: int,
size_gt: int = 0,
keep_profiles: bool = True,
+ delete_quarantined_media: bool = False,
+ delete_protected_media: bool = False,
) -> Tuple[List[str], int]:
"""
Delete local or remote media from this server by size and timestamp. Removes
@@ -896,18 +988,22 @@ class MediaRepository:
Args:
before_ts: Unix timestamp in ms.
- Files that were last used before this timestamp will be deleted
- size_gt: Size of the media in bytes. Files that are larger will be deleted
+ Files that were last used before this timestamp will be deleted.
+ size_gt: Size of the media in bytes. Files that are larger will be deleted.
keep_profiles: Switch to delete also files that are still used in image data
- (e.g user profile, room avatar)
- If false these files will be deleted
+ (e.g user profile, room avatar). If false these files will be deleted.
+ delete_quarantined_media: If True, media marked as quarantined will be deleted.
+ delete_protected_media: If True, media marked as protected will be deleted.
+
Returns:
A tuple of (list of deleted media IDs, total deleted media IDs).
"""
- old_media = await self.store.get_local_media_before(
+ old_media = await self.store.get_local_media_ids(
before_ts,
size_gt,
keep_profiles,
+ include_quarantined_media=delete_quarantined_media,
+ include_protected_media=delete_protected_media,
)
return await self._remove_local_media_from_disk(old_media)
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index ca73965f..ed8f21a4 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -30,6 +30,9 @@ _xml_encoding_match = re.compile(
)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
+# Certain elements aren't meant for display.
+ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"}
+
def _normalise_encoding(encoding: str) -> Optional[str]:
"""Use the Python codec's name as the normalised entry."""
@@ -174,13 +177,15 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og: Dict[str, Optional[str]] = {}
- for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
- if "content" in tag.attrib:
- # if we've got more than 50 tags, someone is taking the piss
- if len(og) >= 50:
- logger.warning("Skipping OG for page with too many 'og:' tags")
- return {}
- og[tag.attrib["property"]] = tag.attrib["content"]
+ for tag in tree.xpath(
+ "//*/meta[starts-with(@property, 'og:')][@content][not(@content='')]"
+ ):
+ # if we've got more than 50 tags, someone is taking the piss
+ if len(og) >= 50:
+ logger.warning("Skipping OG for page with too many 'og:' tags")
+ return {}
+
+ og[tag.attrib["property"]] = tag.attrib["content"]
# TODO: grab article: meta tags too, e.g.:
@@ -192,21 +197,23 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if "og:title" not in og:
- # do some basic spidering of the HTML
- title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
- if title and title[0].text is not None:
- og["og:title"] = title[0].text.strip()
+ # Attempt to find a title from the title tag, or the biggest header on the page.
+ title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()")
+ if title:
+ og["og:title"] = title[0].strip()
else:
og["og:title"] = None
if "og:image" not in og:
- # TODO: extract a favicon failing all else
meta_image = tree.xpath(
- "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
+ "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]"
)
+ # If a meta image is found, use it.
if meta_image:
og["og:image"] = meta_image[0]
else:
+ # Try to find images which are larger than 10px by 10px.
+ #
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(
@@ -215,17 +222,24 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
),
)
+ # If no images were found, try to find *any* images.
if not images:
- images = tree.xpath("//img[@src]")
+ images = tree.xpath("//img[@src][1]")
if images:
og["og:image"] = images[0].attrib["src"]
+ # Finally, fallback to the favicon if nothing else.
+ else:
+ favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]")
+ if favicons:
+ og["og:image"] = favicons[0]
+
if "og:description" not in og:
+ # Check the first meta description tag for content.
meta_description = tree.xpath(
- "//*/meta"
- "[translate(@name, 'DESCRIPTION', 'description')='description']"
- "/@content"
+ "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]"
)
+ # If a meta description is found with content, use it.
if meta_description:
og["og:description"] = meta_description[0]
else:
@@ -246,7 +260,9 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]:
Grabs any text nodes which are inside the <body/> tag, unless they are within
an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
- if they are within a <script/> or <style/> tag.
+ if they are within a <script/>, <svg/> or <style/> tag, or if they are within
+ a tag whose content is usually only shown to old browsers
+ (<iframe/>, <video/>, <canvas/>, <picture/>).
This is a very very very coarse approximation to a plain text render of the page.
@@ -268,6 +284,12 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]:
"script",
"noscript",
"style",
+ "svg",
+ "iframe",
+ "video",
+ "canvas",
+ "img",
+ "picture",
etree.Comment,
)
@@ -281,7 +303,7 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]:
def _iterate_over_text(
- tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+ tree: "etree.Element", *tags_to_ignore: Union[str, "etree.Comment"]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
@@ -298,6 +320,10 @@ def _iterate_over_text(
if isinstance(el, str):
yield el
elif el.tag not in tags_to_ignore:
+ # If the element isn't meant for display, ignore it.
+ if el.get("role") in ARIA_ROLES_TO_IGNORE:
+ continue
+
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 50383bdb..54a849ea 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -586,12 +586,16 @@ class PreviewUrlResource(DirectServeJsonResource):
og: The Open Graph dictionary. This is modified with image information.
"""
# If there's no image or it is blank, there's nothing to do.
- if "og:image" not in og or not og["og:image"]:
+ if "og:image" not in og:
+ return
+
+ # Remove the raw image URL, this will be replaced with an MXC URL, if successful.
+ image_url = og.pop("og:image")
+ if not image_url:
return
# The image URL from the HTML might be relative to the previewed page,
# convert it to an URL which can be requested directly.
- image_url = og["og:image"]
url_parts = urlparse(image_url)
if url_parts.scheme != "data":
image_url = urljoin(media_info.uri, image_url)
@@ -599,7 +603,16 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
- image_info = await self._handle_url(image_url, user, allow_data_urls=True)
+ try:
+ image_info = await self._handle_url(image_url, user, allow_data_urls=True)
+ except Exception as e:
+ # Pre-caching the image failed, don't block the entire URL preview.
+ logger.warning(
+ "Pre-caching image failed during URL preview: %s errored with %s",
+ image_url,
+ e,
+ )
+ return
if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images
@@ -611,13 +624,11 @@ class PreviewUrlResource(DirectServeJsonResource):
og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"]
else:
- logger.warning("Couldn't get dims for %s", og["og:image"])
+ logger.warning("Couldn't get dims for %s", image_url)
og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
og["og:image:type"] = image_info.media_type
og["matrix:image:size"] = image_info.media_length
- else:
- del og["og:image"]
async def _handle_oembed_response(
self, url: str, media_info: MediaInfo, expiration_ms: int
@@ -668,7 +679,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Running url preview cache expiry")
if not (await self.store.db_pool.updates.has_completed_background_updates()):
- logger.info("Still running DB updates; skipping expiry")
+ logger.debug("Still running DB updates; skipping url preview cache expiry")
return
def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
@@ -688,7 +699,9 @@ class PreviewUrlResource(DirectServeJsonResource):
# Failed, skip deleting the rest of the parent dirs
if e.errno != errno.ENOTEMPTY:
logger.warning(
- "Failed to remove media directory: %r: %s", dir, e
+ "Failed to remove media directory while clearing url preview cache: %r: %s",
+ dir,
+ e,
)
break
@@ -703,7 +716,11 @@ class PreviewUrlResource(DirectServeJsonResource):
except FileNotFoundError:
pass # If the path doesn't exist, meh
except OSError as e:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
+ logger.warning(
+ "Failed to remove media while clearing url preview cache: %r: %s",
+ media_id,
+ e,
+ )
continue
removed_media.append(media_id)
@@ -714,9 +731,11 @@ class PreviewUrlResource(DirectServeJsonResource):
await self.store.delete_url_cache(removed_media)
if removed_media:
- logger.info("Deleted %d entries from url cache", len(removed_media))
+ logger.debug(
+ "Deleted %d entries from url preview cache", len(removed_media)
+ )
else:
- logger.debug("No entries removed from url cache")
+ logger.debug("No entries removed from url preview cache")
# Now we delete old images associated with the url cache.
# These may be cached for a bit on the client (i.e., they
@@ -733,7 +752,9 @@ class PreviewUrlResource(DirectServeJsonResource):
except FileNotFoundError:
pass # If the path doesn't exist, meh
except OSError as e:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
+ logger.warning(
+ "Failed to remove media from url preview cache: %r: %s", media_id, e
+ )
continue
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
@@ -745,7 +766,9 @@ class PreviewUrlResource(DirectServeJsonResource):
except FileNotFoundError:
pass # If the path doesn't exist, meh
except OSError as e:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
+ logger.warning(
+ "Failed to remove media from url preview cache: %r: %s", media_id, e
+ )
continue
removed_media.append(media_id)
@@ -758,9 +781,9 @@ class PreviewUrlResource(DirectServeJsonResource):
await self.store.delete_url_cache_media(removed_media)
if removed_media:
- logger.info("Deleted %d media from url cache", len(removed_media))
+ logger.debug("Deleted %d media from url preview cache", len(removed_media))
else:
- logger.debug("No media removed from url cache")
+ logger.debug("No media removed from url preview cache")
def _is_media(content_type: str) -> bool:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 390491eb..9b93b9b4 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -14,7 +14,8 @@
# limitations under the License.
import logging
from io import BytesIO
-from typing import Tuple
+from types import TracebackType
+from typing import Optional, Tuple, Type
from PIL import Image
@@ -45,6 +46,9 @@ class Thumbnailer:
Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str):
+ # Have we closed the image?
+ self._closed = False
+
try:
self.image = Image.open(input_path)
except OSError as e:
@@ -89,7 +93,8 @@ class Thumbnailer:
# Safety: `transpose` takes an int rather than e.g. an IntEnum.
# self.transpose_method is set above to be a value in
# EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values.
- self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
+ with self.image:
+ self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
self.width, self.height = self.image.size
self.transpose_method = None
# We don't need EXIF any more
@@ -122,9 +127,11 @@ class Thumbnailer:
# If the image has transparency, use RGBA instead.
if self.image.mode in ["1", "L", "P"]:
if self.image.info.get("transparency", None) is not None:
- self.image = self.image.convert("RGBA")
+ with self.image:
+ self.image = self.image.convert("RGBA")
else:
- self.image = self.image.convert("RGB")
+ with self.image:
+ self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
@@ -133,8 +140,8 @@ class Thumbnailer:
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
- scaled = self._resize(width, height)
- return self._encode_image(scaled, output_type)
+ with self._resize(width, height) as scaled:
+ return self._encode_image(scaled, output_type)
def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving
@@ -151,18 +158,21 @@ class Thumbnailer:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
if width * self.height > height * self.width:
+ scaled_width = width
scaled_height = (width * self.height) // self.width
- scaled_image = self._resize(width, scaled_height)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
- cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
+ crop = (0, crop_top, width, crop_bottom)
else:
scaled_width = (height * self.width) // self.height
- scaled_image = self._resize(scaled_width, height)
+ scaled_height = height
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
- cropped = scaled_image.crop((crop_left, 0, crop_right, height))
- return self._encode_image(cropped, output_type)
+ crop = (crop_left, 0, crop_right, height)
+
+ with self._resize(scaled_width, scaled_height) as scaled_image:
+ with scaled_image.crop(crop) as cropped:
+ return self._encode_image(cropped, output_type)
def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
@@ -171,3 +181,42 @@ class Thumbnailer:
output_image = output_image.convert("RGB")
output_image.save(output_bytes_io, fmt, quality=80)
return output_bytes_io
+
+ def close(self) -> None:
+ """Closes the underlying image file.
+
+ Once closed no other functions can be called.
+
+ Can be called multiple times.
+ """
+
+ if self._closed:
+ return
+
+ self._closed = True
+
+ # Since we run this on the finalizer then we need to handle `__init__`
+ # raising an exception before it can define `self.image`.
+ image = getattr(self, "image", None)
+ if image is None:
+ return
+
+ image.close()
+
+ def __enter__(self) -> "Thumbnailer":
+ """Make `Thumbnailer` a context manager that calls `close` on
+ `__exit__`.
+ """
+ return self
+
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ def __del__(self) -> None:
+ # Make sure we actually do close the image, rather than leak data.
+ self.close()
diff --git a/synapse/server.py b/synapse/server.py
index d49c7651..a66ec228 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -21,17 +21,7 @@
import abc
import functools
import logging
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Dict,
- List,
- Optional,
- TypeVar,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
from twisted.internet.interfaces import IOpenSSLContextFactory
from twisted.internet.tcp import Port
@@ -60,8 +50,6 @@ from synapse.federation.federation_server import (
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.federation.transport.client import TransportLayerClient
-from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
-from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
from synapse.handlers.account import AccountHandler
from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler
@@ -79,7 +67,6 @@ from synapse.handlers.event_auth import EventAuthHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.federation import FederationHandler
from synapse.handlers.federation_event import FederationEventHandler
-from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
from synapse.handlers.identity import IdentityHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
@@ -119,7 +106,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.module_api import ModuleApi
from synapse.notifier import Notifier
-from synapse.push.action_generator import ActionGenerator
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.external_cache import ExternalCache
@@ -136,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
-from synapse.storage import Databases, Storage
+from synapse.storage import Databases
+from synapse.storage.controllers import StorageControllers
from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
@@ -644,44 +632,20 @@ class HomeServer(metaclass=abc.ABCMeta):
return ReplicationCommandHandler(self)
@cache_in_self
- def get_action_generator(self) -> ActionGenerator:
- return ActionGenerator(self)
+ def get_bulk_push_rule_evaluator(self) -> BulkPushRuleEvaluator:
+ return BulkPushRuleEvaluator(self)
@cache_in_self
def get_user_directory_handler(self) -> UserDirectoryHandler:
return UserDirectoryHandler(self)
@cache_in_self
- def get_groups_local_handler(
- self,
- ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
- if self.config.worker.worker_app:
- return GroupsLocalWorkerHandler(self)
- else:
- return GroupsLocalHandler(self)
-
- @cache_in_self
- def get_groups_server_handler(self):
- if self.config.worker.worker_app:
- return GroupsServerWorkerHandler(self)
- else:
- return GroupsServerHandler(self)
-
- @cache_in_self
- def get_groups_attestation_signing(self) -> GroupAttestationSigning:
- return GroupAttestationSigning(self)
-
- @cache_in_self
- def get_groups_attestation_renewer(self) -> GroupAttestionRenewer:
- return GroupAttestionRenewer(self)
-
- @cache_in_self
def get_stats_handler(self) -> StatsHandler:
return StatsHandler(self)
@cache_in_self
def get_spam_checker(self) -> SpamChecker:
- return SpamChecker()
+ return SpamChecker(self)
@cache_in_self
def get_third_party_event_rules(self) -> ThirdPartyEventRules:
@@ -766,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta):
return PasswordPolicyHandler(self)
@cache_in_self
- def get_storage(self) -> Storage:
- return Storage(self, self.get_datastores())
+ def get_storage_controllers(self) -> StorageControllers:
+ return StorageControllers(self, self.get_datastores())
@cache_in_self
def get_replication_streamer(self) -> ReplicationStreamer:
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 015dd08f..68630207 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -21,7 +21,6 @@ from synapse.api.constants import (
ServerNoticeMsgType,
)
from synapse.api.errors import AuthError, ResourceLimitError, SynapseError
-from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -37,6 +36,7 @@ class ResourceLimitsServerNotices:
def __init__(self, hs: "HomeServer"):
self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._config = hs.config
self._resouce_limited = False
@@ -71,18 +71,19 @@ class ResourceLimitsServerNotices:
# In practice, not sure we can ever get here
return
- room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
+ # Check if there's a server notice room for this user.
+ room_id = await self._server_notices_manager.maybe_get_notice_room_for_user(
user_id
)
- if not room_id:
- logger.warning("Failed to get server notices room")
- return
-
- await self._check_and_set_tags(user_id, room_id)
-
- # Determine current state of room
- currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
+ if room_id is not None:
+ # Determine current state of room
+ currently_blocked, ref_events = await self._is_room_currently_blocked(
+ room_id
+ )
+ else:
+ currently_blocked = False
+ ref_events = []
limit_msg = None
limit_type = None
@@ -161,26 +162,6 @@ class ResourceLimitsServerNotices:
user_id, content, EventTypes.Pinned, ""
)
- async def _check_and_set_tags(self, user_id: str, room_id: str) -> None:
- """
- Since server notices rooms were originally not with tags,
- important to check that tags have been set correctly
- Args:
- user_id(str): the user in question
- room_id(str): the server notices room for that user
- """
- tags = await self._store.get_tags_for_room(user_id, room_id)
- need_to_set_tag = True
- if tags:
- if SERVER_NOTICE_ROOM_TAG in tags:
- # tag already present, nothing to do here
- need_to_set_tag = False
- if need_to_set_tag:
- max_id = await self._account_data_handler.add_tag_to_room(
- user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
- )
- self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str]]:
"""
Determines if the room is currently blocked
@@ -198,8 +179,10 @@ class ResourceLimitsServerNotices:
currently_blocked = False
pinned_state_event = None
try:
- pinned_state_event = await self._state.get_current_state(
- room_id, event_type=EventTypes.Pinned
+ pinned_state_event = (
+ await self._storage_controllers.state.get_current_state_event(
+ room_id, event_type=EventTypes.Pinned, state_key=""
+ )
)
except AuthError:
# The user has yet to join the server notices room
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 48eae5fa..8ecab86e 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.events import EventBase
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import Requester, StreamKeyType, UserID, create_requester
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -91,6 +91,35 @@ class ServerNoticesManager:
return event
@cached()
+ async def maybe_get_notice_room_for_user(self, user_id: str) -> Optional[str]:
+ """Try to look up the server notice room for this user if it exists.
+
+ Does not create one if none can be found.
+
+ Args:
+ user_id: the user we want a server notice room for.
+
+ Returns:
+ The room's ID, or None if no room could be found.
+ """
+ rooms = await self._store.get_rooms_for_local_user_where_membership_is(
+ user_id, [Membership.INVITE, Membership.JOIN]
+ )
+ for room in rooms:
+ # it's worth noting that there is an asymmetry here in that we
+ # expect the user to be invited or joined, but the system user must
+ # be joined. This is kinda deliberate, in that if somebody somehow
+ # manages to invite the system user to a room, that doesn't make it
+ # the server notices room.
+ user_ids = await self._store.get_users_in_room(room.room_id)
+ if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
+ # we found a room which our user shares with the system notice
+ # user
+ return room.room_id
+
+ return None
+
+ @cached()
async def get_or_create_notice_room_for_user(self, user_id: str) -> str:
"""Get the room for notices for a given user
@@ -112,31 +141,20 @@ class ServerNoticesManager:
self.server_notices_mxid, authenticated_entity=self._server_name
)
- rooms = await self._store.get_rooms_for_local_user_where_membership_is(
- user_id, [Membership.INVITE, Membership.JOIN]
- )
- for room in rooms:
- # it's worth noting that there is an asymmetry here in that we
- # expect the user to be invited or joined, but the system user must
- # be joined. This is kinda deliberate, in that if somebody somehow
- # manages to invite the system user to a room, that doesn't make it
- # the server notices room.
- user_ids = await self._store.get_users_in_room(room.room_id)
- if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
- # we found a room which our user shares with the system notice
- # user
- logger.info(
- "Using existing server notices room %s for user %s",
- room.room_id,
- user_id,
- )
- await self._update_notice_user_profile_if_changed(
- requester,
- room.room_id,
- self._config.servernotices.server_notices_mxid_display_name,
- self._config.servernotices.server_notices_mxid_avatar_url,
- )
- return room.room_id
+ room_id = await self.maybe_get_notice_room_for_user(user_id)
+ if room_id is not None:
+ logger.info(
+ "Using existing server notices room %s for user %s",
+ room_id,
+ user_id,
+ )
+ await self._update_notice_user_profile_if_changed(
+ requester,
+ room_id,
+ self._config.servernotices.server_notices_mxid_display_name,
+ self._config.servernotices.server_notices_mxid_avatar_url,
+ )
+ return room_id
# apparently no existing notice room: create a new one
logger.info("Creating server notices room for %s", user_id)
@@ -166,10 +184,12 @@ class ServerNoticesManager:
)
room_id = info["room_id"]
+ self.maybe_get_notice_room_for_user.invalidate((user_id,))
+
max_id = await self._account_data_handler.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
- self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ self._notifier.on_new_event(StreamKeyType.ACCOUNT_DATA, max_id, users=[user_id])
logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 73018f2d..75578270 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -16,7 +16,7 @@ from enum import Enum
class RegistrationBehaviour(Enum):
"""
- Enum to define whether a registration request should allowed, denied, or shadow-banned.
+ Enum to define whether a registration request should be allowed, denied, or shadow-banned.
"""
ALLOW = "allow"
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index cad3b426..da25f20a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,13 +32,11 @@ from typing import (
Set,
Tuple,
Union,
- overload,
)
import attr
from frozendict import frozendict
from prometheus_client import Counter, Histogram
-from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -127,89 +125,25 @@ class StateHandler:
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- self.state_store = hs.get_storage().state
+ self._state_storage_controller = hs.get_storage_controllers().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
+ self._storage_controllers = hs.get_storage_controllers()
- @overload
- async def get_current_state(
- self,
- room_id: str,
- event_type: Literal[None] = None,
- state_key: str = "",
- latest_event_ids: Optional[List[str]] = None,
- ) -> StateMap[EventBase]:
- ...
-
- @overload
- async def get_current_state(
- self,
- room_id: str,
- event_type: str,
- state_key: str = "",
- latest_event_ids: Optional[List[str]] = None,
- ) -> Optional[EventBase]:
- ...
-
- async def get_current_state(
+ async def get_current_state_ids(
self,
room_id: str,
- event_type: Optional[str] = None,
- state_key: str = "",
- latest_event_ids: Optional[List[str]] = None,
- ) -> Union[Optional[EventBase], StateMap[EventBase]]:
- """Retrieves the current state for the room. This is done by
- calling `get_latest_events_in_room` to get the leading edges of the
- event graph and then resolving any of the state conflicts.
-
- This is equivalent to getting the state of an event that were to send
- next before receiving any new events.
-
- Returns:
- If `event_type` is specified, then the method returns only the one
- event (or None) with that `event_type` and `state_key`.
-
- Otherwise, a map from (type, state_key) to event.
- """
- if not latest_event_ids:
- latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
- assert latest_event_ids is not None
-
- logger.debug("calling resolve_state_groups from get_current_state")
- ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
- state = ret.state
-
- if event_type:
- event_id = state.get((event_type, state_key))
- event = None
- if event_id:
- event = await self.store.get_event(event_id, allow_none=True)
- return event
-
- state_map = await self.store.get_events(
- list(state.values()), get_prev_content=False
- )
- return {
- key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
- }
-
- async def get_current_state_ids(
- self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
+ latest_event_ids: Collection[str],
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
room_id:
- latest_event_ids: if given, the forward extremities to resolve. If
- None, we look them up from the database (via a cache).
+ latest_event_ids: The forward extremities to resolve.
Returns:
the state dict, mapping from (event_type, state_key) -> event_id
"""
- if not latest_event_ids:
- latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
- assert latest_event_ids is not None
-
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state
@@ -238,13 +172,9 @@ class StateHandler:
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)
- async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
- event_ids = await self.store.get_latest_event_ids_in_room(room_id)
- return await self.get_hosts_in_room_at_events(room_id, event_ids)
-
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
- ) -> Set[str]:
+ ) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids
Args:
@@ -260,7 +190,7 @@ class StateHandler:
async def compute_event_context(
self,
event: EventBase,
- old_state: Optional[Iterable[EventBase]] = None,
+ state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@@ -272,12 +202,12 @@ class StateHandler:
Args:
event:
- old_state: The state at the event if it can't be
- calculated from existing events. This is normally only specified
- when receiving an event from federation where we don't have the
- prev events for, e.g. when backfilling.
- partial_state: True if `old_state` is partial and omits non-critical
- membership events
+ state_ids_before_event: The event ids of the state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling.
+ partial_state: True if `state_ids_before_event` is partial and omits
+ non-critical membership events
Returns:
The event context.
"""
@@ -285,14 +215,11 @@ class StateHandler:
assert not event.internal_metadata.is_outlier()
#
- # first of all, figure out the state before the event
+ # first of all, figure out the state before the event, unless we
+ # already have it.
#
-
- if old_state:
+ if state_ids_before_event:
# if we're given the state before the event, then we use that
- state_ids_before_event: StateMap[str] = {
- (s.type, s.state_key): s.event_id for s in old_state
- }
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
@@ -339,12 +266,14 @@ class StateHandler:
#
if not state_group_before_event:
- state_group_before_event = await self.state_store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- current_state_ids=state_ids_before_event,
+ state_group_before_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ )
)
# Assign the new state group to the cached state entry.
@@ -361,10 +290,10 @@ class StateHandler:
if not event.is_state():
return EventContext.with_state(
+ storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
state_group=state_group_before_event,
- current_state_ids=state_ids_before_event,
- prev_state_ids=state_ids_before_event,
+ state_delta_due_to_event={},
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
partial_state=partial_state,
@@ -384,19 +313,21 @@ class StateHandler:
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
- state_group_after_event = await self.state_store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=state_ids_after_event,
+ state_group_after_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ current_state_ids=state_ids_after_event,
+ )
)
return EventContext.with_state(
+ storage=self._storage_controllers,
state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
- current_state_ids=state_ids_after_event,
- prev_state_ids=state_ids_before_event,
+ state_delta_due_to_event=delta_ids,
prev_group=state_group_before_event,
delta_ids=delta_ids,
partial_state=partial_state,
@@ -418,33 +349,44 @@ class StateHandler:
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
- # map from state group id to the state in that state group (where
- # 'state' is a map from state key to event id)
- # dict[int, dict[(str, str), str]]
- state_groups_ids = await self.state_store.get_state_groups_ids(
- room_id, event_ids
+ state_groups = await self._state_storage_controller.get_state_group_for_events(
+ event_ids
)
- if len(state_groups_ids) == 0:
- return _StateCacheEntry(state={}, state_group=None)
- elif len(state_groups_ids) == 1:
- name, state_list = list(state_groups_ids.items()).pop()
-
- prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
+ state_group_ids = state_groups.values()
+ # check if each event has same state group id, if so there's no state to resolve
+ state_group_ids_set = set(state_group_ids)
+ if len(state_group_ids_set) == 1:
+ (state_group_id,) = state_group_ids_set
+ state = await self._state_storage_controller.get_state_for_groups(
+ state_group_ids_set
+ )
+ (
+ prev_group,
+ delta_ids,
+ ) = await self._state_storage_controller.get_state_group_delta(
+ state_group_id
+ )
return _StateCacheEntry(
- state=state_list,
- state_group=name,
+ state=state[state_group_id],
+ state_group=state_group_id,
prev_group=prev_group,
delta_ids=delta_ids,
)
+ elif len(state_group_ids_set) == 0:
+ return _StateCacheEntry(state={}, state_group=None)
room_version = await self.store.get_room_version_id(room_id)
+ state_to_resolve = await self._state_storage_controller.get_state_for_groups(
+ state_group_ids_set
+ )
+
result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
- state_groups_ids,
+ state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 105e4e1f..bac21ecf 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple
databases). The `DatabasePool` class represents connections to a single physical
database. The `databases` are classes that talk directly to a `DatabasePool`
-instance and have associated schemas, background updates, etc. On top of those
-there are classes that provide high level interfaces that combine calls to
-multiple `databases`.
+instance and have associated schemas, background updates, etc.
+
+On top of the databases are the StorageControllers, located in the
+`synapse.storage.controllers` module. These classes provide high level
+interfaces that combine calls to multiple `databases`. They are bundled into the
+`StorageControllers` singleton for ease of use, and exposed via
+`HomeServer.get_storage_controllers()`.
There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
-from typing import TYPE_CHECKING
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
-from synapse.storage.persist_events import EventsPersistenceStorage
-from synapse.storage.purge_events import PurgeEventsStorage
-from synapse.storage.state import StateGroupStorage
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
__all__ = ["Databases", "DataStore"]
-
-
-class Storage:
- """The high level interfaces for talking to various storage layers."""
-
- def __init__(self, hs: "HomeServer", stores: Databases):
- # We include the main data store here mainly so that we don't have to
- # rewrite all the existing code to split it into high vs low level
- # interfaces.
- self.main = stores.main
-
- self.purge_events = PurgeEventsStorage(hs, stores)
- self.state = StateGroupStorage(hs, stores)
-
- self.persistence = None
- if stores.persist_events:
- self.persistence = EventsPersistenceStorage(hs, stores)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8df80664..abfc56b0 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -71,13 +71,14 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
if members_changed:
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+ self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,)
)
# Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
- self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
+ self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 08c6eabc..b1e5208c 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,20 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from types import TracebackType
from typing import (
TYPE_CHECKING,
+ Any,
AsyncContextManager,
Awaitable,
Callable,
Dict,
Iterable,
+ List,
Optional,
+ Type,
)
import attr
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor
from synapse.types import JsonDict
from synapse.util import Clock, json_encoder
@@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager:
return self._update_duration_ms
- async def __aexit__(self, *exc) -> None:
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
pass
@@ -273,12 +282,20 @@ class BackgroundUpdater:
self._running = True
+ back_to_back_failures = 0
+
try:
logger.info("Starting background schema updates")
while self.enabled:
try:
result = await self.do_next_background_update(sleep)
+ back_to_back_failures = 0
except Exception:
+ back_to_back_failures += 1
+ if back_to_back_failures >= 5:
+ raise RuntimeError(
+ "5 back-to-back background update failures; aborting."
+ )
logger.exception("Error doing update")
else:
if result:
@@ -352,7 +369,7 @@ class BackgroundUpdater:
True if we have finished running all the background updates, otherwise False
"""
- def get_background_updates_txn(txn):
+ def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
txn.execute(
"""
SELECT update_name, depends_on FROM background_updates
@@ -469,7 +486,7 @@ class BackgroundUpdater:
self,
update_name: str,
update_handler: Callable[[JsonDict, int], Awaitable[int]],
- ):
+ ) -> None:
"""Register a handler for doing a background update.
The handler should take two arguments:
@@ -518,6 +535,7 @@ class BackgroundUpdater:
where_clause: Optional[str] = None,
unique: bool = False,
psql_only: bool = False,
+ replaces_index: Optional[str] = None,
) -> None:
"""Helper for store classes to do a background index addition
@@ -537,6 +555,8 @@ class BackgroundUpdater:
unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
+ replaces_index: The name of an index that this index replaces.
+ The named index will be dropped upon completion of the new index.
"""
def create_index_psql(conn: Connection) -> None:
@@ -568,6 +588,12 @@ class BackgroundUpdater:
}
logger.debug("[SQL] %s", sql)
c.execute(sql)
+
+ if replaces_index is not None:
+ # We drop the old index as the new index has now been created.
+ sql = f"DROP INDEX IF EXISTS {replaces_index}"
+ logger.debug("[SQL] %s", sql)
+ c.execute(sql)
finally:
conn.set_session(autocommit=False) # type: ignore
@@ -596,6 +622,12 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
+ if replaces_index is not None:
+ # We drop the old index as the new index has now been created.
+ sql = f"DROP INDEX IF EXISTS {replaces_index}"
+ logger.debug("[SQL] %s", sql)
+ c.execute(sql)
+
if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner: Optional[Callable[[Connection], None]] = create_index_psql
elif psql_only:
@@ -603,7 +635,7 @@ class BackgroundUpdater:
else:
runner = create_index_sqlite
- async def updater(progress, batch_size):
+ async def updater(progress: JsonDict, batch_size: int) -> int:
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
await self.db_pool.runWithConnection(runner)
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
new file mode 100644
index 00000000..55649719
--- /dev/null
+++ b/synapse/storage/controllers/__init__.py
@@ -0,0 +1,46 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from synapse.storage.controllers.persist_events import (
+ EventsPersistenceStorageController,
+)
+from synapse.storage.controllers.purge_events import PurgeEventsStorageController
+from synapse.storage.controllers.state import StateStorageController
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main import DataStore
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
+
+
+class StorageControllers:
+ """The high level interfaces for talking to various storage controller layers."""
+
+ def __init__(self, hs: "HomeServer", stores: Databases):
+ # We include the main data store here mainly so that we don't have to
+ # rewrite all the existing code to split it into high vs low level
+ # interfaces.
+ self.main = stores.main
+
+ self.purge_events = PurgeEventsStorageController(hs, stores)
+ self.state = StateStorageController(hs, stores)
+
+ self.persistence = None
+ if stores.persist_events:
+ self.persistence = EventsPersistenceStorageController(hs, stores)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/controllers/persist_events.py
index 97118045..4caaa818 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -25,6 +25,7 @@ from typing import (
Collection,
Deque,
Dict,
+ Generator,
Generic,
Iterable,
List,
@@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
return res
- def _handle_queue(self, room_id):
+ def _handle_queue(self, room_id: str) -> None:
"""Attempts to handle the queue for a room if not already being handled.
The queue's callback will be invoked with for each item in the queue,
@@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
self._currently_persisting_rooms.add(room_id)
- async def handle_queue_loop():
+ async def handle_queue_loop() -> None:
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
@@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
with PreserveLoggingContext():
item.deferred.callback(ret)
finally:
- queue = self._event_persist_queues.pop(room_id, None)
- if queue:
- self._event_persist_queues[room_id] = queue
+ remaining_queue = self._event_persist_queues.pop(room_id, None)
+ if remaining_queue:
+ self._event_persist_queues[room_id] = remaining_queue
self._currently_persisting_rooms.discard(room_id)
# set handle_queue_loop off in the background
run_as_background_process("persist_events", handle_queue_loop)
- def _get_drainining_queue(self, room_id):
+ def _get_drainining_queue(
+ self, room_id: str
+ ) -> Generator[_EventPersistQueueItem, None, None]:
queue = self._event_persist_queues.setdefault(room_id, deque())
try:
@@ -269,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
pass
-class EventsPersistenceStorage:
+class EventsPersistenceStorageController:
"""High level interface for handling persisting newly received events.
Takes care of batching up events by room, and calculating the necessary
@@ -310,14 +313,16 @@ class EventsPersistenceStorage:
List of events persisted, the current position room stream position.
The list of events persisted may not be the same as those passed in
if they were deduplicated due to an event already existing that
- matched the transcation ID; the existing event is returned in such
+ matched the transaction ID; the existing event is returned in such
a case.
"""
partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
- async def enqueue(item):
+ async def enqueue(
+ item: Tuple[str, List[Tuple[EventBase, EventContext]]]
+ ) -> Dict[str, str]:
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
@@ -487,12 +492,6 @@ class EventsPersistenceStorage:
# extremities in each room
new_forward_extremities: Dict[str, Set[str]] = {}
- # map room_id->(type,state_key)->event_id tracking the full
- # state in each room after adding these events.
- # This is simply used to prefill the get_current_state_ids
- # cache
- current_state_for_room: Dict[str, StateMap[str]] = {}
-
# map room_id->(to_delete, to_insert) where to_delete is a list
# of type/state keys to remove from current state, and to_insert
# is a map (type,key)->event_id giving the state delta in each
@@ -628,14 +627,8 @@ class EventsPersistenceStorage:
state_delta_for_room[room_id] = delta
- # If we have the current_state then lets prefill
- # the cache with it.
- if current_state is not None:
- current_state_for_room[room_id] = current_state
-
await self.persist_events_store._persist_events_and_state_updates(
chunk,
- current_state_for_room=current_state_for_room,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
@@ -733,7 +726,8 @@ class EventsPersistenceStorage:
The first state map is the full new current state and the second
is the delta to the existing current state. If both are None then
- there has been no change.
+ there has been no change. Either or neither can be None if there
+ has been a change.
The function may prune some old entries from the set of new
forward extremities if it's safe to do so.
@@ -743,9 +737,6 @@ class EventsPersistenceStorage:
the new current state is only returned if we've already calculated
it.
"""
- # map from state_group to ((type, key) -> event_id) state map
- state_groups_map = {}
-
# Map from (prev state group, new state group) -> delta state dict
state_group_deltas = {}
@@ -759,16 +750,6 @@ class EventsPersistenceStorage:
)
continue
- if ctx.state_group in state_groups_map:
- continue
-
- # We're only interested in pulling out state that has already
- # been cached in the context. We'll pull stuff out of the DB later
- # if necessary.
- current_state_ids = ctx.get_cached_current_state_ids()
- if current_state_ids is not None:
- state_groups_map[ctx.state_group] = current_state_ids
-
if ctx.prev_group:
state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
@@ -826,18 +807,14 @@ class EventsPersistenceStorage:
delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
if delta_ids is not None:
# We have a delta from the existing to new current state,
- # so lets just return that. If we happen to already have
- # the current state in memory then lets also return that,
- # but it doesn't matter if we don't.
- new_state = state_groups_map.get(new_state_group)
- return new_state, delta_ids, new_latest_event_ids
+ # so lets just return that.
+ return None, delta_ids, new_latest_event_ids
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
- missing_state = new_state_groups - set(state_groups_map)
- if missing_state:
- group_to_state = await self.state_store._get_state_for_groups(missing_state)
- state_groups_map.update(group_to_state)
+ state_groups_map = await self.state_store._get_state_for_groups(
+ new_state_groups
+ )
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
@@ -1017,7 +994,7 @@ class EventsPersistenceStorage:
Assumes that we are only persisting events for one room at a time.
"""
- existing_state = await self.main_store.get_current_state_ids(room_id)
+ existing_state = await self.main_store.get_partial_current_state_ids(room_id)
to_delete = [key for key in existing_state if key not in current_state]
@@ -1106,7 +1083,7 @@ class EventsPersistenceStorage:
# The server will leave the room, so we go and find out which remote
# users will still be joined when we leave.
if current_state is None:
- current_state = await self.main_store.get_current_state_ids(room_id)
+ current_state = await self.main_store.get_partial_current_state_ids(room_id)
current_state = dict(current_state)
for key in delta.to_delete:
current_state.pop(key, None)
@@ -1130,7 +1107,7 @@ class EventsPersistenceStorage:
return False
- async def _handle_potentially_left_users(self, user_ids: Set[str]):
+ async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
"""Given a set of remote users check if the server still shares a room with
them. If not then mark those users' device cache as stale.
"""
diff --git a/synapse/storage/purge_events.py b/synapse/storage/controllers/purge_events.py
index 30669beb..9ca50d6a 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/controllers/purge_events.py
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class PurgeEventsStorage:
+class PurgeEventsStorageController:
"""High level interface for purging rooms and event history."""
def __init__(self, hs: "HomeServer", stores: Databases):
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
new file mode 100644
index 00000000..3b4cdb67
--- /dev/null
+++ b/synapse/storage/controllers/state.py
@@ -0,0 +1,492 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+)
+
+from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.storage.state import StateFilter
+from synapse.storage.util.partial_state_events_tracker import (
+ PartialCurrentStateTracker,
+ PartialStateEventsTracker,
+)
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases import Databases
+
+logger = logging.getLogger(__name__)
+
+
+class StateStorageController:
+ """High level interface to fetching state for an event, or the current state
+ in a room.
+ """
+
+ def __init__(self, hs: "HomeServer", stores: "Databases"):
+ self._is_mine_id = hs.is_mine_id
+ self.stores = stores
+ self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+ self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
+
+ def notify_event_un_partial_stated(self, event_id: str) -> None:
+ self._partial_state_events_tracker.notify_un_partial_stated(event_id)
+
+ def notify_room_un_partial_stated(self, room_id: str) -> None:
+ """Notify that the room no longer has any partial state.
+
+ Must be called after `DataStore.clear_partial_state_room`
+ """
+ self._partial_state_room_tracker.notify_un_partial_stated(room_id)
+
+ async def get_state_group_delta(
+ self, state_group: int
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Args:
+ state_group: The state group used to retrieve state deltas.
+
+ Returns:
+ A tuple of the previous group and a state map of the event IDs which
+ make up the delta between the old and new state groups.
+ """
+
+ state_group_delta = await self.stores.state.get_state_group_delta(state_group)
+ return state_group_delta.prev_group, state_group_delta.delta_ids
+
+ async def get_state_groups_ids(
+ self, _room_id: str, event_ids: Collection[str]
+ ) -> Dict[int, MutableStateMap[str]]:
+ """Get the event IDs of all the state for the state groups for the given events
+
+ Args:
+ _room_id: id of the room for these events
+ event_ids: ids of the events
+
+ Returns:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ if not event_ids:
+ return {}
+
+ event_to_groups = await self.get_state_group_for_events(event_ids)
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(groups)
+
+ return group_to_state
+
+ async def get_state_ids_for_group(
+ self, state_group: int, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[str]:
+ """Get the event IDs of all the state in the given state group
+
+ Args:
+ state_group: A state group for which we want to get the state IDs.
+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
+ Returns:
+ Resolves to a map of (type, state_key) -> event_id
+ """
+ group_to_state = await self.get_state_for_groups((state_group,), state_filter)
+
+ return group_to_state[state_group]
+
+ async def get_state_groups(
+ self, room_id: str, event_ids: Collection[str]
+ ) -> Dict[int, List[EventBase]]:
+ """Get the state groups for the given list of event_ids
+
+ Args:
+ room_id: ID of the room for these events.
+ event_ids: The event IDs to retrieve state for.
+
+ Returns:
+ dict of state_group_id -> list of state events.
+ """
+ if not event_ids:
+ return {}
+
+ group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
+
+ state_event_map = await self.stores.main.get_events(
+ [
+ ev_id
+ for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
+ ],
+ get_prev_content=False,
+ )
+
+ return {
+ group: [
+ state_event_map[v]
+ for v in event_id_map.values()
+ if v in state_event_map
+ ]
+ for group, event_id_map in group_to_ids.items()
+ }
+
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ) -> Awaitable[Dict[int, StateMap[str]]]:
+ """Returns the state groups for a given set of groups, filtering on
+ types of state events.
+
+ Args:
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
+ from the database.
+
+ Returns:
+ Dict of state group to state map.
+ """
+
+ return self.stores.state._get_state_groups_from_groups(groups, state_filter)
+
+ async def get_state_for_events(
+ self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
+ ) -> Dict[str, StateMap[EventBase]]:
+ """Given a list of event_ids and type tuples, return a list of state
+ dicts for each event.
+
+ Args:
+ event_ids: The events to fetch the state of.
+ state_filter: The state filter used to fetch state.
+
+ Returns:
+ A dict of (event_id) -> (type, state_key) -> [state_events]
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ await_full_state = True
+ if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ await_full_state = False
+
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ state_event_map = await self.stores.main.get_events(
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ get_prev_content=False,
+ )
+
+ event_to_state = {
+ event_id: {
+ k: state_event_map[v]
+ for k, v in group_to_state[group].items()
+ if v in state_event_map
+ }
+ for event_id, group in event_to_groups.items()
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ async def get_state_ids_for_events(
+ self,
+ event_ids: Collection[str],
+ state_filter: Optional[StateFilter] = None,
+ ) -> Dict[str, StateMap[str]]:
+ """
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
+
+ Args:
+ event_ids: events whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from event_id -> (type, state_key) -> event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ await_full_state = True
+ if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ await_full_state = False
+
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ event_to_state = {
+ event_id: group_to_state[group]
+ for event_id, group in event_to_groups.items()
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ async def get_state_for_event(
+ self, event_id: str, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[EventBase]:
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from (type, state_key) -> state_event
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
+ """
+ state_map = await self.get_state_for_events(
+ [event_id], state_filter or StateFilter.all()
+ )
+ return state_map[event_id]
+
+ async def get_state_ids_for_event(
+ self, event_id: str, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[str]:
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from (type, state_key) -> state_event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
+ """
+ state_map = await self.get_state_ids_for_events(
+ [event_id], state_filter or StateFilter.all()
+ )
+ return state_map[event_id]
+
+ def get_state_for_groups(
+ self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
+ ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups: list of state groups for which we want to get the state.
+ state_filter: The state filter used to fetch state.
+ from the database.
+
+ Returns:
+ Dict of state group to state map.
+ """
+ return self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ async def get_state_group_for_events(
+ self,
+ event_ids: Collection[str],
+ await_full_state: bool = True,
+ ) -> Mapping[str, int]:
+ """Returns mapping event_id -> state_group
+
+ Args:
+ event_ids: events to get state groups for
+ await_full_state: if true, will block if we do not yet have complete
+ state at these events.
+ """
+ if await_full_state:
+ await self._partial_state_events_tracker.await_full_state(event_ids)
+
+ return await self.stores.main._get_state_group_for_events(event_ids)
+
+ async def store_state_group(
+ self,
+ event_id: str,
+ room_id: str,
+ prev_group: Optional[int],
+ delta_ids: Optional[StateMap[str]],
+ current_state_ids: StateMap[str],
+ ) -> int:
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id: The event ID for which the state was calculated.
+ room_id: ID of the room for which the state was calculated.
+ prev_group: A previous state group for the room, optional.
+ delta_ids: The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids: The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ The state group ID
+ """
+ return await self.stores.state.store_state_group(
+ event_id, room_id, prev_group, delta_ids, current_state_ids
+ )
+
+ async def get_current_state_ids(
+ self,
+ room_id: str,
+ state_filter: Optional[StateFilter] = None,
+ on_invalidate: Optional[Callable[[], None]] = None,
+ ) -> StateMap[str]:
+ """Get the current state event ids for a room based on the
+ current_state_events table.
+
+ If a state filter is given (that is not `StateFilter.all()`) the query
+ result is *not* cached.
+
+ Args:
+ room_id: The room to get the state IDs of. state_filter: The state
+ filter used to fetch state from the
+ database.
+ on_invalidate: Callback for when the `get_current_state_ids` cache
+ for the room gets invalidated.
+
+ Returns:
+ The current state of the room.
+ """
+ if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ if state_filter and not state_filter.is_full():
+ return await self.stores.main.get_partial_filtered_current_state_ids(
+ room_id, state_filter
+ )
+ else:
+ return await self.stores.main.get_partial_current_state_ids(
+ room_id, on_invalidate=on_invalidate
+ )
+
+ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
+ """Get canonical alias for room, if any
+
+ Args:
+ room_id: The room ID
+
+ Returns:
+ The canonical alias, if any
+ """
+
+ state = await self.get_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+ )
+
+ event_id = state.get((EventTypes.CanonicalAlias, ""))
+ if not event_id:
+ return None
+
+ event = await self.stores.main.get_event(event_id, allow_none=True)
+ if not event:
+ return None
+
+ return event.content.get("canonical_alias")
+
+ async def get_current_state_deltas(
+ self, prev_stream_id: int, max_stream_id: int
+ ) -> Tuple[int, List[Dict[str, Any]]]:
+ """Fetch a list of room state changes since the given stream id
+
+ Each entry in the result contains the following fields:
+ - stream_id (int)
+ - room_id (str)
+ - type (str): event type
+ - state_key (str):
+ - event_id (str|None): new event_id for this state key. None if the
+ state has been deleted.
+ - prev_event_id (str|None): previous event_id for this state key. None
+ if it's new state.
+
+ Args:
+ prev_stream_id: point to get changes since (exclusive)
+ max_stream_id: the point that we know has been correctly persisted
+ - ie, an upper limit to return changes from.
+
+ Returns:
+ A tuple consisting of:
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
+ """
+ # FIXME(faster_joins): what do we do here?
+
+ return await self.stores.main.get_partial_current_state_deltas(
+ prev_stream_id, max_stream_id
+ )
+
+ async def get_current_state(
+ self, room_id: str, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[EventBase]:
+ """Same as `get_current_state_ids` but also fetches the events"""
+ state_map_ids = await self.get_current_state_ids(room_id, state_filter)
+
+ event_map = await self.stores.main.get_events(list(state_map_ids.values()))
+
+ state_map = {}
+ for key, event_id in state_map_ids.items():
+ event = event_map.get(event_id)
+ if event:
+ state_map[key] = event
+
+ return state_map
+
+ async def get_current_state_event(
+ self, room_id: str, event_type: str, state_key: str
+ ) -> Optional[EventBase]:
+ """Get the current state event for the given type/state_key."""
+
+ key = (event_type, state_key)
+ state_map = await self.get_current_state(
+ room_id, StateFilter.from_types((key,))
+ )
+ return state_map.get(key)
+
+ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ """Get current hosts in room based on current state."""
+
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_current_hosts_in_room(room_id)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 41f566b6..a78d68a9 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -31,6 +31,7 @@ from typing import (
List,
Optional,
Tuple,
+ Type,
TypeVar,
cast,
overload,
@@ -41,6 +42,7 @@ from prometheus_client import Histogram
from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.enterprise import adbapi
+from twisted.internet.interfaces import IReactorCore
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -88,11 +90,15 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
"device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
"event_search": "event_search_event_id_idx",
+ "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
+ "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
}
def make_pool(
- reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ reactor: IReactorCore,
+ db_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database."""
@@ -101,7 +107,7 @@ def make_pool(
db_args = dict(db_config.config.get("args", {}))
db_args.setdefault("cp_reconnect", True)
- def _on_new_connection(conn):
+ def _on_new_connection(conn: Connection) -> None:
# Ensure we have a logging context so we can correctly track queries,
# etc.
with LoggingContext("db.on_new_connection"):
@@ -157,7 +163,11 @@ class LoggingDatabaseConnection:
default_txn_name: str
def cursor(
- self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+ self,
+ *,
+ txn_name: Optional[str] = None,
+ after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+ exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
txn_name = self.default_txn_name
@@ -183,11 +193,16 @@ class LoggingDatabaseConnection:
self.conn.__enter__()
return self
- def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
+ ) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class.
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
return getattr(self.conn, name)
@@ -391,17 +406,22 @@ class LoggingTransaction:
def __enter__(self) -> "LoggingTransaction":
return self
- def __exit__(self, exc_type, exc_value, traceback):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
+ ) -> None:
self.close()
class PerformanceCounters:
- def __init__(self):
- self.current_counters = {}
- self.previous_counters = {}
+ def __init__(self) -> None:
+ self.current_counters: Dict[str, Tuple[int, float]] = {}
+ self.previous_counters: Dict[str, Tuple[int, float]] = {}
def update(self, key: str, duration_secs: float) -> None:
- count, cum_time = self.current_counters.get(key, (0, 0))
+ count, cum_time = self.current_counters.get(key, (0, 0.0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
@@ -527,7 +547,7 @@ class DatabasePool:
def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
- def loop():
+ def loop() -> None:
curr = self._current_txn_total_time
prev = self._previous_txn_total_time
self._previous_txn_total_time = curr
@@ -1186,7 +1206,7 @@ class DatabasePool:
if lock:
self.engine.lock_table(txn, table)
- def _getwhere(key):
+ def _getwhere(key: str) -> str:
# If the value we're passing in is None (aka NULL), we need to use
# IS, not =, as NULL = NULL equals NULL (False).
if keyvalues[key] is None:
@@ -2258,7 +2278,7 @@ class DatabasePool:
term: Optional[str],
col: str,
retcols: Collection[str],
- desc="simple_search_list",
+ desc: str = "simple_search_list",
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5895b892..11d9d16c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -26,11 +26,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
- IdGenerator,
- MultiWriterIdGenerator,
- StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -155,12 +151,6 @@ class DataStore(
],
)
- self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
- self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._group_updates_id_gen = StreamIdGenerator(
- db_conn, "local_group_updates", "stream_id"
- )
-
self._cache_id_gen: Optional[MultiWriterIdGenerator]
if isinstance(self.database_engine, PostgresEngine):
# We set the `writers` to an empty list here as we don't care about
@@ -203,20 +193,6 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill,
)
- _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
- db_conn,
- "local_group_updates",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._group_updates_id_gen.get_current_token(),
- limit=1000,
- )
- self._group_updates_stream_cache = StreamChangeCache(
- "_group_updates_stream_cache",
- min_group_updates_id,
- prefilled_cache=_group_updates_prefill,
- )
-
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 945707b0..e284454b 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -203,19 +203,29 @@ class ApplicationServiceTransactionWorkerStore(
"""Get the application service state.
Args:
- service: The service whose state to set.
+ service: The service whose state to get.
Returns:
- An ApplicationServiceState or none.
+ An ApplicationServiceState, or None if we have yet to attempt any
+ transactions to the AS.
"""
- result = await self.db_pool.simple_select_one(
+ # if we have created transactions for this AS but not yet attempted to send
+ # them, we will have a row in the table with state=NULL (recording the stream
+ # positions we have processed up to).
+ #
+ # On the other hand, if we have yet to create any transactions for this AS at
+ # all, then there will be no row for the AS.
+ #
+ # In either case, we return None to indicate "we don't yet know the state of
+ # this AS".
+ result = await self.db_pool.simple_select_one_onecol(
"application_services_state",
{"as_id": service.id},
- ["state"],
+ retcol="state",
allow_none=True,
desc="get_appservice_state",
)
if result:
- return ApplicationServiceState(result.get("state"))
+ return ApplicationServiceState(result)
return None
async def set_appservice_state(
@@ -296,14 +306,6 @@ class ApplicationServiceTransactionWorkerStore(
"""
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
- # Set current txn_id for AS to 'txn_id'
- self.db_pool.simple_upsert_txn(
- txn,
- "application_services_state",
- {"as_id": service.id},
- {"last_txn": txn_id},
- )
-
# Delete txn
self.db_pool.simple_delete_txn(
txn,
@@ -452,16 +454,15 @@ class ApplicationServiceTransactionWorkerStore(
% (stream_type,)
)
- def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
- stream_id_type = "%s_stream_id" % stream_type
- txn.execute(
- "UPDATE application_services_state SET %s = ? WHERE as_id=?"
- % stream_id_type,
- (pos, service.id),
- )
-
- await self.db_pool.runInteraction(
- "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
+ # this may be the first time that we're recording any state for this AS, so
+ # we don't yet know if a row for it exists; hence we have to upsert here.
+ await self.db_pool.simple_upsert(
+ table="application_services_state",
+ keyvalues={"as_id": service.id},
+ values={f"{stream_type}_stream_id": pos},
+ # no need to lock when emulating upsert: as_id is a unique key
+ lock=False,
+ desc="set_appservice_stream_type_pos",
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index dd4e83a2..1653a6a9 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._instance_name = hs.get_instance_name()
+ self.db_pool.updates.register_background_index_update(
+ update_name="cache_invalidation_index_by_instance",
+ index_name="cache_invalidation_stream_by_instance_instance_index",
+ table="cache_invalidation_stream_by_instance",
+ columns=("instance_name", "stream_id"),
+ psql_only=True, # The table is only on postgres DBs.
+ )
+
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2df4dd4e..d900064c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
cast,
)
+from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
get_active_span_text_map,
@@ -419,7 +420,7 @@ class DeviceWorkerStore(SQLBaseStore):
# Add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
- results.append(("m.signing_key_update", result))
+ results.append((EduTypes.SIGNING_KEY_UPDATE, result))
# also send the unstable version
# FIXME: remove this when enough servers have upgraded
# and remove the length budgeting above.
@@ -545,7 +546,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
- results.append(("m.device_list_update", result))
+ results.append((EduTypes.DEVICE_LIST_UPDATE, result))
return results
@@ -1153,6 +1154,45 @@ class DeviceWorkerStore(SQLBaseStore):
_prune_txn,
)
+ async def get_local_devices_not_accessed_since(
+ self, since_ms: int
+ ) -> Dict[str, List[str]]:
+ """Retrieves local devices that haven't been accessed since a given date.
+
+ Args:
+ since_ms: the timestamp to select on, every device with a last access date
+ from before that time is returned.
+
+ Returns:
+ A dictionary with an entry for each user with at least one device matching
+ the request, which value is a list of the device ID(s) for the corresponding
+ device(s).
+ """
+
+ def get_devices_not_accessed_since_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, str]]:
+ sql = """
+ SELECT user_id, device_id
+ FROM devices WHERE last_seen < ? AND hidden = FALSE
+ """
+ txn.execute(sql, (since_ms,))
+ return self.db_pool.cursor_to_dict(txn)
+
+ rows = await self.db_pool.runInteraction(
+ "get_devices_not_accessed_since",
+ get_devices_not_accessed_since_txn,
+ )
+
+ devices: Dict[str, List[str]] = {}
+ for row in rows:
+ # Remote devices are never stale from our point of view.
+ if self.hs.is_mine_id(row["user_id"]):
+ user_devices = devices.setdefault(row["user_id"], [])
+ user_devices.append(row["device_id"])
+
+ return devices
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b789a588..af59be6b 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -21,7 +21,7 @@ from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
-from synapse.types import JsonDict, JsonSerializable
+from synapse.types import JsonDict, JsonSerializable, StreamKeyType
from synapse.util import json_encoder
@@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
- "room_key": room_key,
+ StreamKeyType.ROOM: room_key,
}
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 47102247..eec55b64 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,17 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
from prometheus_client import Counter, Gauge
@@ -33,7 +43,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
- room = await self.get_room(room_id)
+ room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_ids_using_cover_index_txn(
- self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ event_ids: Collection[str],
+ include_given: bool,
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""
@@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
- for batch in batch_iter(event_chains, 1000):
+ for batch2 in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch
+ txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
@@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = set(event_ids)
while front:
- new_front = set()
+ new_front: Set[str] = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
to_fetch: List[str] = [] # Event IDs to fetch from DB
@@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Note we need to batch up the results by event ID before
# adding to the cache.
- to_cache = {}
+ to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
- room = await self.get_room(room_id)
+ room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_difference_using_cover_index_txn(
- self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+ self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using the chain index.
@@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
- for batch in batch_iter(set(seen_chains), 1000):
+ for batch2 in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch
+ txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
@@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
def _get_auth_chain_difference_txn(
- self, txn, state_sets: List[Set[str]]
+ self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search.
@@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# I think building a temporary list with fetchall is more efficient than
# just `search.extend(txn)`, but this is unconfirmed
- search.extend(txn.fetchall())
+ search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
# sort by depth
search.sort()
@@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We parse the results and add the to the `found` set and the
# cache (note we need to batch up the results by event ID before
# adding to the cache).
- to_cache = {}
+ to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return {eid for eid, n in event_to_missing_sets.items() if n}
async def get_oldest_event_ids_with_depth_in_room(
- self, room_id
+ self, room_id: str
) -> List[Tuple[str, int]]:
"""Gets the oldest events(backwards extremities) in the room along with the
aproximate depth.
@@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
- def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
+ def get_oldest_event_ids_with_depth_in_room_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> List[Tuple[str, int]]:
# Assemble a dictionary with event_id -> depth for the oldest events
# we know of in the room. Backwards extremeties are the oldest
# events we know of in the room but we only know of them because
@@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, False))
- return txn.fetchall()
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_oldest_event_ids_with_depth_in_room",
@@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
async def get_insertion_event_backward_extremities_in_room(
- self, room_id
+ self, room_id: str
) -> List[Tuple[str, int]]:
"""Get the insertion events we know about that we haven't backfilled yet.
@@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
- def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
+ def get_insertion_event_backward_extremities_in_room_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> List[Tuple[str, int]]:
sql = """
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
/* We only want insertion events that are also marked as backwards extremities */
@@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
- return txn.fetchall()
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_insertion_event_backward_extremities_in_room",
@@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
- async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+ async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return max_depth_event_id, current_max_depth
- async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+ async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
Args:
@@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
- def _get_prev_events_for_room_txn(self, txn, room_id: str):
+ def _get_prev_events_for_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> List[str]:
# we just use the 10 newest events. Older events will become
# prev_events of future events.
@@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
sorted by extremity count.
"""
- def _get_rooms_with_many_extremities_txn(txn):
+ def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
where_clause = "1=1"
if room_id_filter:
where_clause = "room_id NOT IN (%s)" % (
@@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_min_depth", self._get_min_depth_interaction, room_id
)
- def _get_min_depth_interaction(self, txn, room_id):
+ def _get_min_depth_interaction(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> Optional[int]:
min_depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_depth",
@@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
- last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+ last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
# stream_ordering from before a restart
- last_change = max(self._stream_order_on_start, last_change)
+ last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
# provided the last_change is recent enough, we now clamp the requested
# stream_ordering to it.
- if last_change > self.stream_ordering_month_ago:
+ if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
stream_ordering = min(last_change, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
- async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+ async def _get_forward_extremeties_for_room(
+ self, room_id: str, stream_ordering: int
+ ) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
stream_orderings from that point.
"""
- if stream_ordering <= self.stream_ordering_month_ago:
+ if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
@@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
WHERE room_id = ?
"""
- def get_forward_extremeties_for_room_txn(txn):
+ def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
@@ -1033,7 +1057,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
INNER JOIN batch_events AS c
ON i.next_batch_id = c.batch_id
/* Get the depth of the batch start event from the events table */
- INNER JOIN events AS e USING (event_id)
+ INNER JOIN events AS e ON c.event_id = e.event_id
/* Find an insertion event which matches the given event_id */
WHERE i.event_id = ?
LIMIT ?
@@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
]
async def get_backfill_events(
- self, room_id: str, seed_event_id_list: list, limit: int
- ):
+ self, room_id: str, seed_event_id_list: List[str], limit: int
+ ) -> List[EventBase]:
"""Get a list of Events for a given topic that occurred before (and
including) the events in seed_event_id_list. Return a list of max size `limit`
@@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
events = await self.get_events_as_list(event_ids)
return sorted(
- events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+ # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
+ # But it's never None, because these events were previously persisted to the DB.
+ events,
+ key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
)
- def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
+ def _get_backfill_events(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ seed_event_id_list: List[str],
+ limit: int,
+ ) -> Set[str]:
"""
We want to make sure that we do a breadth-first, "depth" ordered search.
We also handle navigating historical branches of history connected by
@@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
limit,
)
- event_id_results = set()
+ event_id_results: Set[str] = set()
# In a PriorityQueue, the lowest valued entries are retrieved first.
# We're using depth as the priority in the queue and tie-break based on
@@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# highest and newest-in-time message. We add events to the queue with a
# negative depth so that we process the newest-in-time messages first
# going backwards in time. stream_ordering follows the same pattern.
- queue = PriorityQueue()
+ queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
for seed_event_id in seed_event_id_list:
event_lookup_result = self.db_pool.simple_select_one_txn(
@@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return event_id_results
- async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+ async def get_missing_events(
+ self,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[EventBase]:
ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
@@ -1264,25 +1303,29 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(ids)
- def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
+ def _get_missing_events(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[str]:
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
- event_results = []
+ event_results: List[str] = []
query = (
"SELECT prev_event_id FROM event_edges "
- "WHERE room_id = ? AND event_id = ? AND is_state = ? "
+ "WHERE event_id = ? AND NOT is_state "
"LIMIT ?"
)
while front and len(event_results) < limit:
new_front = set()
for event_id in front:
- txn.execute(
- query, (room_id, event_id, False, limit - len(event_results))
- )
-
+ txn.execute(query, (event_id, limit - len(event_results)))
new_results = {t[0] for t in txn} - seen_events
new_front |= new_results
@@ -1311,7 +1354,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@wrap_as_background_process("delete_old_forward_extrem_cache")
async def _delete_old_forward_extrem_cache(self) -> None:
- def _delete_old_forward_extrem_cache_txn(txn):
+ def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = """
@@ -1324,7 +1367,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) AND stream_ordering < ?
"""
txn.execute(
- sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+ sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
)
await self.db_pool.runInteraction(
@@ -1382,7 +1425,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
if self.db_pool.engine.supports_returning:
- def _remove_received_event_from_staging_txn(txn):
+ def _remove_received_event_from_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
sql = """
DELETE FROM federation_inbound_events_staging
WHERE origin = ? AND event_id = ?
@@ -1390,21 +1435,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (origin, event_id))
- return txn.fetchone()
+ row = cast(Optional[Tuple[int]], txn.fetchone())
+
+ if row is None:
+ return None
- row = await self.db_pool.runInteraction(
+ return row[0]
+
+ return await self.db_pool.runInteraction(
"remove_received_event_from_staging",
_remove_received_event_from_staging_txn,
db_autocommit=True,
)
- if row is None:
- return None
-
- return row[0]
else:
- def _remove_received_event_from_staging_txn(txn):
+ def _remove_received_event_from_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
received_ts = self.db_pool.simple_select_one_onecol_txn(
txn,
table="federation_inbound_events_staging",
@@ -1437,7 +1485,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, str]]:
"""Get the next event ID in the staging area for the given room."""
- def _get_next_staged_event_id_for_room_txn(txn):
+ def _get_next_staged_event_id_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, str]]:
sql = """
SELECT origin, event_id
FROM federation_inbound_events_staging
@@ -1448,7 +1498,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id,))
- return txn.fetchone()
+ return cast(Optional[Tuple[str, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@@ -1461,7 +1511,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, EventBase]]:
"""Get the next event in the staging area for the given room."""
- def _get_next_staged_event_for_room_txn(txn):
+ def _get_next_staged_event_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, str, str]]:
sql = """
SELECT event_json, internal_metadata, origin
FROM federation_inbound_events_staging
@@ -1471,7 +1523,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
- return txn.fetchone()
+ return cast(Optional[Tuple[str, str, str]], txn.fetchone())
row = await self.db_pool.runInteraction(
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@@ -1599,18 +1651,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@wrap_as_background_process("_get_stats_for_federation_staging")
- async def _get_stats_for_federation_staging(self):
+ async def _get_stats_for_federation_staging(self) -> None:
"""Update the prometheus metrics for the inbound federation staging area."""
- def _get_stats_for_federation_staging_txn(txn):
+ def _get_stats_for_federation_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, int]:
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
txn.execute(
"SELECT min(received_ts) FROM federation_inbound_events_staging"
)
- (received_ts,) = txn.fetchone()
+ (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
# If there is nothing in the staging area default it to 0.
age = 0
@@ -1651,19 +1705,21 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
- async def clean_room_for_join(self, room_id):
- return await self.db_pool.runInteraction(
+ async def clean_room_for_join(self, room_id: str) -> None:
+ await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
- def _clean_room_for_join_txn(self, txn, room_id):
+ def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- async def _background_delete_non_state_event_auth(self, progress, batch_size):
- def delete_event_auth(txn):
+ async def _background_delete_non_state_event_auth(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def delete_event_auth(txn: LoggingTransaction) -> bool:
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b7c4c622..b0199793 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -938,7 +938,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
users can still get a list of recent highlights.
Args:
- txn: The transcation
+ txn: The transaction
room_id: Room ID to delete from
user_id: user ID to delete for
stream_ordering: The lowest stream ordering which will
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2c86a870..17e35cf6 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,9 +36,8 @@ from prometheus_client import Counter
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
-from synapse.events import EventBase # noqa: F401
-from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines.postgres import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
from synapse.util.stringutils import non_null_str_or_none
@@ -130,7 +129,6 @@ class PersistEventsStore:
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
*,
- current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremities: Dict[str, Set[str]],
use_negative_stream_ordering: bool = False,
@@ -141,8 +139,6 @@ class PersistEventsStore:
Args:
events_and_contexts:
- current_state_for_room: Map from room_id to the current state of
- the room based on forward extremities
state_delta_for_room: Map from room_id to the delta to apply to
room state
new_forward_extremities: Map from room_id to set of event IDs
@@ -217,9 +213,6 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, new_state in current_state_for_room.items():
- self.store.get_current_state_ids.prefill((room_id,), new_state)
-
for room_id, latest_event_ids in new_forward_extremities.items():
self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
@@ -237,7 +230,9 @@ class PersistEventsStore:
"""
results: List[str] = []
- def _get_events_which_are_prevs_txn(txn, batch):
+ def _get_events_which_are_prevs_txn(
+ txn: LoggingTransaction, batch: Collection[str]
+ ) -> None:
sql = """
SELECT prev_event_id, internal_metadata
FROM event_edges
@@ -287,7 +282,9 @@ class PersistEventsStore:
# and their prev events.
existing_prevs = set()
- def _get_prevs_before_rejected_txn(txn, batch):
+ def _get_prevs_before_rejected_txn(
+ txn: LoggingTransaction, batch: Collection[str]
+ ) -> None:
to_recursively_check = batch
while to_recursively_check:
@@ -517,7 +514,7 @@ class PersistEventsStore:
@classmethod
def _add_chain_cover_index(
cls,
- txn,
+ txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
@@ -811,7 +808,7 @@ class PersistEventsStore:
@staticmethod
def _allocate_chain_ids(
- txn,
+ txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
@@ -945,7 +942,7 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- ):
+ ) -> None:
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
to_insert = []
@@ -999,7 +996,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
state_delta_by_room: Dict[str, DeltaState],
stream_id: int,
- ):
+ ) -> None:
for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
@@ -1157,7 +1154,7 @@ class PersistEventsStore:
txn, room_id, members_changed
)
- def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+ def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
"""Update the room version in the database based off current state
events.
@@ -1191,7 +1188,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
new_forward_extremities: Dict[str, Set[str]],
max_stream_order: int,
- ):
+ ) -> None:
for room_id in new_forward_extremities.keys():
self.db_pool.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1256,9 +1253,9 @@ class PersistEventsStore:
def _update_room_depths_txn(
self,
- txn,
+ txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- ):
+ ) -> None:
"""Update min_depth for each room
Args:
@@ -1387,7 +1384,7 @@ class PersistEventsStore:
# nothing to do here
return
- def event_dict(event):
+ def event_dict(event: EventBase) -> JsonDict:
d = event.get_dict()
d.pop("redacted", None)
d.pop("redacted_because", None)
@@ -1478,18 +1475,20 @@ class PersistEventsStore:
),
)
- def _store_rejected_events_txn(self, txn, events_and_contexts):
+ def _store_rejected_events_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Add rows to the 'rejections' table for received events which were
rejected
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
+ txn: db connection
+ events_and_contexts: events we are persisting
Returns:
- list[(EventBase, EventContext)] new list, without the rejected
- events.
+ new list, without the rejected events.
"""
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
@@ -1510,7 +1509,7 @@ class PersistEventsStore:
events_and_contexts: List[Tuple[EventBase, EventContext]],
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool = False,
- ):
+ ) -> None:
"""Update all the miscellaneous tables for new events
Args:
@@ -1601,15 +1600,14 @@ class PersistEventsStore:
inhibit_local_membership_updates=inhibit_local_membership_updates,
)
- # Insert event_reference_hashes table.
- self._store_event_reference_hashes_txn(
- txn, [event for event, _ in events_and_contexts]
- )
-
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
- def _add_to_cache(self, txn, events_and_contexts):
+ def _add_to_cache(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> None:
to_prefill = []
rows = []
@@ -1640,7 +1638,7 @@ class PersistEventsStore:
if not row["rejects"] and not row["redacts"]:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
- def prefill():
+ def prefill() -> None:
for cache_entry in to_prefill:
self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
@@ -1670,19 +1668,24 @@ class PersistEventsStore:
)
def insert_labels_for_event_txn(
- self, txn, event_id, labels, room_id, topological_ordering
- ):
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ labels: List[str],
+ room_id: str,
+ topological_ordering: int,
+ ) -> None:
"""Store the mapping between an event's ID and its labels, with one row per
(event_id, label) tuple.
Args:
- txn (LoggingTransaction): The transaction to execute.
- event_id (str): The event's ID.
- labels (list[str]): A list of text labels.
- room_id (str): The ID of the room the event was sent to.
- topological_ordering (int): The position of the event in the room's topology.
+ txn: The transaction to execute.
+ event_id: The event's ID.
+ labels: A list of text labels.
+ room_id: The ID of the room the event was sent to.
+ topological_ordering: The position of the event in the room's topology.
"""
- return self.db_pool.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1691,44 +1694,32 @@ class PersistEventsStore:
],
)
- def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+ def _insert_event_expiry_txn(
+ self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+ ) -> None:
"""Save the expiry timestamp associated with a given event ID.
Args:
- txn (LoggingTransaction): The database transaction to use.
- event_id (str): The event ID the expiry timestamp is associated with.
- expiry_ts (int): The timestamp at which to expire (delete) the event.
+ txn: The database transaction to use.
+ event_id: The event ID the expiry timestamp is associated with.
+ expiry_ts: The timestamp at which to expire (delete) the event.
"""
- return self.db_pool.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
)
- def _store_event_reference_hashes_txn(self, txn, events):
- """Store a hash for a PDU
- Args:
- txn (cursor):
- events (list): list of Events.
- """
-
- vals = []
- for event in events:
- ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
- vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
-
- self.db_pool.simple_insert_many_txn(
- txn,
- table="event_reference_hashes",
- keys=("event_id", "algorithm", "hash"),
- values=vals,
- )
-
def _store_room_members_txn(
- self, txn, events, *, inhibit_local_membership_updates: bool = False
- ):
+ self,
+ txn: LoggingTransaction,
+ events: List[EventBase],
+ *,
+ inhibit_local_membership_updates: bool = False,
+ ) -> None:
"""
Store a room member in the database.
+
Args:
txn: The transaction to use.
events: List of events to store.
@@ -1765,6 +1756,7 @@ class PersistEventsStore:
)
for event in events:
+ assert event.internal_metadata.stream_ordering is not None
txn.call_after(
self.store._membership_stream_cache.entity_has_changed,
event.state_key,
@@ -1813,55 +1805,54 @@ class PersistEventsStore:
txn: The current database transaction.
event: The event which might have relations.
"""
- relation = event.content.get("m.relates_to")
+ relation = relation_from_event(event)
if not relation:
- # No relations
- return
-
- # Relations must have a type and parent event ID.
- rel_type = relation.get("rel_type")
- if not isinstance(rel_type, str):
+ # No relation, nothing to do.
return
- parent_id = relation.get("event_id")
- if not isinstance(parent_id, str):
- return
-
- # Annotations have a key field.
- aggregation_key = None
- if rel_type == RelationTypes.ANNOTATION:
- aggregation_key = relation.get("key")
-
self.db_pool.simple_insert_txn(
txn,
table="event_relations",
values={
"event_id": event.event_id,
- "relates_to_id": parent_id,
- "relation_type": rel_type,
- "aggregation_key": aggregation_key,
+ "relates_to_id": relation.parent_id,
+ "relation_type": relation.rel_type,
+ "aggregation_key": relation.aggregation_key,
},
)
- txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
txn.call_after(
- self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+ self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+ )
+ txn.call_after(
+ self.store.get_aggregation_groups_for_event.invalidate,
+ (relation.parent_id,),
+ )
+ txn.call_after(
+ self.store.get_mutual_event_relations_for_rel_type.invalidate,
+ (relation.parent_id,),
)
- if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+ if relation.rel_type == RelationTypes.REPLACE:
+ txn.call_after(
+ self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+ )
- if rel_type == RelationTypes.THREAD:
- txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+ if relation.rel_type == RelationTypes.THREAD:
+ txn.call_after(
+ self.store.get_thread_summary.invalidate, (relation.parent_id,)
+ )
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
- (parent_id, event.sender),
+ (relation.parent_id, event.sender),
)
- def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+ def _handle_insertion_event(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
"""Handles keeping track of insertion events and edges/connections.
Part of MSC2716.
@@ -1922,7 +1913,7 @@ class PersistEventsStore:
},
)
- def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Handles inserting the batch edges/connections between the batch event
and an insertion event. Part of MSC2716.
@@ -2017,30 +2008,39 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
+ self.store._invalidate_cache_and_stream(
+ txn,
+ self.store.get_mutual_event_relations_for_rel_type,
+ (redacted_relates_to,),
+ )
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
- def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("topic"), str):
self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
- def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("name"), str):
self.store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
- def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_message_txn(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
if isinstance(event.content.get("body"), str):
self.store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
- def _store_retention_policy_for_room_txn(self, txn, event):
+ def _store_retention_policy_for_room_txn(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
if not event.is_state():
logger.debug("Ignoring non-state m.room.retention event")
return
@@ -2100,8 +2100,11 @@ class PersistEventsStore:
)
def _set_push_actions_for_event_and_users_txn(
- self, txn, events_and_contexts, all_events_and_contexts
- ):
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> None:
"""Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`.
@@ -2109,12 +2112,10 @@ class PersistEventsStore:
from the push action staging area.
Args:
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
- all_events_and_contexts (list[(EventBase, EventContext)]): all
- events that we were going to persist. This includes events
- we've already persisted, etc, that wouldn't appear in
- events_and_context.
+ events_and_contexts: events we are persisting
+ all_events_and_contexts: all events that we were going to persist.
+ This includes events we've already persisted, etc, that wouldn't
+ appear in events_and_context.
"""
# Only non outlier events will have push actions associated with them,
@@ -2183,7 +2184,9 @@ class PersistEventsStore:
),
)
- def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+ def _remove_push_actions_for_event_id_txn(
+ self, txn: LoggingTransaction, room_id: str, event_id: str
+ ) -> None:
# Sad that we have to blow away the cache for the whole room here
txn.call_after(
self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2194,7 +2197,9 @@ class PersistEventsStore:
(room_id, event_id),
)
- def _store_rejections_txn(self, txn, event_id, reason):
+ def _store_rejections_txn(
+ self, txn: LoggingTransaction, event_id: str, reason: str
+ ) -> None:
self.db_pool.simple_insert_txn(
txn,
table="rejections",
@@ -2206,8 +2211,10 @@ class PersistEventsStore:
)
def _store_event_state_mappings_txn(
- self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
- ):
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+ ) -> None:
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
@@ -2264,7 +2271,9 @@ class PersistEventsStore:
state_group_id,
)
- def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ def _update_min_depth_for_room_txn(
+ self, txn: LoggingTransaction, room_id: str, depth: int
+ ) -> None:
min_depth = self.store._get_min_depth_interaction(txn, room_id)
if min_depth is not None and depth >= min_depth:
@@ -2277,7 +2286,9 @@ class PersistEventsStore:
values={"min_depth": depth},
)
- def _handle_mult_prev_events(self, txn, events):
+ def _handle_mult_prev_events(
+ self, txn: LoggingTransaction, events: List[EventBase]
+ ) -> None:
"""
For the given event, update the event edges table and forward and
backward extremities tables.
@@ -2295,7 +2306,9 @@ class PersistEventsStore:
self._update_backward_extremeties(txn, events)
- def _update_backward_extremeties(self, txn, events):
+ def _update_backward_extremeties(
+ self, txn: LoggingTransaction, events: List[EventBase]
+ ) -> None:
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a4a604a4..b99b1077 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,6 +14,7 @@
import logging
import threading
+import weakref
from enum import Enum, auto
from typing import (
TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
Dict,
Iterable,
List,
+ MutableMapping,
Optional,
Set,
Tuple,
@@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore):
str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
+ # We keep track of the events we have currently loaded in memory so that
+ # we can reuse them even if they've been evicted from the cache. We only
+ # track events that don't need redacting in here (as then we don't need
+ # to track redaction status).
+ self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary()
+
self._event_fetch_lock = threading.Condition()
self._event_fetch_list: List[
Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
@@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
+ self._event_ref.pop(event_id, None)
+ self._current_event_fetches.pop(event_id, None)
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
@@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore):
event_map = {}
for event_id in events:
+ # First check if it's in the event cache
ret = self._get_event_cache.get(
(event_id,), None, update_metrics=update_metrics
)
- if not ret:
+ if ret:
+ event_map[event_id] = ret
continue
- event_map[event_id] = ret
+ # Otherwise check if we still have the event in memory.
+ event = self._event_ref.get(event_id)
+ if event:
+ # Reconstruct an event cache entry
+
+ cache_entry = EventCacheEntry(
+ event=event,
+ # We don't cache weakrefs to redacted events, so we know
+ # this is None.
+ redacted_event=None,
+ )
+ event_map[event_id] = cache_entry
+
+ # We add the entry back into the cache as we want to keep
+ # recently queried events in the cache.
+ self._get_event_cache.set((event_id,), cache_entry)
return event_map
@@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry
+ if not redacted_event:
+ # We only cache references to unredacted events.
+ self._event_ref[event_id] = original_ev
+
return result_map
async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
@@ -1325,14 +1356,23 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
The set of events we have already seen.
"""
- res = await self._have_seen_events_dict(
- (room_id, event_id) for event_id in event_ids
- )
- return {eid for ((_rid, eid), have_event) in res.items() if have_event}
+
+ # @cachedList chomps lots of memory if you call it with a big list, so
+ # we break it down. However, each batch requires its own index scan, so we make
+ # the batches as big as possible.
+
+ results: Set[str] = set()
+ for chunk in batch_iter(event_ids, 500):
+ r = await self._have_seen_events_dict(
+ [(room_id, event_id) for event_id in chunk]
+ )
+ results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
+
+ return results
@cachedList(cached_method_name="have_seen_event", list_name="keys")
async def _have_seen_events_dict(
- self, keys: Iterable[Tuple[str, str]]
+ self, keys: Collection[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
"""Helper for have_seen_events
@@ -1344,11 +1384,12 @@ class EventsWorkerStore(SQLBaseStore):
cache_results = {
(rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
}
- results = {x: True for x in cache_results}
+ results = dict.fromkeys(cache_results, True)
+ remaining = [k for k in keys if k not in cache_results]
+ if not remaining:
+ return results
- def have_seen_events_txn(
- txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
- ) -> None:
+ def have_seen_events_txn(txn: LoggingTransaction) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1356,21 +1397,17 @@ class EventsWorkerStore(SQLBaseStore):
sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
- txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk]
+ txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}
- # ... and then we can update the results for each row in the batch
- results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk})
-
- # each batch requires its own index scan, so we make the batches as big as
- # possible.
- for chunk in batch_iter((k for k in keys if k not in cache_results), 500):
- await self.db_pool.runInteraction(
- "have_seen_events", have_seen_events_txn, chunk
+ # ... and then we can update the results for each key
+ results.update(
+ {(rid, eid): (eid in found_events) for (rid, eid) in remaining}
)
+ await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
return results
@cached(max_entries=100000, tree=True)
@@ -1891,6 +1928,18 @@ class EventsWorkerStore(SQLBaseStore):
LIMIT 1
"""
+ # We consider any forward extremity as the latest in the room and
+ # not a forward gap.
+ #
+ # To expand, even though there is technically a gap at the front of
+ # the room where the forward extremities are, we consider those the
+ # latest messages in the room so asking other homeservers for more
+ # is useless. The new latest messages will just be federated as
+ # usual.
+ txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+ if txn.fetchone():
+ return False
+
# Check to see whether the event in question is already referenced
# by another event. If we don't see any edges, we're next to a
# forward gap.
@@ -1899,8 +1948,7 @@ class EventsWorkerStore(SQLBaseStore):
/* Check to make sure the event referencing our event in question is not rejected */
LEFT JOIN rejections ON event_edges.event_id = rejections.event_id
WHERE
- event_edges.room_id = ?
- AND event_edges.prev_event_id = ?
+ event_edges.prev_event_id = ?
/* It's not a valid edge if the event referencing our event in
* question is rejected.
*/
@@ -1908,25 +1956,11 @@ class EventsWorkerStore(SQLBaseStore):
LIMIT 1
"""
- # We consider any forward extremity as the latest in the room and
- # not a forward gap.
- #
- # To expand, even though there is technically a gap at the front of
- # the room where the forward extremities are, we consider those the
- # latest messages in the room so asking other homeservers for more
- # is useless. The new latest messages will just be federated as
- # usual.
- txn.execute(forward_extremity_query, (event.room_id, event.event_id))
- forward_extremities = txn.fetchall()
- if len(forward_extremities):
- return False
-
# If there are no forward edges to the event in question (another
# event hasn't referenced this event in their prev_events), then we
# assume there is a forward gap in the history.
- txn.execute(forward_edge_query, (event.room_id, event.event_id))
- forward_edges = txn.fetchall()
- if not len(forward_edges):
+ txn.execute(forward_edge_query, (event.event_id,))
+ if not txn.fetchone():
return True
return False
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 04efad9e..c15a7136 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,1417 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING
-from typing_extensions import TypedDict
-
-from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
-)
-from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
if TYPE_CHECKING:
from synapse.server import HomeServer
-# The category ID for the "default" category. We don't store as null in the
-# database to avoid the fun of null != null
-_DEFAULT_CATEGORY_ID = ""
-_DEFAULT_ROLE_ID = ""
-
-
-# A room in a group.
-class _RoomInGroup(TypedDict):
- room_id: str
- is_public: bool
-
-class GroupServerWorkerStore(SQLBaseStore):
+class GroupServerStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
- database.updates.register_background_index_update(
- update_name="local_group_updates_index",
- index_name="local_group_updates_stream_id_index",
- table="local_group_updates",
- columns=("stream_id",),
- unique=True,
- )
+ # Register a legacy groups background update as a no-op.
+ database.updates.register_noop_background_update("local_group_updates_index")
super().__init__(database, db_conn, hs)
-
- async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
- return await self.db_pool.simple_select_one(
- table="groups",
- keyvalues={"group_id": group_id},
- retcols=(
- "name",
- "short_description",
- "long_description",
- "avatar_url",
- "is_public",
- "join_policy",
- ),
- allow_none=True,
- desc="get_group",
- )
-
- async def get_users_in_group(
- self, group_id: str, include_private: bool = False
- ) -> List[Dict[str, Any]]:
- # TODO: Pagination
-
- keyvalues: JsonDict = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
-
- return await self.db_pool.simple_select_list(
- table="group_users",
- keyvalues=keyvalues,
- retcols=("user_id", "is_public", "is_admin"),
- desc="get_users_in_group",
- )
-
- async def get_invited_users_in_group(self, group_id: str) -> List[str]:
- # TODO: Pagination
-
- return await self.db_pool.simple_select_onecol(
- table="group_invites",
- keyvalues={"group_id": group_id},
- retcol="user_id",
- desc="get_invited_users_in_group",
- )
-
- async def get_rooms_in_group(
- self, group_id: str, include_private: bool = False
- ) -> List[_RoomInGroup]:
- """Retrieve the rooms that belong to a given group. Does not return rooms that
- lack members.
-
- Args:
- group_id: The ID of the group to query for rooms
- include_private: Whether to return private rooms in results
-
- Returns:
- A list of dictionaries, each in the form of:
-
- {
- "room_id": "!a_room_id:example.com", # The ID of the room
- "is_public": False # Whether this is a public room or not
- }
- """
-
- # TODO: Pagination
-
- def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
- sql = """
- SELECT room_id, is_public FROM group_rooms
- WHERE group_id = ?
- AND room_id IN (
- SELECT group_rooms.room_id FROM group_rooms
- LEFT JOIN room_stats_current ON
- group_rooms.room_id = room_stats_current.room_id
- AND joined_members > 0
- AND local_users_in_room > 0
- LEFT JOIN rooms ON
- group_rooms.room_id = rooms.room_id
- AND (room_version <> '') = ?
- )
- """
- args = [group_id, False]
-
- if not include_private:
- sql += " AND is_public = ?"
- args += [True]
-
- txn.execute(sql, args)
-
- return [
- {"room_id": room_id, "is_public": is_public}
- for room_id, is_public in txn
- ]
-
- return await self.db_pool.runInteraction(
- "get_rooms_in_group", _get_rooms_in_group_txn
- )
-
- async def get_rooms_for_summary_by_category(
- self,
- group_id: str,
- include_private: bool = False,
- ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
- """Get the rooms and categories that should be included in a summary request
-
- Args:
- group_id: The ID of the group to query the summary for
- include_private: Whether to return private rooms in results
-
- Returns:
- A tuple containing:
-
- * A list of dictionaries with the keys:
- * "room_id": str, the room ID
- * "is_public": bool, whether the room is public
- * "category_id": str|None, the category ID if set, else None
- * "order": int, the sort order of rooms
-
- * A dictionary with the key:
- * category_id (str): a dictionary with the keys:
- * "is_public": bool, whether the category is public
- * "profile": str, the category profile
- * "order": int, the sort order of rooms in this category
- """
-
- def _get_rooms_for_summary_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
- keyvalues: JsonDict = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
-
- sql = """
- SELECT room_id, is_public, category_id, room_order
- FROM group_summary_rooms
- WHERE group_id = ?
- AND room_id IN (
- SELECT group_rooms.room_id FROM group_rooms
- LEFT JOIN room_stats_current ON
- group_rooms.room_id = room_stats_current.room_id
- AND joined_members > 0
- AND local_users_in_room > 0
- LEFT JOIN rooms ON
- group_rooms.room_id = rooms.room_id
- AND (room_version <> '') = ?
- )
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, False, True))
- else:
- txn.execute(sql, (group_id, False))
-
- rooms = [
- {
- "room_id": row[0],
- "is_public": row[1],
- "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None,
- "order": row[3],
- }
- for row in txn
- ]
-
- sql = """
- SELECT category_id, is_public, profile, cat_order
- FROM group_summary_room_categories
- INNER JOIN group_room_categories USING (group_id, category_id)
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- categories = {
- row[0]: {
- "is_public": row[1],
- "profile": db_to_json(row[2]),
- "order": row[3],
- }
- for row in txn
- }
-
- return rooms, categories
-
- return await self.db_pool.runInteraction(
- "get_rooms_for_summary", _get_rooms_for_summary_txn
- )
-
- async def get_group_categories(self, group_id: str) -> JsonDict:
- rows = await self.db_pool.simple_select_list(
- table="group_room_categories",
- keyvalues={"group_id": group_id},
- retcols=("category_id", "is_public", "profile"),
- desc="get_group_categories",
- )
-
- return {
- row["category_id"]: {
- "is_public": row["is_public"],
- "profile": db_to_json(row["profile"]),
- }
- for row in rows
- }
-
- async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
- category = await self.db_pool.simple_select_one(
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- retcols=("is_public", "profile"),
- desc="get_group_category",
- )
-
- category["profile"] = db_to_json(category["profile"])
-
- return category
-
- async def get_group_roles(self, group_id: str) -> JsonDict:
- rows = await self.db_pool.simple_select_list(
- table="group_roles",
- keyvalues={"group_id": group_id},
- retcols=("role_id", "is_public", "profile"),
- desc="get_group_roles",
- )
-
- return {
- row["role_id"]: {
- "is_public": row["is_public"],
- "profile": db_to_json(row["profile"]),
- }
- for row in rows
- }
-
- async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
- role = await self.db_pool.simple_select_one(
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- retcols=("is_public", "profile"),
- desc="get_group_role",
- )
-
- role["profile"] = db_to_json(role["profile"])
-
- return role
-
- async def get_local_groups_for_room(self, room_id: str) -> List[str]:
- """Get all of the local group that contain a given room
- Args:
- room_id: The ID of a room
- Returns:
- A list of group ids containing this room
- """
- return await self.db_pool.simple_select_onecol(
- table="group_rooms",
- keyvalues={"room_id": room_id},
- retcol="group_id",
- desc="get_local_groups_for_room",
- )
-
- async def get_users_for_summary_by_role(
- self, group_id: str, include_private: bool = False
- ) -> Tuple[List[JsonDict], JsonDict]:
- """Get the users and roles that should be included in a summary request
-
- Returns:
- ([users], [roles])
- """
-
- def _get_users_for_summary_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], JsonDict]:
- keyvalues: JsonDict = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
-
- sql = """
- SELECT user_id, is_public, role_id, user_order
- FROM group_summary_users
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- users = [
- {
- "user_id": row[0],
- "is_public": row[1],
- "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
- "order": row[3],
- }
- for row in txn
- ]
-
- sql = """
- SELECT role_id, is_public, profile, role_order
- FROM group_summary_roles
- INNER JOIN group_roles USING (group_id, role_id)
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- roles = {
- row[0]: {
- "is_public": row[1],
- "profile": db_to_json(row[2]),
- "order": row[3],
- }
- for row in txn
- }
-
- return users, roles
-
- return await self.db_pool.runInteraction(
- "get_users_for_summary_by_role", _get_users_for_summary_txn
- )
-
- async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
- result = await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="is_user_in_group",
- )
- return bool(result)
-
- async def is_user_admin_in_group(
- self, group_id: str, user_id: str
- ) -> Optional[bool]:
- return await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="is_admin",
- allow_none=True,
- desc="is_user_admin_in_group",
- )
-
- async def is_user_invited_to_local_group(
- self, group_id: str, user_id: str
- ) -> Optional[bool]:
- """Has the group server invited a user?"""
- return await self.db_pool.simple_select_one_onecol(
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- desc="is_user_invited_to_local_group",
- allow_none=True,
- )
-
- async def get_users_membership_info_in_group(
- self, group_id: str, user_id: str
- ) -> JsonDict:
- """Get a dict describing the membership of a user in a group.
-
- Example if joined:
-
- {
- "membership": "join",
- "is_public": True,
- "is_privileged": False,
- }
-
- Returns:
- An empty dict if the user is not join/invite/etc
- """
-
- def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
- row = self.db_pool.simple_select_one_txn(
- txn,
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcols=("is_admin", "is_public"),
- allow_none=True,
- )
-
- if row:
- return {
- "membership": "join",
- "is_public": row["is_public"],
- "is_privileged": row["is_admin"],
- }
-
- row = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- )
-
- if row:
- return {"membership": "invite"}
-
- return {}
-
- return await self.db_pool.runInteraction(
- "get_users_membership_info_in_group", _get_users_membership_in_group_txn
- )
-
- async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
- """Get all groups a user is publicising"""
- return await self.db_pool.simple_select_onecol(
- table="local_group_membership",
- keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
- retcol="group_id",
- desc="get_publicised_groups_for_user",
- )
-
- async def get_attestations_need_renewals(
- self, valid_until_ms: int
- ) -> List[Dict[str, Any]]:
- """Get all attestations that need to be renewed until givent time"""
-
- def _get_attestations_need_renewals_txn(
- txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
- sql = """
- SELECT group_id, user_id FROM group_attestations_renewals
- WHERE valid_until_ms <= ?
- """
- txn.execute(sql, (valid_until_ms,))
- return self.db_pool.cursor_to_dict(txn)
-
- return await self.db_pool.runInteraction(
- "get_attestations_need_renewals", _get_attestations_need_renewals_txn
- )
-
- async def get_remote_attestation(
- self, group_id: str, user_id: str
- ) -> Optional[JsonDict]:
- """Get the attestation that proves the remote agrees that the user is
- in the group.
- """
- row = await self.db_pool.simple_select_one(
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcols=("valid_until_ms", "attestation_json"),
- desc="get_remote_attestation",
- allow_none=True,
- )
-
- now = int(self._clock.time_msec())
- if row and now < row["valid_until_ms"]:
- return db_to_json(row["attestation_json"])
-
- return None
-
- async def get_joined_groups(self, user_id: str) -> List[str]:
- return await self.db_pool.simple_select_onecol(
- table="local_group_membership",
- keyvalues={"user_id": user_id, "membership": "join"},
- retcol="group_id",
- desc="get_joined_groups",
- )
-
- async def get_all_groups_for_user(
- self, user_id: str, now_token: int
- ) -> List[JsonDict]:
- def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
- sql = """
- SELECT group_id, type, membership, u.content
- FROM local_group_updates AS u
- INNER JOIN local_group_membership USING (group_id, user_id)
- WHERE user_id = ? AND membership != 'leave'
- AND stream_id <= ?
- """
- txn.execute(sql, (user_id, now_token))
- return [
- {
- "group_id": row[0],
- "type": row[1],
- "membership": row[2],
- "content": db_to_json(row[3]),
- }
- for row in txn
- ]
-
- return await self.db_pool.runInteraction(
- "get_all_groups_for_user", _get_all_groups_for_user_txn
- )
-
- async def get_groups_changes_for_user(
- self, user_id: str, from_token: int, to_token: int
- ) -> List[JsonDict]:
- has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
- user_id, from_token
- )
- if not has_changed:
- return []
-
- def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
- sql = """
- SELECT group_id, membership, type, u.content
- FROM local_group_updates AS u
- INNER JOIN local_group_membership USING (group_id, user_id)
- WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
- """
- txn.execute(sql, (user_id, from_token, to_token))
- return [
- {
- "group_id": group_id,
- "membership": membership,
- "type": gtype,
- "content": db_to_json(content_json),
- }
- for group_id, membership, gtype, content_json in txn
- ]
-
- return await self.db_pool.runInteraction(
- "get_groups_changes_for_user", _get_groups_changes_for_user_txn
- )
-
- async def get_all_groups_changes(
- self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
- """Get updates for groups replication stream.
-
- Args:
- instance_name: The writer we want to fetch updates from. Unused
- here since there is only ever one writer.
- last_id: The token to fetch updates from. Exclusive.
- current_id: The token to fetch updates up to. Inclusive.
- limit: The requested limit for the number of rows to return. The
- function may return more or fewer rows.
-
- Returns:
- A tuple consisting of: the updates, a token to use to fetch
- subsequent updates, and whether we returned fewer rows than exists
- between the requested tokens due to the limit.
-
- The token returned can be used in a subsequent call to this
- function to get further updatees.
-
- The updates are a list of 2-tuples of stream ID and the row data
- """
-
- last_id = int(last_id)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
-
- if not has_changed:
- return [], current_id, False
-
- def _get_all_groups_changes_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
- sql = """
- SELECT stream_id, group_id, user_id, type, content
- FROM local_group_updates
- WHERE ? < stream_id AND stream_id <= ?
- LIMIT ?
- """
- txn.execute(sql, (last_id, current_id, limit))
- updates = cast(
- List[Tuple[int, tuple]],
- [
- (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
- for stream_id, group_id, user_id, gtype, content_json in txn
- ],
- )
-
- limited = False
- upto_token = current_id
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
-
- return await self.db_pool.runInteraction(
- "get_all_groups_changes", _get_all_groups_changes_txn
- )
-
-
-class GroupServerStore(GroupServerWorkerStore):
- async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
- """Set the join policy of a group.
-
- join_policy can be one of:
- * "invite"
- * "open"
- """
- await self.db_pool.simple_update_one(
- table="groups",
- keyvalues={"group_id": group_id},
- updatevalues={"join_policy": join_policy},
- desc="set_group_join_policy",
- )
-
- async def add_room_to_summary(
- self,
- group_id: str,
- room_id: str,
- category_id: Optional[str],
- order: Optional[int],
- is_public: Optional[bool],
- ) -> None:
- """Add (or update) room's entry in summary.
-
- Args:
- group_id
- room_id
- category_id: If not None then adds the category to the end of
- the summary if its not already there.
- order: If not None inserts the room at that position, e.g. an order
- of 1 will put the room first. Otherwise, the room gets added to
- the end.
- is_public
- """
- await self.db_pool.runInteraction(
- "add_room_to_summary",
- self._add_room_to_summary_txn,
- group_id,
- room_id,
- category_id,
- order,
- is_public,
- )
-
- def _add_room_to_summary_txn(
- self,
- txn: LoggingTransaction,
- group_id: str,
- room_id: str,
- category_id: Optional[str],
- order: Optional[int],
- is_public: Optional[bool],
- ) -> None:
- """Add (or update) room's entry in summary.
-
- Args:
- txn
- group_id
- room_id
- category_id: If not None then adds the category to the end of
- the summary if its not already there.
- order: If not None inserts the room at that position, e.g. an order
- of 1 will put the room first. Otherwise, the room gets added to
- the end.
- is_public
- """
- room_in_group = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_rooms",
- keyvalues={"group_id": group_id, "room_id": room_id},
- retcol="room_id",
- allow_none=True,
- )
- if not room_in_group:
- raise SynapseError(400, "room not in group")
-
- if category_id is None:
- category_id = _DEFAULT_CATEGORY_ID
- else:
- cat_exists = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- retcol="group_id",
- allow_none=True,
- )
- if not cat_exists:
- raise SynapseError(400, "Category doesn't exist")
-
- # TODO: Check category is part of summary already
- cat_exists = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_summary_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- retcol="group_id",
- allow_none=True,
- )
- if not cat_exists:
- # If not, add it with an order larger than all others
- txn.execute(
- """
- INSERT INTO group_summary_room_categories
- (group_id, category_id, cat_order)
- SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
- FROM group_summary_room_categories
- WHERE group_id = ? AND category_id = ?
- """,
- (group_id, category_id, group_id, category_id),
- )
-
- existing = self.db_pool.simple_select_one_txn(
- txn,
- table="group_summary_rooms",
- keyvalues={
- "group_id": group_id,
- "room_id": room_id,
- "category_id": category_id,
- },
- retcols=("room_order", "is_public"),
- allow_none=True,
- )
-
- if order is not None:
- # Shuffle other room orders that come after the given order
- sql = """
- UPDATE group_summary_rooms SET room_order = room_order + 1
- WHERE group_id = ? AND category_id = ? AND room_order >= ?
- """
- txn.execute(sql, (group_id, category_id, order))
- elif not existing:
- sql = """
- SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
- WHERE group_id = ? AND category_id = ?
- """
- txn.execute(sql, (group_id, category_id))
- (order,) = cast(Tuple[int], txn.fetchone())
-
- if existing:
- to_update = {}
- if order is not None:
- to_update["room_order"] = order
- if is_public is not None:
- to_update["is_public"] = is_public
- self.db_pool.simple_update_txn(
- txn,
- table="group_summary_rooms",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- "room_id": room_id,
- },
- updatevalues=to_update,
- )
- else:
- if is_public is None:
- is_public = True
-
- self.db_pool.simple_insert_txn(
- txn,
- table="group_summary_rooms",
- values={
- "group_id": group_id,
- "category_id": category_id,
- "room_id": room_id,
- "room_order": order,
- "is_public": is_public,
- },
- )
-
- async def remove_room_from_summary(
- self, group_id: str, room_id: str, category_id: Optional[str]
- ) -> int:
- if category_id is None:
- category_id = _DEFAULT_CATEGORY_ID
-
- return await self.db_pool.simple_delete(
- table="group_summary_rooms",
- keyvalues={
- "group_id": group_id,
- "category_id": category_id,
- "room_id": room_id,
- },
- desc="remove_room_from_summary",
- )
-
- async def upsert_group_category(
- self,
- group_id: str,
- category_id: str,
- profile: Optional[JsonDict],
- is_public: Optional[bool],
- ) -> None:
- """Add/update room category for group"""
- insertion_values: JsonDict = {}
- update_values: JsonDict = {"category_id": category_id} # This cannot be empty
-
- if profile is None:
- insertion_values["profile"] = "{}"
- else:
- update_values["profile"] = json_encoder.encode(profile)
-
- if is_public is None:
- insertion_values["is_public"] = True
- else:
- update_values["is_public"] = is_public
-
- await self.db_pool.simple_upsert(
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- values=update_values,
- insertion_values=insertion_values,
- desc="upsert_group_category",
- )
-
- async def remove_group_category(self, group_id: str, category_id: str) -> int:
- return await self.db_pool.simple_delete(
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- desc="remove_group_category",
- )
-
- async def upsert_group_role(
- self,
- group_id: str,
- role_id: str,
- profile: Optional[JsonDict],
- is_public: Optional[bool],
- ) -> None:
- """Add/remove user role"""
- insertion_values: JsonDict = {}
- update_values: JsonDict = {"role_id": role_id} # This cannot be empty
-
- if profile is None:
- insertion_values["profile"] = "{}"
- else:
- update_values["profile"] = json_encoder.encode(profile)
-
- if is_public is None:
- insertion_values["is_public"] = True
- else:
- update_values["is_public"] = is_public
-
- await self.db_pool.simple_upsert(
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- values=update_values,
- insertion_values=insertion_values,
- desc="upsert_group_role",
- )
-
- async def remove_group_role(self, group_id: str, role_id: str) -> int:
- return await self.db_pool.simple_delete(
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- desc="remove_group_role",
- )
-
- async def add_user_to_summary(
- self,
- group_id: str,
- user_id: str,
- role_id: Optional[str],
- order: Optional[int],
- is_public: Optional[bool],
- ) -> None:
- """Add (or update) user's entry in summary.
-
- Args:
- group_id
- user_id
- role_id: If not None then adds the role to the end of the summary if
- its not already there.
- order: If not None inserts the user at that position, e.g. an order
- of 1 will put the user first. Otherwise, the user gets added to
- the end.
- is_public
- """
- await self.db_pool.runInteraction(
- "add_user_to_summary",
- self._add_user_to_summary_txn,
- group_id,
- user_id,
- role_id,
- order,
- is_public,
- )
-
- def _add_user_to_summary_txn(
- self,
- txn: LoggingTransaction,
- group_id: str,
- user_id: str,
- role_id: Optional[str],
- order: Optional[int],
- is_public: Optional[bool],
- ) -> None:
- """Add (or update) user's entry in summary.
-
- Args:
- txn
- group_id
- user_id
- role_id: If not None then adds the role to the end of the summary if
- its not already there.
- order: If not None inserts the user at that position, e.g. an order
- of 1 will put the user first. Otherwise, the user gets added to
- the end.
- is_public
- """
- user_in_group = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- )
- if not user_in_group:
- raise SynapseError(400, "user not in group")
-
- if role_id is None:
- role_id = _DEFAULT_ROLE_ID
- else:
- role_exists = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- retcol="group_id",
- allow_none=True,
- )
- if not role_exists:
- raise SynapseError(400, "Role doesn't exist")
-
- # TODO: Check role is part of the summary already
- role_exists = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="group_summary_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- retcol="group_id",
- allow_none=True,
- )
- if not role_exists:
- # If not, add it with an order larger than all others
- txn.execute(
- """
- INSERT INTO group_summary_roles
- (group_id, role_id, role_order)
- SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
- FROM group_summary_roles
- WHERE group_id = ? AND role_id = ?
- """,
- (group_id, role_id, group_id, role_id),
- )
-
- existing = self.db_pool.simple_select_one_txn(
- txn,
- table="group_summary_users",
- keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
- retcols=("user_order", "is_public"),
- allow_none=True,
- )
-
- if order is not None:
- # Shuffle other users orders that come after the given order
- sql = """
- UPDATE group_summary_users SET user_order = user_order + 1
- WHERE group_id = ? AND role_id = ? AND user_order >= ?
- """
- txn.execute(sql, (group_id, role_id, order))
- elif not existing:
- sql = """
- SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
- WHERE group_id = ? AND role_id = ?
- """
- txn.execute(sql, (group_id, role_id))
- (order,) = cast(Tuple[int], txn.fetchone())
-
- if existing:
- to_update = {}
- if order is not None:
- to_update["user_order"] = order
- if is_public is not None:
- to_update["is_public"] = is_public
- self.db_pool.simple_update_txn(
- txn,
- table="group_summary_users",
- keyvalues={
- "group_id": group_id,
- "role_id": role_id,
- "user_id": user_id,
- },
- updatevalues=to_update,
- )
- else:
- if is_public is None:
- is_public = True
-
- self.db_pool.simple_insert_txn(
- txn,
- table="group_summary_users",
- values={
- "group_id": group_id,
- "role_id": role_id,
- "user_id": user_id,
- "user_order": order,
- "is_public": is_public,
- },
- )
-
- async def remove_user_from_summary(
- self, group_id: str, user_id: str, role_id: Optional[str]
- ) -> int:
- if role_id is None:
- role_id = _DEFAULT_ROLE_ID
-
- return await self.db_pool.simple_delete(
- table="group_summary_users",
- keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
- desc="remove_user_from_summary",
- )
-
- async def add_group_invite(self, group_id: str, user_id: str) -> None:
- """Record that the group server has invited a user"""
- await self.db_pool.simple_insert(
- table="group_invites",
- values={"group_id": group_id, "user_id": user_id},
- desc="add_group_invite",
- )
-
- async def add_user_to_group(
- self,
- group_id: str,
- user_id: str,
- is_admin: bool = False,
- is_public: bool = True,
- local_attestation: Optional[dict] = None,
- remote_attestation: Optional[dict] = None,
- ) -> None:
- """Add a user to the group server.
-
- Args:
- group_id
- user_id
- is_admin
- is_public
- local_attestation: The attestation the GS created to give to the remote
- server. Optional if the user and group are on the same server
- remote_attestation: The attestation given to GS by remote server.
- Optional if the user and group are on the same server
- """
-
- def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_users",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "is_admin": is_admin,
- "is_public": is_public,
- },
- )
-
- self.db_pool.simple_delete_txn(
- txn,
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
-
- if local_attestation:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_attestations_renewals",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": local_attestation["valid_until_ms"],
- },
- )
- if remote_attestation:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_attestations_remote",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json_encoder.encode(remote_attestation),
- },
- )
-
- await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
-
- async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
- def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_attestations_renewals",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_summary_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
-
- await self.db_pool.runInteraction(
- "remove_user_from_group", _remove_user_from_group_txn
- )
-
- async def add_room_to_group(
- self, group_id: str, room_id: str, is_public: bool
- ) -> None:
- await self.db_pool.simple_insert(
- table="group_rooms",
- values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
- desc="add_room_to_group",
- )
-
- async def update_room_in_group_visibility(
- self, group_id: str, room_id: str, is_public: bool
- ) -> int:
- return await self.db_pool.simple_update(
- table="group_rooms",
- keyvalues={"group_id": group_id, "room_id": room_id},
- updatevalues={"is_public": is_public},
- desc="update_room_in_group_visibility",
- )
-
- async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
- def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="group_rooms",
- keyvalues={"group_id": group_id, "room_id": room_id},
- )
-
- self.db_pool.simple_delete_txn(
- txn,
- table="group_summary_rooms",
- keyvalues={"group_id": group_id, "room_id": room_id},
- )
-
- await self.db_pool.runInteraction(
- "remove_room_from_group", _remove_room_from_group_txn
- )
-
- async def update_group_publicity(
- self, group_id: str, user_id: str, publicise: bool
- ) -> None:
- """Update whether the user is publicising their membership of the group"""
- await self.db_pool.simple_update_one(
- table="local_group_membership",
- keyvalues={"group_id": group_id, "user_id": user_id},
- updatevalues={"is_publicised": publicise},
- desc="update_group_publicity",
- )
-
- async def register_user_group_membership(
- self,
- group_id: str,
- user_id: str,
- membership: str,
- is_admin: bool = False,
- content: Optional[JsonDict] = None,
- local_attestation: Optional[dict] = None,
- remote_attestation: Optional[dict] = None,
- is_publicised: bool = False,
- ) -> int:
- """Registers that a local user is a member of a (local or remote) group.
-
- Args:
- group_id: The group the member is being added to.
- user_id: THe user ID to add to the group.
- membership: The type of group membership.
- is_admin: Whether the user should be added as a group admin.
- content: Content of the membership, e.g. includes the inviter
- if the user has been invited.
- local_attestation: If remote group then store the fact that we
- have given out an attestation, else None.
- remote_attestation: If remote group then store the remote
- attestation from the group, else None.
- is_publicised: Whether this should be publicised.
- """
-
- content = content or {}
-
- def _register_user_group_membership_txn(
- txn: LoggingTransaction, next_id: int
- ) -> int:
- # TODO: Upsert?
- self.db_pool.simple_delete_txn(
- txn,
- table="local_group_membership",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_insert_txn(
- txn,
- table="local_group_membership",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "is_admin": is_admin,
- "membership": membership,
- "is_publicised": is_publicised,
- "content": json_encoder.encode(content),
- },
- )
-
- self.db_pool.simple_insert_txn(
- txn,
- table="local_group_updates",
- values={
- "stream_id": next_id,
- "group_id": group_id,
- "user_id": user_id,
- "type": "membership",
- "content": json_encoder.encode(
- {"membership": membership, "content": content}
- ),
- },
- )
- self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
-
- # TODO: Insert profile to ensure it comes down stream if its a join.
-
- if membership == "join":
- if local_attestation:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_attestations_renewals",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": local_attestation["valid_until_ms"],
- },
- )
- if remote_attestation:
- self.db_pool.simple_insert_txn(
- txn,
- table="group_attestations_remote",
- values={
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json_encoder.encode(remote_attestation),
- },
- )
- else:
- self.db_pool.simple_delete_txn(
- txn,
- table="group_attestations_renewals",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- )
-
- return next_id
-
- async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
- res = await self.db_pool.runInteraction(
- "register_user_group_membership",
- _register_user_group_membership_txn,
- next_id,
- )
- return res
-
- async def create_group(
- self,
- group_id: str,
- user_id: str,
- name: str,
- avatar_url: str,
- short_description: str,
- long_description: str,
- ) -> None:
- await self.db_pool.simple_insert(
- table="groups",
- values={
- "group_id": group_id,
- "name": name,
- "avatar_url": avatar_url,
- "short_description": short_description,
- "long_description": long_description,
- "is_public": True,
- },
- desc="create_group",
- )
-
- async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
- await self.db_pool.simple_update_one(
- table="groups",
- keyvalues={"group_id": group_id},
- updatevalues=profile,
- desc="update_group_profile",
- )
-
- async def update_attestation_renewal(
- self, group_id: str, user_id: str, attestation: dict
- ) -> None:
- """Update an attestation that we have renewed"""
- await self.db_pool.simple_update_one(
- table="group_attestations_renewals",
- keyvalues={"group_id": group_id, "user_id": user_id},
- updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
- desc="update_attestation_renewal",
- )
-
- async def update_remote_attestion(
- self, group_id: str, user_id: str, attestation: dict
- ) -> None:
- """Update an attestation that a remote has renewed"""
- await self.db_pool.simple_update_one(
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- updatevalues={
- "valid_until_ms": attestation["valid_until_ms"],
- "attestation_json": json_encoder.encode(attestation),
- },
- desc="update_remote_attestion",
- )
-
- async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int:
- """Remove an attestation that we thought we should renew, but actually
- shouldn't. Ideally this would never get called as we would never
- incorrectly try and do attestations for local users on local groups.
-
- Args:
- group_id
- user_id
- """
- return await self.db_pool.simple_delete(
- table="group_attestations_renewals",
- keyvalues={"group_id": group_id, "user_id": user_id},
- desc="remove_attestation_renewal",
- )
-
- def get_group_stream_token(self) -> int:
- return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
-
- async def delete_group(self, group_id: str) -> None:
- """Deletes a group fully from the database.
-
- Args:
- group_id: The group ID to delete.
- """
-
- def _delete_group_txn(txn: LoggingTransaction) -> None:
- tables = [
- "groups",
- "group_users",
- "group_invites",
- "group_rooms",
- "group_summary_rooms",
- "group_summary_room_categories",
- "group_room_categories",
- "group_summary_users",
- "group_summary_roles",
- "group_roles",
- "group_attestations_renewals",
- "group_attestations_remote",
- ]
-
- for table in tables:
- self.db_pool.simple_delete_txn(
- txn, table=table, keyvalues={"group_id": group_id}
- )
-
- await self.db_pool.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index bedacaf0..2d7633fb 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from types import TracebackType
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@@ -84,6 +84,8 @@ class LockStore(SQLBaseStore):
self._on_shutdown,
)
+ self._acquiring_locks: Set[Tuple[str, str]] = set()
+
@wrap_as_background_process("LockStore._on_shutdown")
async def _on_shutdown(self) -> None:
"""Called when the server is shutting down"""
@@ -103,6 +105,21 @@ class LockStore(SQLBaseStore):
context manager if the lock is successfully acquired, which *must* be
used (otherwise the lock will leak).
"""
+ if (lock_name, lock_key) in self._acquiring_locks:
+ return None
+ try:
+ self._acquiring_locks.add((lock_name, lock_key))
+ return await self._try_acquire_lock(lock_name, lock_key)
+ finally:
+ self._acquiring_locks.discard((lock_name, lock_key))
+
+ async def _try_acquire_lock(
+ self, lock_name: str, lock_key: str
+ ) -> Optional["Lock"]:
+ """Try to acquire a lock for the given name/key. Will return an async
+ context manager if the lock is successfully acquired, which *must* be
+ used (otherwise the lock will leak).
+ """
# Check if this process has taken out a lock and if it's still valid.
lock = self._live_tokens.get((lock_name, lock_key))
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 40ac377c..d028be16 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -251,12 +251,36 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
)
- async def get_local_media_before(
+ async def get_local_media_ids(
self,
before_ts: int,
size_gt: int,
keep_profiles: bool,
+ include_quarantined_media: bool,
+ include_protected_media: bool,
) -> List[str]:
+ """
+ Retrieve a list of media IDs from the local media store.
+
+ Args:
+ before_ts: Only retrieve IDs from media that was either last accessed
+ (or if never accessed, created) before the given UNIX timestamp in ms.
+ size_gt: Only retrieve IDs from media that has a size (in bytes) greater than
+ the given integer.
+ keep_profiles: If True, exclude media IDs from the results that are used in the
+ following situations:
+ * global profile user avatar
+ * per-room profile user avatar
+ * room avatar
+ * a user's avatar in the user directory
+ include_quarantined_media: If False, exclude media IDs from the results that have
+ been marked as quarantined.
+ include_protected_media: If False, exclude media IDs from the results that have
+ been marked as protected from quarantine.
+
+ Returns:
+ A list of local media IDs.
+ """
# to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts`
@@ -278,10 +302,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
AND NOT EXISTS
(SELECT 1
- FROM groups
- WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id)
- AND NOT EXISTS
- (SELECT 1
FROM room_memberships
WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id)
AND NOT EXISTS
@@ -298,12 +318,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
sql += sql_keep
- def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
+ if include_quarantined_media is False:
+ # Do not include media that has been quarantined
+ sql += """
+ AND quarantined_by IS NULL
+ """
+
+ if include_protected_media is False:
+ # Do not include media that has been protected from quarantine
+ sql += """
+ AND NOT safe_from_quarantine
+ """
+
+ def _get_local_media_ids_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts, before_ts, size_gt))
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
- "get_local_media_before", _get_local_media_before_txn
+ "get_local_media_ids", _get_local_media_ids_txn
)
async def store_local_media(
@@ -603,15 +635,37 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
- async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
+ async def get_remote_media_ids(
+ self, before_ts: int, include_quarantined_media: bool
+ ) -> List[Dict[str, str]]:
+ """
+ Retrieve a list of server name, media ID tuples from the remote media cache.
+
+ Args:
+ before_ts: Only retrieve IDs from media that was either last accessed
+ (or if never accessed, created) before the given UNIX timestamp in ms.
+ include_quarantined_media: If False, exclude media IDs from the results that have
+ been marked as quarantined.
+
+ Returns:
+ A list of tuples containing:
+ * The server name of homeserver where the media originates from,
+ * The ID of the media.
+ """
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
+ if include_quarantined_media is False:
+ # Only include media that has not been quarantined
+ sql += """
+ AND quarantined_by IS NULL
+ """
+
return await self.db_pool.execute(
- "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
+ "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f0..14294a0b 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,12 +14,16 @@
import calendar
import logging
import time
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, List, Tuple, cast
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
@@ -71,8 +75,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
self._last_user_visit_update = self._get_start_of_day()
@wrap_as_background_process("read_forward_extremities")
- async def _read_forward_extremities(self):
- def fetch(txn):
+ async def _read_forward_extremities(self) -> None:
+ def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
txn.execute(
"""
SELECT t1.c, t2.c
@@ -85,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) t2 ON t1.room_id = t2.room_id
"""
)
- return txn.fetchall()
+ return cast(List[Tuple[int, int]], txn.fetchall())
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
@@ -95,7 +99,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
- async def count_daily_e2ee_messages(self):
+ async def count_daily_e2ee_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@@ -103,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
- def _count_messages(txn):
+ def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
- async def count_daily_sent_e2ee_messages(self):
- def _count_messages(txn):
+ async def count_daily_sent_e2ee_messages(self) -> int:
+ def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@@ -129,29 +133,29 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
"count_daily_sent_e2ee_messages", _count_messages
)
- async def count_daily_active_e2ee_rooms(self):
- def _count(txn):
+ async def count_daily_active_e2ee_rooms(self) -> int:
+ def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
"count_daily_active_e2ee_rooms", _count
)
- async def count_daily_messages(self):
+ async def count_daily_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@@ -159,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
- def _count_messages(txn):
+ def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_messages", _count_messages)
- async def count_daily_sent_messages(self):
- def _count_messages(txn):
+ async def count_daily_sent_messages(self) -> int:
+ def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@@ -185,22 +189,22 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages
)
- async def count_daily_active_rooms(self):
- def _count(txn):
+ async def count_daily_active_rooms(self) -> int:
+ def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -226,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)
- def _count_users(self, txn, time_from):
+ def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@@ -238,7 +242,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) u
"""
txn.execute(sql, (time_from,))
- (count,) = txn.fetchone()
+ # Mypy knows that fetchone() might return None if there are no rows.
+ # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+ # returns exactly one row.
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
async def count_r30_users(self) -> Dict[str, int]:
@@ -252,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
A mapping of counts globally as well as broken out by platform.
"""
- def _count_r30_users(txn):
+ def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -317,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count
return results
@@ -344,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""
- def _count_r30v2_users(txn):
+ def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -441,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
thirty_days_in_secs * 1000,
),
)
- row = txn.fetchone()
- if row is None:
- results["all"] = 0
- else:
- results["all"] = row[0]
+ (count,) = cast(Tuple[int], txn.fetchone())
+ results["all"] = count
return results
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_r30v2_users", _count_r30v2_users
)
- def _get_start_of_day(self):
+ def _get_start_of_day(self) -> int:
"""
Returns millisecond unixtime for start of UTC day.
"""
@@ -467,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
Generates daily visit data for use in cohort/ retention analysis
"""
- def _generate_user_daily_visits(txn):
+ def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 5beb8f1d..9a63f953 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -122,6 +122,51 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
"count_users_by_service", _count_users_by_service
)
+ async def get_monthly_active_users_by_service(
+ self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None
+ ) -> List[Tuple[str, str]]:
+ """Generates list of monthly active users and their services.
+ Please see "get_monthly_active_count_by_service" docstring for more details
+ about services.
+
+ Arguments:
+ start_timestamp: If specified, only include users that were first active
+ at or after this point
+ end_timestamp: If specified, only include users that were first active
+ at or before this point
+
+ Returns:
+ A list of tuples (appservice_id, user_id). "native" is emitted as the
+ appservice for users that don't come from appservices (i.e. native Matrix
+ users).
+
+ """
+ if start_timestamp is not None and end_timestamp is not None:
+ where_clause = 'WHERE "timestamp" >= ? and "timestamp" <= ?'
+ query_params = [start_timestamp, end_timestamp]
+ elif start_timestamp is not None:
+ where_clause = 'WHERE "timestamp" >= ?'
+ query_params = [start_timestamp]
+ elif end_timestamp is not None:
+ where_clause = 'WHERE "timestamp" <= ?'
+ query_params = [end_timestamp]
+ else:
+ where_clause = ""
+ query_params = []
+
+ def _list_users(txn: LoggingTransaction) -> List[Tuple[str, str]]:
+ sql = f"""
+ SELECT COALESCE(appservice_id, 'native'), user_id
+ FROM monthly_active_users
+ LEFT JOIN users ON monthly_active_users.user_id=users.name
+ {where_clause};
+ """
+
+ txn.execute(sql, query_params)
+ return cast(List[Tuple[str, str]], txn.fetchall())
+
+ return await self.db_pool.runInteraction("list_users", _list_users)
+
async def get_registered_reserved_users(self) -> List[str]:
"""Of the reserved threepids defined in config, retrieve those that are associated
with registered users
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index b47c5114..9769a18a 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -22,6 +22,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import (
@@ -56,7 +57,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
)
-class PresenceStore(PresenceBackgroundUpdateStore):
+class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -281,20 +282,30 @@ class PresenceStore(PresenceBackgroundUpdateStore):
True if the user should have full presence sent to them, False otherwise.
"""
- def _should_user_receive_full_presence_with_token_txn(
- txn: LoggingTransaction,
- ) -> bool:
- sql = """
- SELECT 1 FROM users_to_send_full_presence_to
- WHERE user_id = ?
- AND presence_stream_id >= ?
- """
- txn.execute(sql, (user_id, from_token))
- return bool(txn.fetchone())
+ token = await self._get_full_presence_stream_token_for_user(user_id)
+ if token is None:
+ return False
- return await self.db_pool.runInteraction(
- "should_user_receive_full_presence_with_token",
- _should_user_receive_full_presence_with_token_txn,
+ return from_token <= token
+
+ @cached()
+ async def _get_full_presence_stream_token_for_user(
+ self, user_id: str
+ ) -> Optional[int]:
+ """Get the presence token corresponding to the last full presence update
+ for this user.
+
+ If the user presents a sync token with a presence stream token at least
+ as old as the result, then we need to send them a full presence update.
+
+ If this user has never needed a full presence update, returns `None`.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="users_to_send_full_presence_to",
+ keyvalues={"user_id": user_id},
+ retcol="presence_stream_id",
+ allow_none=True,
+ desc="_get_full_presence_stream_token_for_user",
)
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
@@ -307,18 +318,28 @@ class PresenceStore(PresenceBackgroundUpdateStore):
# Add user entries to the table, updating the presence_stream_id column if the user already
# exists in the table.
presence_stream_id = self._presence_id_gen.get_current_token()
- await self.db_pool.simple_upsert_many(
- table="users_to_send_full_presence_to",
- key_names=("user_id",),
- key_values=[(user_id,) for user_id in user_ids],
- value_names=("presence_stream_id",),
- # We save the current presence stream ID token along with the user ID entry so
- # that when a user /sync's, even if they syncing multiple times across separate
- # devices at different times, each device will receive full presence once - when
- # the presence stream ID in their sync token is less than the one in the table
- # for their user ID.
- value_values=[(presence_stream_id,) for _ in user_ids],
- desc="add_users_to_send_full_presence_to",
+
+ def _add_users_to_send_full_presence_to(txn: LoggingTransaction) -> None:
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="users_to_send_full_presence_to",
+ key_names=("user_id",),
+ key_values=[(user_id,) for user_id in user_ids],
+ value_names=("presence_stream_id",),
+ # We save the current presence stream ID token along with the user ID entry so
+ # that when a user /sync's, even if they syncing multiple times across separate
+ # devices at different times, each device will receive full presence once - when
+ # the presence stream ID in their sync token is less than the one in the table
+ # for their user ID.
+ value_values=[(presence_stream_id,) for _ in user_ids],
+ )
+ for user_id in user_ids:
+ self._invalidate_cache_and_stream(
+ txn, self._get_full_presence_stream_token_for_user, (user_id,)
+ )
+
+ return await self.db_pool.runInteraction(
+ "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to
)
async def get_presence_for_all_users(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index e197b720..a1747f04 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
@@ -55,17 +54,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
- async def get_from_remote_profile_cache(
- self, user_id: str
- ) -> Optional[Dict[str, Any]]:
- return await self.db_pool.simple_select_one(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- retcols=("displayname", "avatar_url"),
- allow_none=True,
- desc="get_from_remote_profile_cache",
- )
-
async def create_profile(self, user_localpart: str) -> None:
await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
@@ -91,97 +79,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="set_profile_avatar_url",
)
- async def update_remote_profile_cache(
- self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
- ) -> int:
- return await self.db_pool.simple_update(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- updatevalues={
- "displayname": displayname,
- "avatar_url": avatar_url,
- "last_check": self._clock.time_msec(),
- },
- desc="update_remote_profile_cache",
- )
-
- async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
- """Check if we still care about the remote user's profile, and if we
- don't then remove their profile from the cache
- """
- subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
- if not subscribed:
- await self.db_pool.simple_delete(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- desc="delete_remote_profile_cache",
- )
-
- async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
- """Check whether we are interested in a remote user's profile."""
- res: Optional[str] = await self.db_pool.simple_select_one_onecol(
- table="group_users",
- keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
- )
-
- if res:
- return True
-
- res = await self.db_pool.simple_select_one_onecol(
- table="group_invites",
- keyvalues={"user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="should_update_remote_profile_cache_for_user",
- )
-
- if res:
- return True
- return False
-
- async def get_remote_profile_cache_entries_that_expire(
- self, last_checked: int
- ) -> List[Dict[str, str]]:
- """Get all users who haven't been checked since `last_checked`"""
-
- def _get_remote_profile_cache_entries_that_expire_txn(
- txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
- sql = """
- SELECT user_id, displayname, avatar_url
- FROM remote_profile_cache
- WHERE last_check < ?
- """
-
- txn.execute(sql, (last_checked,))
-
- return self.db_pool.cursor_to_dict(txn)
-
- return await self.db_pool.runInteraction(
- "get_remote_profile_cache_entries_that_expire",
- _get_remote_profile_cache_entries_that_expire_txn,
- )
-
class ProfileStore(ProfileWorkerStore):
- async def add_remote_profile_cache(
- self, user_id: str, displayname: str, avatar_url: str
- ) -> None:
- """Ensure we are caching the remote user's profiles.
-
- This should only be called when `is_subscribed_remote_profile_for_user`
- would return true for the user.
- """
- await self.db_pool.simple_upsert(
- table="remote_profile_cache",
- keyvalues={"user_id": user_id},
- values={
- "displayname": displayname,
- "avatar_url": avatar_url,
- "last_check": self._clock.time_msec(),
- },
- desc="add_remote_profile_cache",
- )
+ pass
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bfc85b3a..ba385f9f 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# event_forward_extremities
# event_json
# event_push_actions
- # event_reference_hashes
# event_relations
# event_search
# event_to_state_groups
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_auth",
"event_edges",
"event_forward_extremities",
- "event_reference_hashes",
"event_relations",
"event_search",
"rejections",
@@ -324,12 +322,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
)
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
- # We *immediately* delete the room from the rooms table. This ensures
- # that we don't race when persisting events (as that transaction checks
- # that the room exists).
- txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,))
-
- # Next, we fetch all the state groups that should be deleted, before
+ # First, fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
"""
@@ -369,7 +362,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_edges",
"event_json",
"event_push_actions_staging",
- "event_reference_hashes",
"event_relations",
"event_to_state_groups",
"event_auth_chains",
@@ -390,7 +382,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
(room_id,),
)
- # and finally, the tables with an index on room_id (or no useful index)
+ # next, the tables with an index on room_id (or no useful index)
for table in (
"current_state_events",
"destination_rooms",
@@ -398,8 +390,12 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_forward_extremities",
"event_push_actions",
"event_search",
+ "partial_state_events",
"events",
- "group_rooms",
+ "federation_inbound_events_staging",
+ "local_current_membership",
+ "partial_state_rooms_servers",
+ "partial_state_rooms",
"receipts_graph",
"receipts_linearized",
"room_aliases",
@@ -416,10 +412,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
- "group_summary_rooms",
"room_account_data",
"room_tags",
- "local_current_membership",
+ # "rooms" happens last, to keep the foreign keys in the other tables
+ # happy
+ "rooms",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 4ed913e2..d5aefe02 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,14 +14,18 @@
# limitations under the License.
import abc
import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
AbstractStreamIdTracker,
+ IdGenerator,
StreamIdGenerator,
)
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -54,10 +61,19 @@ def _is_experimental_rule_enabled(
and not experimental_config.msc3786_enabled
):
return False
+ if (
+ rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+ and not experimental_config.msc3772_enabled
+ ):
+ return False
return True
-def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
+def _load_rules(
+ rawrules: List[JsonDict],
+ enabled_map: Dict[str, bool],
+ experimental_config: ExperimentalConfig,
+) -> List[JsonDict]:
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@@ -137,7 +153,7 @@ class PushRulesWorkerStore(
)
@abc.abstractmethod
- def get_max_push_rules_stream_id(self):
+ def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream.
Returns:
@@ -146,7 +162,7 @@ class PushRulesWorkerStore(
raise NotImplementedError()
@cached(max_entries=5000)
- async def get_push_rules_for_user(self, user_id):
+ async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -158,7 +174,7 @@ class PushRulesWorkerStore(
"conditions",
"actions",
),
- desc="get_push_rules_enabled_for_user",
+ desc="get_push_rules_for_user",
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
@@ -168,14 +184,14 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, self.hs.config.experimental)
@cached(max_entries=5000)
- async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
+ async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
- retcols=("user_name", "rule_id", "enabled"),
+ retcols=("rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
)
- return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
+ return {r["rule_id"]: bool(r["enabled"]) for r in results}
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
@@ -184,29 +200,27 @@ class PushRulesWorkerStore(
return False
else:
- def have_push_rules_changed_txn(txn):
+ def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return bool(count)
return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
- @cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- )
- async def bulk_get_push_rules(self, user_ids):
+ @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
+ async def bulk_get_push_rules(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, List[JsonDict]]:
if not user_ids:
return {}
- results = {user_id: [] for user_id in user_ids}
+ results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -230,67 +244,16 @@ class PushRulesWorkerStore(
return results
- async def copy_push_rule_from_room_to_room(
- self, new_room_id: str, user_id: str, rule: dict
- ) -> None:
- """Copy a single push rule from one room to another for a specific user.
-
- Args:
- new_room_id: ID of the new room.
- user_id : ID of user the push rule belongs to.
- rule: A push rule.
- """
- # Create new rule id
- rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
- new_rule_id = rule_id_scope + "/" + new_room_id
-
- # Change room id in each condition
- for condition in rule.get("conditions", []):
- if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
-
- # Add the rule for the new room
- await self.add_push_rule(
- user_id=user_id,
- rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
- )
-
- async def copy_push_rules_from_room_to_room_for_user(
- self, old_room_id: str, new_room_id: str, user_id: str
- ) -> None:
- """Copy all of the push rules from one room to another for a specific
- user.
-
- Args:
- old_room_id: ID of the old room.
- new_room_id: ID of the new room.
- user_id: ID of user to copy push rules for.
- """
- # Retrieve push rules for this user
- user_push_rules = await self.get_push_rules_for_user(user_id)
-
- # Get rules relating to the old room and copy them to the new room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
- if any(
- (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
- for c in conditions
- ):
- await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
@cachedList(
- cached_method_name="get_push_rules_enabled_for_user",
- list_name="user_ids",
- num_args=1,
+ cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
)
- async def bulk_get_push_rules_enabled(self, user_ids):
+ async def bulk_get_push_rules_enabled(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, Dict[str, bool]]:
if not user_ids:
return {}
- results = {user_id: {} for user_id in user_ids}
+ results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
@@ -306,7 +269,7 @@ class PushRulesWorkerStore(
async def get_all_push_rule_updates(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
"""Get updates for push_rules replication stream.
Args:
@@ -331,7 +294,9 @@ class PushRulesWorkerStore(
if last_id == current_id:
return [], current_id, False
- def get_all_push_rule_updates_txn(txn):
+ def get_all_push_rule_updates_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
sql = """
SELECT stream_id, user_id
FROM push_rules_stream
@@ -340,7 +305,10 @@ class PushRulesWorkerStore(
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
+ updates = cast(
+ List[Tuple[int, Tuple[str]]],
+ [(stream_id, (user_id,)) for stream_id, user_id in txn],
+ )
limited = False
upper_bound = current_id
@@ -356,15 +324,30 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
+ # Because we have write access, this will be a StreamIdGenerator
+ # (see PushRulesWorkerStore.__init__)
+ _push_rules_stream_id_gen: AbstractStreamIdGenerator
+
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
async def add_push_rule(
self,
- user_id,
- rule_id,
- priority_class,
- conditions,
- actions,
- before=None,
- after=None,
+ user_id: str,
+ rule_id: str,
+ priority_class: int,
+ conditions: List[Dict[str, str]],
+ actions: List[Union[JsonDict, str]],
+ before: Optional[str] = None,
+ after: Optional[str] = None,
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
@@ -400,17 +383,17 @@ class PushRuleStore(PushRulesWorkerStore):
def _add_push_rule_relative_txn(
self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- before,
- after,
- ):
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ priority_class: int,
+ conditions_json: str,
+ actions_json: str,
+ before: str,
+ after: str,
+ ) -> None:
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
@@ -470,15 +453,15 @@ class PushRuleStore(PushRulesWorkerStore):
def _add_push_rule_highest_priority_txn(
self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- ):
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ priority_class: int,
+ conditions_json: str,
+ actions_json: str,
+ ) -> None:
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
@@ -510,17 +493,17 @@ class PushRuleStore(PushRulesWorkerStore):
def _upsert_push_rule_txn(
self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- priority,
- conditions_json,
- actions_json,
- update_stream=True,
- ):
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ priority_class: int,
+ priority: int,
+ conditions_json: str,
+ actions_json: str,
+ update_stream: bool = True,
+ ) -> None:
"""Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked"""
@@ -600,7 +583,11 @@ class PushRuleStore(PushRulesWorkerStore):
rule_id: The rule_id of the rule to be deleted
"""
- def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ def delete_push_rule_txn(
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ ) -> None:
# we don't use simple_delete_one_txn because that would fail if the
# user did not have a push_rule_enable row.
self.db_pool.simple_delete_txn(
@@ -661,14 +648,14 @@ class PushRuleStore(PushRulesWorkerStore):
def _set_push_rule_enabled_txn(
self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- enabled,
- is_default_rule,
- ):
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ enabled: bool,
+ is_default_rule: bool,
+ ) -> None:
new_id = self._push_rules_enable_id_gen.get_next()
if not is_default_rule:
@@ -740,7 +727,11 @@ class PushRuleStore(PushRulesWorkerStore):
"""
actions_json = json_encoder.encode(actions)
- def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+ def set_push_rule_actions_txn(
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ ) -> None:
if is_default_rule:
# Add a dummy rule to the rules table with the user specified
# actions.
@@ -794,8 +785,15 @@ class PushRuleStore(PushRulesWorkerStore):
)
def _insert_push_rules_update_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
- ):
+ self,
+ txn: LoggingTransaction,
+ stream_id: int,
+ event_stream_ordering: int,
+ user_id: str,
+ rule_id: str,
+ op: str,
+ data: Optional[JsonDict] = None,
+ ) -> None:
values = {
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
@@ -814,5 +812,56 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_max_push_rules_stream_id(self):
+ def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token()
+
+ async def copy_push_rule_from_room_to_room(
+ self, new_room_id: str, user_id: str, rule: dict
+ ) -> None:
+ """Copy a single push rule from one room to another for a specific user.
+
+ Args:
+ new_room_id: ID of the new room.
+ user_id : ID of user the push rule belongs to.
+ rule: A push rule.
+ """
+ # Create new rule id
+ rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ new_rule_id = rule_id_scope + "/" + new_room_id
+
+ # Change room id in each condition
+ for condition in rule.get("conditions", []):
+ if condition.get("key") == "room_id":
+ condition["pattern"] = new_room_id
+
+ # Add the rule for the new room
+ await self.add_push_rule(
+ user_id=user_id,
+ rule_id=new_rule_id,
+ priority_class=rule["priority_class"],
+ conditions=rule["conditions"],
+ actions=rule["actions"],
+ )
+
+ async def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id: str, new_room_id: str, user_id: str
+ ) -> None:
+ """Copy all of the push rules from one room to another for a specific
+ user.
+
+ Args:
+ old_room_id: ID of the old room.
+ new_room_id: ID of the new room.
+ user_id: ID of user to copy push rules for.
+ """
+ # Retrieve push rules for this user
+ user_push_rules = await self.get_push_rules_for_user(user_id)
+
+ # Get rules relating to the old room and copy them to the new room
+ for rule in user_push_rules:
+ conditions = rule.get("conditions", [])
+ if any(
+ (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+ for c in conditions
+ ):
+ await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 91286c9b..bd0cfa7f 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -91,12 +91,6 @@ class PusherWorkerStore(SQLBaseStore):
yield PusherConfig(**r)
- async def user_has_pusher(self, user_id: str) -> bool:
- ret = await self.db_pool.simple_select_one_onecol(
- "pushers", {"user_name": user_id}, "id", allow_none=True
- )
- return ret is not None
-
async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
) -> Iterator[PusherConfig]:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index d035969a..21e954cc 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -26,7 +26,7 @@ from typing import (
cast,
)
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -363,7 +363,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
row["user_id"]
] = db_to_json(row["data"])
- return [{"type": "m.receipt", "room_id": room_id, "content": content}]
+ return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@cachedList(
cached_method_name="_get_linearized_receipts_for_room",
@@ -411,7 +411,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
- {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+ {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
@@ -476,7 +476,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
- {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+ {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
@@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
- def insert_linearized_receipt_txn(
+ def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
@@ -673,8 +673,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
lock=False,
)
+ # When updating a local users read receipt, remove any push actions
+ # which resulted from the receipt's event and all earlier events.
if (
- receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
+ self.hs.is_mine_id(user_id)
+ and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
and stream_ordering is not None
):
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
@@ -683,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rx_ts
+ def _graph_to_linear(
+ self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
+ ) -> str:
+ """
+ Generate a linearized event from a list of events (i.e. a list of forward
+ extremities in the room).
+
+ This should allow for calculation of the correct read receipt even if
+ servers have different event ordering.
+
+ Args:
+ txn: The transaction
+ room_id: The room ID the events are in.
+ event_ids: The list of event IDs to linearize.
+
+ Returns:
+ The linearized event ID.
+ """
+ # TODO: Make this better.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", event_ids
+ )
+
+ sql = """
+ SELECT event_id WHERE room_id = ? AND stream_ordering IN (
+ SELECT max(stream_ordering) WHERE %s
+ )
+ """ % (
+ clause,
+ )
+
+ txn.execute(sql, [room_id] + list(args))
+ rows = txn.fetchall()
+ if rows:
+ return rows[0][0]
+ else:
+ raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
+
async def insert_receipt(
self,
room_id: str,
@@ -709,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
- # TODO: Make this better.
- def graph_to_linear(txn: LoggingTransaction) -> str:
- clause, args = make_in_list_sql_clause(
- self.database_engine, "event_id", event_ids
- )
-
- sql = """
- SELECT event_id WHERE room_id = ? AND stream_ordering IN (
- SELECT max(stream_ordering) WHERE %s
- )
- """ % (
- clause,
- )
-
- txn.execute(sql, [room_id] + list(args))
- rows = txn.fetchall()
- if rows:
- return rows[0][0]
- else:
- raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
-
linearized_event_id = await self.db_pool.runInteraction(
- "insert_receipt_conv", graph_to_linear
+ "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
)
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
- self.insert_linearized_receipt_txn,
+ self._insert_linearized_receipt_txn,
room_id,
receipt_type,
user_id,
@@ -758,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
now - event_ts,
)
- await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
-
- max_persisted_id = self._receipts_id_gen.get_current_token()
-
- return stream_id, max_persisted_id
-
- async def insert_graph_receipt(
- self,
- room_id: str,
- receipt_type: str,
- user_id: str,
- event_ids: List[str],
- data: JsonDict,
- ) -> None:
- assert self._can_write_to_receipts
-
await self.db_pool.runInteraction(
"insert_graph_receipt",
- self.insert_graph_receipt_txn,
+ self._insert_graph_receipt_txn,
room_id,
receipt_type,
user_id,
@@ -784,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- def insert_graph_receipt_txn(
+ max_persisted_id = self._receipts_id_gen.get_current_token()
+
+ return stream_id, max_persisted_id
+
+ def _insert_graph_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 484976ca..b457bc18 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
logger = logging.getLogger(__name__)
@@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore):
if len(events) > limit and last_topo_id and last_stream_id:
next_key = RoomStreamToken(last_topo_id, last_stream_id)
if from_token:
- next_token = from_token.copy_and_replace("room_key", next_key)
+ next_token = from_token.copy_and_replace(
+ StreamKeyType.ROOM, next_key
+ )
else:
next_token = StreamToken(
room_key=next_key,
@@ -765,6 +767,59 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ @cached(iterable=True)
+ async def get_mutual_event_relations_for_rel_type(
+ self, event_id: str, relation_type: str
+ ) -> Set[Tuple[str, str]]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="get_mutual_event_relations_for_rel_type",
+ list_name="relation_types",
+ )
+ async def get_mutual_event_relations(
+ self, event_id: str, relation_types: Collection[str]
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ """
+ Fetch event metadata for events which related to the same event as the given event.
+
+ If the given event has no relation information, returns an empty dictionary.
+
+ Args:
+ event_id: The event ID which is targeted by relations.
+ relation_types: The relation types to check for mutual relations.
+
+ Returns:
+ A dictionary of relation type to:
+ A set of tuples of:
+ The sender
+ The event type
+ """
+ rel_type_sql, rel_type_args = make_in_list_sql_clause(
+ self.database_engine, "relation_type", relation_types
+ )
+
+ sql = f"""
+ SELECT DISTINCT relation_type, sender, type FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND {rel_type_sql}
+ """
+
+ def _get_event_relations(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ txn.execute(sql, [event_id] + rel_type_args)
+ result: Dict[str, Set[Tuple[str, str]]] = {
+ rel_type: set() for rel_type in relation_types
+ }
+ for rel_type, sender, type in txn.fetchall():
+ result[rel_type].add((sender, type))
+ return result
+
+ return await self.db_pool.runInteraction(
+ "get_event_relations", _get_event_relations
+ )
+
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 87e9482c..68d4fc2e 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -23,6 +23,7 @@ from typing import (
Collection,
Dict,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -45,7 +46,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
-from synapse.types import JsonDict, ThirdPartyInstanceID
+from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import MXC_REGEX
@@ -233,24 +234,23 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
UNION SELECT room_id from appservice_room_list
"""
- sql = """
+ sql = f"""
SELECT
COUNT(*)
FROM (
- %(published_sql)s
+ {published_sql}
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ join_rules = '{JoinRules.PUBLIC}'
+ OR join_rules = '{JoinRules.KNOCK}'
+ OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
OR history_visibility = 'world_readable'
)
AND joined_members > 0
- """ % {
- "published_sql": published_sql,
- "knock_join_rule": JoinRules.KNOCK,
- }
+ """
txn.execute(sql, query_args)
return cast(Tuple[int], txn.fetchone())[0]
@@ -369,29 +369,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
if where_clauses:
where_clause = " AND " + " AND ".join(where_clauses)
- sql = """
+ dir = "DESC" if forwards else "ASC"
+ sql = f"""
SELECT
room_id, name, topic, canonical_alias, joined_members,
avatar, history_visibility, guest_access, join_rules
FROM (
- %(published_sql)s
+ {published_sql}
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
- join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
+ join_rules = '{JoinRules.PUBLIC}'
+ OR join_rules = '{JoinRules.KNOCK}'
+ OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
OR history_visibility = 'world_readable'
)
AND joined_members > 0
- %(where_clause)s
- ORDER BY joined_members %(dir)s, room_id %(dir)s
- """ % {
- "published_sql": published_sql,
- "where_clause": where_clause,
- "dir": "DESC" if forwards else "ASC",
- "knock_join_rule": JoinRules.KNOCK,
- }
+ {where_clause}
+ ORDER BY
+ joined_members {dir},
+ room_id {dir}
+ """
if limit is not None:
query_args.append(limit)
@@ -699,7 +699,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
@cached()
- async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
+ async def get_retention_policy_for_room(self, room_id: str) -> RetentionPolicy:
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@@ -707,12 +707,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
the 'max_lifetime' if no default policy has been defined in the server's
configuration).
+ If support for retention policies is disabled, a policy with a 'min_lifetime' and
+ 'max_lifetime' of None is returned.
+
Args:
room_id: The ID of the room to get the retention policy of.
Returns:
A dict containing "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum. This prevents incorrectly filtering out events when sending to
+ # the client.
+ if not self.config.retention.retention_enabled:
+ return RetentionPolicy()
def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
@@ -736,10 +744,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# If we don't know this room ID, ret will be None, in this case return the default
# policy.
if not ret:
- return {
- "min_lifetime": self.config.retention.retention_default_min_lifetime,
- "max_lifetime": self.config.retention.retention_default_max_lifetime,
- }
+ return RetentionPolicy(
+ min_lifetime=self.config.retention.retention_default_min_lifetime,
+ max_lifetime=self.config.retention.retention_default_max_lifetime,
+ )
min_lifetime = ret[0]["min_lifetime"]
max_lifetime = ret[0]["max_lifetime"]
@@ -754,10 +762,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
if max_lifetime is None:
max_lifetime = self.config.retention.retention_default_max_lifetime
- return {
- "min_lifetime": min_lifetime,
- "max_lifetime": max_lifetime,
- }
+ return RetentionPolicy(
+ min_lifetime=min_lifetime,
+ max_lifetime=max_lifetime,
+ )
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
@@ -994,7 +1002,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
- ) -> Dict[str, Dict[str, Optional[int]]]:
+ ) -> Dict[str, RetentionPolicy]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
@@ -1016,7 +1024,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_rooms_for_retention_period_in_range_txn(
txn: LoggingTransaction,
- ) -> Dict[str, Dict[str, Optional[int]]]:
+ ) -> Dict[str, RetentionPolicy]:
range_conditions = []
args = []
@@ -1047,10 +1055,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
rooms_dict = {}
for row in rows:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": row["min_lifetime"],
- "max_lifetime": row["max_lifetime"],
- }
+ rooms_dict[row["room_id"]] = RetentionPolicy(
+ min_lifetime=row["min_lifetime"],
+ max_lifetime=row["max_lifetime"],
+ )
if include_null:
# If required, do a second query that retrieves all of the rooms we know
@@ -1065,10 +1073,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# policy in its state), add it with a null policy.
for row in rows:
if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": None,
- "max_lifetime": None,
- }
+ rooms_dict[row["room_id"]] = RetentionPolicy()
return rooms_dict
@@ -1077,6 +1082,32 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
+ async def get_partial_state_rooms_and_servers(
+ self,
+ ) -> Mapping[str, Collection[str]]:
+ """Get all rooms containing events with partial state, and the servers known
+ to be in the room.
+
+ Returns:
+ A dictionary of rooms with partial state, with room IDs as keys and
+ lists of servers in rooms as values.
+ """
+ room_servers: Dict[str, List[str]] = {}
+
+ rows = await self.db_pool.simple_select_list(
+ "partial_state_rooms_servers",
+ keyvalues=None,
+ retcols=("room_id", "server_name"),
+ desc="get_partial_state_rooms",
+ )
+
+ for row in rows:
+ room_id = row["room_id"]
+ server_name = row["server_name"]
+ room_servers.setdefault(room_id, []).append(server_name)
+
+ return room_servers
+
async def clear_partial_state_room(self, room_id: str) -> bool:
# this can race with incoming events, so we watch out for FK errors.
# TODO(faster_joins): this still doesn't completely fix the race, since the persist process
@@ -1108,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"room_id": room_id},
)
+ async def is_partial_state_room(self, room_id: str) -> bool:
+ """Checks if this room has partial state.
+
+ Returns true if this is a "partial-state" room, which means that the state
+ at events in the room, and `current_state_events`, may not yet be
+ complete.
+ """
+
+ entry = await self.db_pool.simple_select_one_onecol(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="room_id",
+ allow_none=True,
+ desc="is_partial_state_room",
+ )
+
+ return entry is not None
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 48e83592..31bc8c56 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,6 +15,7 @@
import logging
from typing import (
TYPE_CHECKING,
+ Callable,
Collection,
Dict,
FrozenSet,
@@ -37,7 +38,12 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@@ -46,7 +52,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -115,7 +121,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@wrap_as_background_process("_count_known_servers")
- async def _count_known_servers(self):
+ async def _count_known_servers(self) -> int:
"""
Count the servers that this server knows about.
@@ -123,7 +129,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
`synapse_federation_known_servers` LaterGauge to collect.
"""
- def _transact(txn):
+ def _transact(txn: LoggingTransaction) -> int:
if isinstance(self.database_engine, Sqlite3Engine):
query = """
SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
@@ -150,7 +156,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._known_servers_count = max([count, 1])
return self._known_servers_count
- def _check_safe_current_state_events_membership_updated_txn(self, txn):
+ def _check_safe_current_state_events_membership_updated_txn(
+ self, txn: LoggingTransaction
+ ) -> None:
"""Checks if it is safe to assume the new current_state_events
membership column is up to date
"""
@@ -182,7 +190,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_users_in_room", self.get_users_in_room_txn, room_id
)
- def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
+ def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
@@ -222,7 +230,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
A mapping from user ID to ProfileInfo.
"""
- def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
+ def _get_users_in_room_with_profiles(
+ txn: LoggingTransaction,
+ ) -> Dict[str, ProfileInfo]:
sql = """
SELECT state_key, display_name, avatar_url FROM room_memberships as m
INNER JOIN current_state_events as c
@@ -250,7 +260,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
dict of membership states, pointing to a MemberSummary named tuple.
"""
- def _get_room_summary_txn(txn):
+ def _get_room_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, MemberSummary]:
# first get counts.
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
@@ -279,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (room_id,))
- res = {}
+ res: Dict[str, MemberSummary] = {}
for count, membership in txn:
res.setdefault(membership, MemberSummary([], count))
@@ -400,7 +412,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def _get_rooms_for_local_user_where_membership_is_txn(
self,
- txn,
+ txn: LoggingTransaction,
user_id: str,
membership_list: List[str],
) -> List[RoomsForUser]:
@@ -488,7 +500,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
def _get_rooms_for_user_with_stream_ordering_txn(
- self, txn, user_id: str
+ self, txn: LoggingTransaction, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
@@ -542,7 +554,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
def _get_rooms_for_users_with_stream_ordering_txn(
- self, txn, user_ids: Collection[str]
+ self, txn: LoggingTransaction, user_ids: Collection[str]
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
clause, args = make_in_list_sql_clause(
@@ -575,7 +587,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, [Membership.JOIN] + args)
- result = {user_id: set() for user_id in user_ids}
+ result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
+ user_id: set() for user_id in user_ids
+ }
for user_id, room_id, instance, stream_id in txn:
result[user_id].add(
GetRoomsForUserWithStreamOrdering(
@@ -595,7 +609,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if not user_ids:
return set()
- def _get_users_server_still_shares_room_with_txn(txn):
+ def _get_users_server_still_shares_room_with_txn(
+ txn: LoggingTransaction,
+ ) -> Set[str]:
sql = """
SELECT state_key FROM current_state_events
WHERE
@@ -619,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
async def get_rooms_for_user(
- self, user_id: str, on_invalidate=None
+ self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
@@ -654,10 +670,34 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return user_who_share_room
+ @cached(cache_context=True, iterable=True)
+ async def get_mutual_rooms_between_users(
+ self, user_ids: FrozenSet[str], cache_context: _CacheContext
+ ) -> FrozenSet[str]:
+ """
+ Returns the set of rooms that all users in `user_ids` share.
+
+ Args:
+ user_ids: A frozen set of all users to investigate and return
+ overlapping joined rooms for.
+ cache_context
+ """
+ shared_room_ids: Optional[FrozenSet[str]] = None
+ for user_id in user_ids:
+ room_ids = await self.get_rooms_for_user(
+ user_id, on_invalidate=cache_context.invalidate
+ )
+ if shared_room_ids is not None:
+ shared_room_ids &= room_ids
+ else:
+ shared_room_ids = room_ids
+
+ return shared_room_ids or frozenset()
+
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
) -> Dict[str, ProfileInfo]:
- state_group = context.state_group
+ state_group: Union[object, int] = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -666,14 +706,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
current_state_ids = await context.get_current_state_ids()
+ assert current_state_ids is not None
+ assert state_group is not None
return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
async def get_joined_users_from_state(
- self, room_id, state_entry
+ self, room_id: str, state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]:
- state_group = state_entry.state_group
+ state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -681,6 +723,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
+ assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
@@ -689,12 +732,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
self,
- room_id,
- state_group,
- current_state_ids,
- cache_context,
- event=None,
- context=None,
+ room_id: str,
+ state_group: Union[object, int],
+ current_state_ids: StateMap[str],
+ cache_context: _CacheContext,
+ event: Optional[EventBase] = None,
+ context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
@@ -765,14 +808,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return users_in_room
@cached(max_entries=10000)
- def _get_joined_profile_from_event_id(self, event_id):
+ def _get_joined_profile_from_event_id(
+ self, event_id: str
+ ) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
)
- async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+ async def _get_joined_profiles_from_event_ids(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -780,8 +827,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
- dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
- to `user_id` and ProfileInfo (or None if not join event).
+ Map from event ID to `user_id` and ProfileInfo (or None if not join event).
"""
rows = await self.db_pool.simple_select_many_batch(
@@ -847,8 +893,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
- async def get_joined_hosts(self, room_id: str, state_entry):
- state_group = state_entry.state_group
+ @cached(iterable=True, max_entries=10000)
+ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ """Get current hosts in room based on current state."""
+
+ # First we check if we already have `get_users_in_room` in the cache, as
+ # we can just calculate result from that
+ users = self.get_users_in_room.cache.get_immediate(
+ (room_id,), None, update_metrics=False
+ )
+ if users is not None:
+ return {get_domain_from_id(u) for u in users}
+
+ if isinstance(self.database_engine, Sqlite3Engine):
+ # If we're using SQLite then let's just always use
+ # `get_users_in_room` rather than funky SQL.
+ users = await self.get_users_in_room(room_id)
+ return {get_domain_from_id(u) for u in users}
+
+ # For PostgreSQL we can use a regex to pull out the domains from the
+ # joined users in `current_state_events` via regex.
+
+ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+ sql = """
+ SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
+ FROM current_state_events
+ WHERE
+ type = 'm.room.member'
+ AND membership = 'join'
+ AND room_id = ?
+ """
+ txn.execute(sql, (room_id,))
+ return {d for d, in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_current_hosts_in_room", get_current_hosts_in_room_txn
+ )
+
+ async def get_joined_hosts(
+ self, room_id: str, state_entry: "_StateCacheEntry"
+ ) -> FrozenSet[str]:
+ state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -856,6 +941,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
+ assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry
@@ -863,7 +949,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
- self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
+ self,
+ room_id: str,
+ state_group: Union[object, int],
+ state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
@@ -881,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
- cache = await self._get_joined_hosts_cache(room_id)
+ cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
@@ -897,6 +986,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
elif state_entry.prev_group == cache.state_group:
# The cached work is for the previous state group, so we work out
# the delta.
+ assert state_entry.delta_ids is not None
for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue
@@ -942,7 +1032,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns False if they have since re-joined."""
- def f(txn):
+ def f(txn: LoggingTransaction) -> int:
sql = (
"SELECT"
" COUNT(*)"
@@ -973,7 +1063,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The forgotten rooms.
"""
- def _get_forgotten_rooms_for_user_txn(txn):
+ def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
# This is a slightly convoluted query that first looks up all rooms
# that the user has forgotten in the past, then rechecks that list
# to see if any have subsequently been updated. This is done so that
@@ -1076,7 +1166,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
clause,
)
- def _is_local_host_in_room_ignoring_users_txn(txn):
+ def _is_local_host_in_room_ignoring_users_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
txn.execute(sql, (room_id, Membership.JOIN, *args))
return bool(txn.fetchone())
@@ -1110,15 +1202,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1",
)
- async def _background_add_membership_profile(self, progress, batch_size):
+ async def _background_add_membership_profile(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress.get(
- "target_min_stream_id_inclusive", self._min_stream_order_on_start
+ "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined]
)
max_stream_id = progress.get(
- "max_stream_id_exclusive", self._stream_order_on_start + 1
+ "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined]
)
- def add_membership_profile_txn(txn):
+ def add_membership_profile_txn(txn: LoggingTransaction) -> int:
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
@@ -1182,13 +1276,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return result
- async def _background_current_state_membership(self, progress, batch_size):
+ async def _background_current_state_membership(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order.
"""
- def _background_current_state_membership_txn(txn, last_processed_room):
+ def _background_current_state_membership_txn(
+ txn: LoggingTransaction, last_processed_room: str
+ ) -> Tuple[int, bool]:
processed = 0
while processed < batch_size:
txn.execute(
@@ -1242,7 +1340,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return row_count
-class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+class RoomMemberStore(
+ RoomMemberWorkerStore,
+ RoomMemberBackgroundUpdateStore,
+ CacheInvalidationWorkerStore,
+):
def __init__(
self,
database: DatabasePool,
@@ -1254,7 +1356,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
sql = (
"UPDATE"
" room_memberships"
@@ -1288,5 +1390,5 @@ class _JoinedHostsCache:
# equal to anything else).
state_group: Union[object, int] = attr.Factory(object)
- def __len__(self):
+ def __len__(self) -> int:
return sum(len(v) for v in self.hosts_to_joined_users.values())
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3c49e7ec..78e0773b 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
import attr
@@ -27,7 +27,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
)
- async def _background_reindex_search(self, progress, batch_size):
+ async def _background_reindex_search(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
- def reindex_search_txn(txn):
+ def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return result
- async def _background_reindex_gin_search(self, progress, batch_size):
+ async def _background_reindex_gin_search(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
- def create_index(conn):
+ def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
# we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
return 1
- async def _background_reindex_search_order(self, progress, batch_size):
+ async def _background_reindex_search_order(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
if not have_added_index:
- def create_index(conn):
+ def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
conn.set_session(autocommit=True)
c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
pg,
)
- def reindex_search_txn(txn):
+ def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
else:
raise Exception("Unrecognized database engine")
- args.append(limit)
+ # mypy expects to append only a `str`, not an `int`
+ args.append(limit) # type: ignore[arg-type]
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
A set of strings.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> Set[str]:
highlight_words = set()
for event in events:
# As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
return await self.db_pool.runInteraction("_find_highlights", f)
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 18ae8aee..bdd00273 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -16,6 +16,8 @@ import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+import attr
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -26,6 +28,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -43,6 +47,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventMetadata:
+ """Returned by `get_metadata_for_events`"""
+
+ room_id: str
+ event_type: str
+ state_key: Optional[str]
+ rejection_reason: Optional[str]
+
+
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v:
@@ -133,6 +147,57 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
+ async def get_metadata_for_events(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventMetadata]:
+ """Get some metadata (room_id, type, state_key) for the given events.
+
+ This method is a faster alternative than fetching the full events from
+ the DB, and should be used when the full event is not needed.
+
+ Returns metadata for rejected and redacted events. Events that have not
+ been persisted are omitted from the returned dict.
+ """
+
+ def get_metadata_for_events_txn(
+ txn: LoggingTransaction,
+ batch_ids: Collection[str],
+ ) -> Dict[str, EventMetadata]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "e.event_id", batch_ids
+ )
+
+ sql = f"""
+ SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason
+ FROM events AS e
+ LEFT JOIN state_events se USING (event_id)
+ LEFT JOIN rejections r USING (event_id)
+ WHERE {clause}
+ """
+
+ txn.execute(sql, args)
+ return {
+ event_id: EventMetadata(
+ room_id=room_id,
+ event_type=event_type,
+ state_key=state_key,
+ rejection_reason=rejection_reason,
+ )
+ for event_id, room_id, event_type, state_key, rejection_reason in txn
+ }
+
+ result_map: Dict[str, EventMetadata] = {}
+ for batch_ids in batch_iter(event_ids, 1000):
+ result_map.update(
+ await self.db_pool.runInteraction(
+ "get_metadata_for_events",
+ get_metadata_for_events_txn,
+ batch_ids=batch_ids,
+ )
+ )
+
+ return result_map
+
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@@ -177,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
NotFoundError if the room is unknown
"""
- state_ids = await self.get_current_state_ids(room_id)
+ state_ids = await self.get_partial_current_state_ids(room_id)
if not state_ids:
raise NotFoundError(f"Current state for room {room_id} is empty")
@@ -193,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
- async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
+ async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
+ This may be the partial state if we're lazy joining the room.
+
Args:
room_id: The room to get the state IDs of.
@@ -215,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return await self.db_pool.runInteraction(
- "get_current_state_ids", _get_current_state_ids_txn
+ "get_partial_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
- async def get_filtered_current_state_ids(
+ async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
+ This may be the partial state if we're lazy joining the room.
+
Args:
room_id
state_filter: The state filter used to fetch state
@@ -241,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not where_clause:
# We delegate to the cached version
- return await self.get_current_state_ids(room_id)
+ return await self.get_partial_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
@@ -269,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
- async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
- """Get canonical alias for room, if any
-
- Args:
- room_id: The room ID
-
- Returns:
- The canonical alias, if any
- """
-
- state = await self.get_filtered_current_state_ids(
- room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
- )
-
- event_id = state.get((EventTypes.CanonicalAlias, ""))
- if not event_id:
- return None
-
- event = await self.get_event(event_id, allow_none=True)
- if not event:
- return None
-
- return event.content.get("canonical_alias")
-
@cached(max_entries=50000)
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return await self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 188afec3..445213e1 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore):
# attribute. TODO: can we get static analysis to enforce this?
_curr_state_delta_stream_cache: StreamChangeCache
- async def get_current_state_deltas(
+ async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
@@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore):
- prev_event_id (str|None): previous event_id for this state key. None
if it's new state.
+ This may be the partial state if we're lazy joining the room.
+
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 0373af86..8e88784d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -765,15 +765,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self,
room_id: str,
end_token: RoomStreamToken,
- ) -> Optional[EventBase]:
- """Returns the last event in a room at or before a stream ordering
+ ) -> Optional[str]:
+ """Returns the ID of the last event in a room at or before a stream ordering
Args:
room_id
end_token: The token used to stream from
Returns:
- The most recent event.
+ The ID of the most recent event, or None if there are no events in the room
+ before this stream ordering.
"""
last_row = await self.get_room_event_before_stream_ordering(
@@ -781,37 +782,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
stream_ordering=end_token.stream,
)
if last_row:
- _, _, event_id = last_row
- event = await self.get_event(event_id, get_prev_content=True)
- return event
-
+ return last_row[2]
return None
async def get_current_room_stream_token_for_room_id(
- self, room_id: Optional[str] = None
+ self, room_id: str
) -> RoomStreamToken:
- """Returns the current position of the rooms stream.
-
- By default, it returns a live token with the current global stream
- token. Specifying a `room_id` causes it to return a historic token with
- the room specific topological token.
- """
+ """Returns the current position of the rooms stream (historic token)."""
stream_ordering = self.get_room_max_stream_ordering()
- if room_id is None:
- return RoomStreamToken(None, stream_ordering)
- else:
- topo = await self.db_pool.runInteraction(
- "_get_max_topological_txn", self._get_max_topological_txn, room_id
- )
- return RoomStreamToken(topo, stream_ordering)
+ topo = await self.db_pool.runInteraction(
+ "_get_max_topological_txn", self._get_max_topological_txn, room_id
+ )
+ return RoomStreamToken(topo, stream_ordering)
def get_stream_id_for_event_txn(
self,
txn: LoggingTransaction,
event_id: str,
- allow_none=False,
- ) -> int:
- return self.db_pool.simple_select_one_onecol_txn(
+ allow_none: bool = False,
+ ) -> Optional[int]:
+ # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+ # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+ return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
txn=txn,
table="events",
keyvalues={"event_id": event_id},
@@ -873,7 +865,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
rows = txn.fetchall()
- return rows[0][0] if rows else 0
+ # An aggregate function like MAX() will always return one row per group
+ # so we can safely rely on the lookup here. For example, when a we
+ # lookup a `room_id` which does not exist, `rows` will look like
+ # `[(None,)]`
+ return rows[0][0] if rows[0][0] is not None else 0
@staticmethod
def _set_before_and_after(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 028db69a..ddb25b5c 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
- current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
+ # Getting the partial state is fine, as we're not looking at membership
+ # events.
+ current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)
@@ -729,49 +731,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- async def get_mutual_rooms_for_users(
- self, user_id: str, other_user_id: str
- ) -> Set[str]:
- """
- Returns the rooms that a local user shares with another local or remote user.
-
- Args:
- user_id: The MXID of a local user
- other_user_id: The MXID of the other user
-
- Returns:
- A set of room ID's that the users share.
- """
-
- def _get_mutual_rooms_for_users_txn(
- txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
- txn.execute(
- """
- SELECT p1.room_id
- FROM users_in_public_rooms as p1
- INNER JOIN users_in_public_rooms as p2
- ON p1.room_id = p2.room_id
- AND p1.user_id = ?
- AND p2.user_id = ?
- UNION
- SELECT room_id
- FROM users_who_share_private_rooms
- WHERE
- user_id = ?
- AND other_user_id = ?
- """,
- (user_id, other_user_id, user_id, other_user_id),
- )
- rows = self.db_pool.cursor_to_dict(txn)
- return rows
-
- rows = await self.db_pool.runInteraction(
- "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
- )
-
- return {row["room_id"] for row in rows}
-
async def get_user_directory_stream_pos(self) -> Optional[int]:
"""
Get the stream ID of the user directory stream.
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 5de70f31..fa9eadac 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -195,6 +195,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+ STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME = "state_group_edges_unique_idx"
def __init__(
self,
@@ -217,6 +218,21 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
+ # `state_group_edges` can cause severe performance issues if duplicate
+ # rows are introduced, which can accidentally be done by well-meaning
+ # server admins when trying to restore a database dump, etc.
+ # See https://github.com/matrix-org/synapse/issues/11779.
+ # Introduce a unique index to guard against that.
+ self.db_pool.updates.register_background_index_update(
+ self.STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME,
+ index_name="state_group_edges_unique_idx",
+ table="state_group_edges",
+ columns=["state_group", "prev_state_group"],
+ unique=True,
+ # The old index was on (state_group) and was not unique.
+ replaces_index="state_group_edges_idx",
+ )
+
async def _background_deduplicate_state(
self, progress: dict, batch_size: int
) -> int:
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7614d76a..609a2b88 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
group: int,
state_filter: StateFilter,
) -> Tuple[MutableStateMap[str], bool]:
- """Checks if group is in cache. See `_get_state_for_groups`
+ """Checks if group is in cache. See `get_state_for_groups`
Args:
cache: the state group cache to use
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index afb7d505..f51b3d22 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -11,25 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Mapping
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
-def create_engine(database_config) -> BaseDatabaseEngine:
+def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
name = database_config["name"]
if name == "sqlite3":
- import sqlite3
-
- return Sqlite3Engine(sqlite3, database_config)
+ return Sqlite3Engine(database_config)
if name == "psycopg2":
- # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
- import psycopg2
-
- return PostgresEngine(psycopg2, database_config)
+ return PostgresEngine(database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 143cd98c..971ff826 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -13,9 +13,12 @@
# limitations under the License.
import abc
from enum import IntEnum
-from typing import Generic, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, TypeVar
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor, DBAPI2Module
+
+if TYPE_CHECKING:
+ from synapse.storage.database import LoggingDatabaseConnection
class IsolationLevel(IntEnum):
@@ -32,7 +35,7 @@ ConnectionType = TypeVar("ConnectionType", bound=Connection)
class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
- def __init__(self, module, database_config: dict):
+ def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
self.module = module
@property
@@ -69,7 +72,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
...
@abc.abstractmethod
- def check_new_database(self, txn) -> None:
+ def check_new_database(self, txn: Cursor) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
@@ -79,8 +82,11 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def convert_param_style(self, sql: str) -> str:
...
+ # This method would ideally take a plain ConnectionType, but it seems that
+ # the Sqlite engine expects to use LoggingDatabaseConnection.cursor
+ # instead of sqlite3.Connection.cursor: only the former takes a txn_name.
@abc.abstractmethod
- def on_new_connection(self, db_conn: ConnectionType) -> None:
+ def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
...
@abc.abstractmethod
@@ -92,7 +98,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
...
@abc.abstractmethod
- def lock_table(self, txn, table: str) -> None:
+ def lock_table(self, txn: Cursor, table: str) -> None:
...
@property
@@ -102,12 +108,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
...
@abc.abstractmethod
- def in_transaction(self, conn: Connection) -> bool:
+ def in_transaction(self, conn: ConnectionType) -> bool:
"""Whether the connection is currently in a transaction."""
...
@abc.abstractmethod
- def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+ def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> None:
"""Attempt to set the connections autocommit mode.
When True queries are run outside of transactions.
@@ -119,8 +125,8 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
@abc.abstractmethod
def attempt_to_set_isolation_level(
- self, conn: Connection, isolation_level: Optional[int]
- ):
+ self, conn: ConnectionType, isolation_level: Optional[int]
+ ) -> None:
"""Attempt to set the connections isolation level.
Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default.
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index e8d29e28..391f8ed2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,39 +13,47 @@
# limitations under the License.
import logging
-from typing import Mapping, Optional
+from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
from synapse.storage.engines._base import (
BaseDatabaseEngine,
IncorrectDatabaseSetup,
IsolationLevel,
)
-from synapse.storage.types import Connection
+from synapse.storage.types import Cursor
+
+if TYPE_CHECKING:
+ import psycopg2 # noqa: F401
+
+ from synapse.storage.database import LoggingDatabaseConnection
+
logger = logging.getLogger(__name__)
-class PostgresEngine(BaseDatabaseEngine):
- def __init__(self, database_module, database_config):
- super().__init__(database_module, database_config)
- self.module.extensions.register_type(self.module.extensions.UNICODE)
+class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
+ def __init__(self, database_config: Mapping[str, Any]):
+ import psycopg2.extensions
+
+ super().__init__(psycopg2, database_config)
+ psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
- def _disable_bytes_adapter(_):
+ def _disable_bytes_adapter(_: bytes) -> NoReturn:
raise Exception("Passing bytes to DB is disabled.")
- self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
- self.synchronous_commit = database_config.get("synchronous_commit", True)
- self._version = None # unknown as yet
+ psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
+ self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
+ self._version: Optional[int] = None # unknown as yet
self.isolation_level_map: Mapping[int, int] = {
- IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED,
- IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
- IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE,
+ IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
+ IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
+ IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
}
self.default_isolation_level = (
- self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
+ psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.config = database_config
@@ -53,19 +61,21 @@ class PostgresEngine(BaseDatabaseEngine):
def single_threaded(self) -> bool:
return False
- def get_db_locale(self, txn):
+ def get_db_locale(self, txn: Cursor) -> Tuple[str, str]:
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
- collation, ctype = txn.fetchone()
+ collation, ctype = cast(Tuple[str, str], txn.fetchone())
return collation, ctype
- def check_database(self, db_conn, allow_outdated_version: bool = False):
+ def check_database(
+ self, db_conn: "psycopg2.connection", allow_outdated_version: bool = False
+ ) -> None:
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
- self._version = db_conn.server_version
+ self._version = cast(int, db_conn.server_version)
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
@@ -108,7 +118,7 @@ class PostgresEngine(BaseDatabaseEngine):
ctype,
)
- def check_new_database(self, txn):
+ def check_new_database(self, txn: Cursor) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
@@ -129,10 +139,10 @@ class PostgresEngine(BaseDatabaseEngine):
"See docs/postgres.md for more information." % ("\n".join(errors))
)
- def convert_param_style(self, sql):
+ def convert_param_style(self, sql: str) -> str:
return sql.replace("?", "%s")
- def on_new_connection(self, db_conn):
+ def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
db_conn.set_isolation_level(self.default_isolation_level)
# Set the bytea output to escape, vs the default of hex
@@ -149,14 +159,14 @@ class PostgresEngine(BaseDatabaseEngine):
db_conn.commit()
@property
- def can_native_upsert(self):
+ def can_native_upsert(self) -> bool:
"""
Can we use native UPSERTs?
"""
return True
@property
- def supports_using_any_list(self):
+ def supports_using_any_list(self) -> bool:
"""Do we support using `a = ANY(?)` and passing a list"""
return True
@@ -165,27 +175,25 @@ class PostgresEngine(BaseDatabaseEngine):
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return True
- def is_deadlock(self, error):
- if isinstance(error, self.module.DatabaseError):
+ def is_deadlock(self, error: Exception) -> bool:
+ import psycopg2.extensions
+
+ if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
# "40001" serialization_failure
# "40P01" deadlock_detected
return error.pgcode in ["40001", "40P01"]
return False
- def is_connection_closed(self, conn):
+ def is_connection_closed(self, conn: "psycopg2.connection") -> bool:
return bool(conn.closed)
- def lock_table(self, txn, table):
+ def lock_table(self, txn: Cursor, table: str) -> None:
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
@property
- def server_version(self):
- """Returns a string giving the server version. For example: '8.1.5'
-
- Returns:
- string
- """
+ def server_version(self) -> str:
+ """Returns a string giving the server version. For example: '8.1.5'."""
# note that this is a bit of a hack because it relies on check_database
# having been called. Still, that should be a safe bet here.
numver = self._version
@@ -197,17 +205,21 @@ class PostgresEngine(BaseDatabaseEngine):
else:
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
- def in_transaction(self, conn: Connection) -> bool:
- return conn.status != self.module.extensions.STATUS_READY # type: ignore
+ def in_transaction(self, conn: "psycopg2.connection") -> bool:
+ import psycopg2.extensions
+
+ return conn.status != psycopg2.extensions.STATUS_READY
- def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
- return conn.set_session(autocommit=autocommit) # type: ignore
+ def attempt_to_set_autocommit(
+ self, conn: "psycopg2.connection", autocommit: bool
+ ) -> None:
+ return conn.set_session(autocommit=autocommit)
def attempt_to_set_isolation_level(
- self, conn: Connection, isolation_level: Optional[int]
- ):
+ self, conn: "psycopg2.connection", isolation_level: Optional[int]
+ ) -> None:
if isolation_level is None:
isolation_level = self.default_isolation_level
else:
isolation_level = self.isolation_level_map[isolation_level]
- return conn.set_isolation_level(isolation_level) # type: ignore
+ return conn.set_isolation_level(isolation_level)
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 6c19e559..621f2c5e 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
+import sqlite3
import struct
import threading
-import typing
-from typing import Optional
+from typing import TYPE_CHECKING, Any, List, Mapping, Optional
from synapse.storage.engines import BaseDatabaseEngine
-from synapse.storage.types import Connection
+from synapse.storage.types import Cursor
-if typing.TYPE_CHECKING:
- import sqlite3 # noqa: F401
+if TYPE_CHECKING:
+ from synapse.storage.database import LoggingDatabaseConnection
-class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
- def __init__(self, database_module, database_config):
- super().__init__(database_module, database_config)
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
+ def __init__(self, database_config: Mapping[str, Any]):
+ super().__init__(sqlite3, database_config)
database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (
@@ -37,7 +37,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
if platform.python_implementation() == "PyPy":
# pypy's sqlite3 module doesn't handle bytearrays, convert them
# back to bytes.
- database_module.register_adapter(bytearray, lambda array: bytes(array))
+ sqlite3.register_adapter(bytearray, lambda array: bytes(array))
# The current max state_group, or None if we haven't looked
# in the DB yet.
@@ -49,41 +49,43 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
return True
@property
- def can_native_upsert(self):
+ def can_native_upsert(self) -> bool:
"""
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
more work we haven't done yet to tell what was inserted vs updated.
"""
- return self.module.sqlite_version_info >= (3, 24, 0)
+ return sqlite3.sqlite_version_info >= (3, 24, 0)
@property
- def supports_using_any_list(self):
+ def supports_using_any_list(self) -> bool:
"""Do we support using `a = ANY(?)` and passing a list"""
return False
@property
def supports_returning(self) -> bool:
"""Do we support the `RETURNING` clause in insert/update/delete?"""
- return self.module.sqlite_version_info >= (3, 35, 0)
+ return sqlite3.sqlite_version_info >= (3, 35, 0)
- def check_database(self, db_conn, allow_outdated_version: bool = False):
+ def check_database(
+ self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False
+ ) -> None:
if not allow_outdated_version:
- version = self.module.sqlite_version_info
+ version = sqlite3.sqlite_version_info
# Synapse is untested against older SQLite versions, and we don't want
# to let users upgrade to a version of Synapse with broken support for their
# sqlite version, because it risks leaving them with a half-upgraded db.
if version < (3, 22, 0):
raise RuntimeError("Synapse requires sqlite 3.22 or above.")
- def check_new_database(self, txn):
+ def check_new_database(self, txn: Cursor) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
- def convert_param_style(self, sql):
+ def convert_param_style(self, sql: str) -> str:
return sql
- def on_new_connection(self, db_conn):
+ def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
# We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database
@@ -97,48 +99,46 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
db_conn.execute("PRAGMA foreign_keys = ON;")
db_conn.commit()
- def is_deadlock(self, error):
+ def is_deadlock(self, error: Exception) -> bool:
return False
- def is_connection_closed(self, conn):
+ def is_connection_closed(self, conn: sqlite3.Connection) -> bool:
return False
- def lock_table(self, txn, table):
+ def lock_table(self, txn: Cursor, table: str) -> None:
return
@property
- def server_version(self):
- """Gets a string giving the server version. For example: '3.22.0'
+ def server_version(self) -> str:
+ """Gets a string giving the server version. For example: '3.22.0'."""
+ return "%i.%i.%i" % sqlite3.sqlite_version_info
- Returns:
- string
- """
- return "%i.%i.%i" % self.module.sqlite_version_info
-
- def in_transaction(self, conn: Connection) -> bool:
- return conn.in_transaction # type: ignore
+ def in_transaction(self, conn: sqlite3.Connection) -> bool:
+ return conn.in_transaction
- def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+ def attempt_to_set_autocommit(
+ self, conn: sqlite3.Connection, autocommit: bool
+ ) -> None:
# Twisted doesn't let us set attributes on the connections, so we can't
# set the connection to autocommit mode.
pass
def attempt_to_set_isolation_level(
- self, conn: Connection, isolation_level: Optional[int]
- ):
- # All transactions are SERIALIZABLE by default in sqllite
+ self, conn: sqlite3.Connection, isolation_level: Optional[int]
+ ) -> None:
+ # All transactions are SERIALIZABLE by default in sqlite
pass
# Following functions taken from: https://github.com/coleifer/peewee
-def _parse_match_info(buf):
+def _parse_match_info(buf: bytes) -> List[int]:
bufsize = len(buf)
return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
-def _rank(raw_match_info):
+def _rank(raw_match_info: bytes) -> float:
"""Handle match_info called w/default args 'pcx' - based on the example rank
function http://sqlite.org/fts3.html#appendix_a
"""
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 546d6bae..c33df420 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -85,7 +85,7 @@ def prepare_database(
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str] = ("main", "state"),
-):
+) -> None:
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 871d4ace..5843fae6 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 69 # remember to update the list below when updating
+SCHEMA_VERSION = 71 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -61,13 +61,23 @@ Changes in SCHEMA_VERSION = 68:
Changes in SCHEMA_VERSION = 69:
- We now write to `device_lists_changes_in_room` table.
- - Use sequence to generate future `application_services_txns.txn_id`s
+ - We now use a PostgreSQL sequence to generate future txn_ids for
+ `application_services_txns`. `application_services_state.last_txn` is no longer
+ updated.
+
+Changes in SCHEMA_VERSION = 70:
+ - event_reference_hashes is no longer written to.
+
+Changes in SCHEMA_VERSION = 71:
+ - event_edges.room_id is no longer read from.
+ - Tables related to groups are no longer accessed.
"""
SCHEMA_COMPAT_VERSION = (
# We now assume that `device_lists_changes_in_room` has been filled out for
# recent device_list_updates.
+ # ... and that `application_services_state.last_txn` is not used.
69
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
diff --git a/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql
new file mode 100644
index 00000000..22ae3b8c
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Background update to clear the inboxes of hidden and deleted devices.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6902, 'cache_invalidation_index_by_instance', '{}');
diff --git a/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql
new file mode 100644
index 00000000..aed79635
--- /dev/null
+++ b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Clean up left over rows from bug #11833, which was fixed in #12770.
+DELETE FROM federation_inbound_events_staging WHERE room_id not in (
+ SELECT room_id FROM rooms
+);
diff --git a/synapse/storage/schema/state/delta/70/08_state_group_edges_unique.sql b/synapse/storage/schema/state/delta/70/08_state_group_edges_unique.sql
new file mode 100644
index 00000000..b8c0ee0f
--- /dev/null
+++ b/synapse/storage/schema/state/delta/70/08_state_group_edges_unique.sql
@@ -0,0 +1,17 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7008, 'state_group_edges_unique_idx', '{}');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d1d58592..96aaffb5 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +15,7 @@
import logging
from typing import (
TYPE_CHECKING,
- Awaitable,
+ Callable,
Collection,
Dict,
Iterable,
@@ -30,15 +31,11 @@ import attr
from frozendict import frozendict
from synapse.api.constants import EventTypes
-from synapse.events import EventBase
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
- from synapse.server import HomeServer
- from synapse.storage.databases import Databases
logger = logging.getLogger(__name__)
@@ -62,7 +59,7 @@ class StateFilter:
types: "frozendict[str, Optional[FrozenSet[str]]]"
include_others: bool = False
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
@@ -138,7 +135,9 @@ class StateFilter:
)
@staticmethod
- def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+ def freeze(
+ types: Mapping[str, Optional[Collection[str]]], include_others: bool
+ ) -> "StateFilter":
"""
Returns a (frozen) StateFilter with the same contents as the parameters
specified here, which can be made of mutable types.
@@ -530,306 +529,47 @@ class StateFilter:
new_all, new_excludes, new_wildcards, new_concrete_keys
)
+ def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
+ """Check if we need to wait for full state to complete to calculate this state
-_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
-_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
- types=frozendict({EventTypes.Member: frozenset()}), include_others=True
-)
-_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
-
-
-class StateGroupStorage:
- """High level interface to fetching state for event."""
-
- def __init__(self, hs: "HomeServer", stores: "Databases"):
- self.stores = stores
- self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
-
- def notify_event_un_partial_stated(self, event_id: str) -> None:
- self._partial_state_events_tracker.notify_un_partial_stated(event_id)
-
- async def get_state_group_delta(
- self, state_group: int
- ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
- """Given a state group try to return a previous group and a delta between
- the old and the new.
-
- Args:
- state_group: The state group used to retrieve state deltas.
-
- Returns:
- A tuple of the previous group and a state map of the event IDs which
- make up the delta between the old and new state groups.
- """
-
- state_group_delta = await self.stores.state.get_state_group_delta(state_group)
- return state_group_delta.prev_group, state_group_delta.delta_ids
-
- async def get_state_groups_ids(
- self, _room_id: str, event_ids: Collection[str]
- ) -> Dict[int, MutableStateMap[str]]:
- """Get the event IDs of all the state for the state groups for the given events
-
- Args:
- _room_id: id of the room for these events
- event_ids: ids of the events
-
- Returns:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
-
- Raises:
- RuntimeError if we don't have a state group for one or more of the events
- (ie they are outliers or unknown)
- """
- if not event_ids:
- return {}
-
- event_to_groups = await self._get_state_group_for_events(event_ids)
-
- groups = set(event_to_groups.values())
- group_to_state = await self.stores.state._get_state_for_groups(groups)
-
- return group_to_state
-
- async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
- """Get the event IDs of all the state in the given state group
-
- Args:
- state_group: A state group for which we want to get the state IDs.
-
- Returns:
- Resolves to a map of (type, state_key) -> event_id
- """
- group_to_state = await self._get_state_for_groups((state_group,))
-
- return group_to_state[state_group]
-
- async def get_state_groups(
- self, room_id: str, event_ids: Collection[str]
- ) -> Dict[int, List[EventBase]]:
- """Get the state groups for the given list of event_ids
-
- Args:
- room_id: ID of the room for these events.
- event_ids: The event IDs to retrieve state for.
-
- Returns:
- dict of state_group_id -> list of state events.
- """
- if not event_ids:
- return {}
-
- group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
-
- state_event_map = await self.stores.main.get_events(
- [
- ev_id
- for group_ids in group_to_ids.values()
- for ev_id in group_ids.values()
- ],
- get_prev_content=False,
- )
-
- return {
- group: [
- state_event_map[v]
- for v in event_id_map.values()
- if v in state_event_map
- ]
- for group, event_id_map in group_to_ids.items()
- }
-
- def _get_state_groups_from_groups(
- self, groups: List[int], state_filter: StateFilter
- ) -> Awaitable[Dict[int, StateMap[str]]]:
- """Returns the state groups for a given set of groups, filtering on
- types of state events.
-
- Args:
- groups: list of state group IDs to query
- state_filter: The state filter used to fetch state
- from the database.
-
- Returns:
- Dict of state group to state map.
- """
-
- return self.stores.state._get_state_groups_from_groups(groups, state_filter)
-
- async def get_state_for_events(
- self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
- ) -> Dict[str, StateMap[EventBase]]:
- """Given a list of event_ids and type tuples, return a list of state
- dicts for each event.
-
- Args:
- event_ids: The events to fetch the state of.
- state_filter: The state filter used to fetch state.
-
- Returns:
- A dict of (event_id) -> (type, state_key) -> [state_events]
-
- Raises:
- RuntimeError if we don't have a state group for one or more of the events
- (ie they are outliers or unknown)
- """
- event_to_groups = await self._get_state_group_for_events(event_ids)
-
- groups = set(event_to_groups.values())
- group_to_state = await self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
- )
-
- state_event_map = await self.stores.main.get_events(
- [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
- get_prev_content=False,
- )
-
- event_to_state = {
- event_id: {
- k: state_event_map[v]
- for k, v in group_to_state[group].items()
- if v in state_event_map
- }
- for event_id, group in event_to_groups.items()
- }
-
- return {event: event_to_state[event] for event in event_ids}
-
- async def get_state_ids_for_events(
- self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
- ) -> Dict[str, StateMap[str]]:
- """
- Get the state dicts corresponding to a list of events, containing the event_ids
- of the state events (as opposed to the events themselves)
-
- Args:
- event_ids: events whose state should be returned
- state_filter: The state filter used to fetch state from the database.
-
- Returns:
- A dict from event_id -> (type, state_key) -> event_id
-
- Raises:
- RuntimeError if we don't have a state group for one or more of the events
- (ie they are outliers or unknown)
- """
- event_to_groups = await self._get_state_group_for_events(event_ids)
-
- groups = set(event_to_groups.values())
- group_to_state = await self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
- )
-
- event_to_state = {
- event_id: group_to_state[group]
- for event_id, group in event_to_groups.items()
- }
-
- return {event: event_to_state[event] for event in event_ids}
-
- async def get_state_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
- ) -> StateMap[EventBase]:
- """
- Get the state dict corresponding to a particular event
-
- Args:
- event_id: event whose state should be returned
- state_filter: The state filter used to fetch state from the database.
-
- Returns:
- A dict from (type, state_key) -> state_event
-
- Raises:
- RuntimeError if we don't have a state group for the event (ie it is an
- outlier or is unknown)
- """
- state_map = await self.get_state_for_events(
- [event_id], state_filter or StateFilter.all()
- )
- return state_map[event_id]
-
- async def get_state_ids_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
- ) -> StateMap[str]:
- """
- Get the state dict corresponding to a particular event
+ If we have a state filter which is completely satisfied even with partial
+ state, then we don't need to await_full_state before we can return it.
Args:
- event_id: event whose state should be returned
- state_filter: The state filter used to fetch state from the database.
-
- Returns:
- A dict from (type, state_key) -> state_event_id
-
- Raises:
- RuntimeError if we don't have a state group for the event (ie it is an
- outlier or is unknown)
+ is_mine_id: a callable which confirms if a given state_key matches a mxid
+ of a local user
"""
- state_map = await self.get_state_ids_for_events(
- [event_id], state_filter or StateFilter.all()
- )
- return state_map[event_id]
- def _get_state_for_groups(
- self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
- ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
- """Gets the state at each of a list of state groups, optionally
- filtering by type/state_key
+ # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
+ # there may be circumstances in which we return a piece of state that, once we
+ # resync the state, we discover is invalid. For example: if it turns out that
+ # the sender of a piece of state wasn't actually in the room, then clearly that
+ # state shouldn't have been returned.
+ # We should at least add some tests around this to see what happens.
- Args:
- groups: list of state groups for which we want to get the state.
- state_filter: The state filter used to fetch state.
- from the database.
+ # if we haven't requested membership events, then it depends on the value of
+ # 'include_others'
+ if EventTypes.Member not in self.types:
+ return self.include_others
- Returns:
- Dict of state group to state map.
- """
- return self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
- )
+ # if we're looking for *all* membership events, then we have to wait
+ member_state_keys = self.types[EventTypes.Member]
+ if member_state_keys is None:
+ return True
- async def _get_state_group_for_events(
- self,
- event_ids: Collection[str],
- await_full_state: bool = True,
- ) -> Mapping[str, int]:
- """Returns mapping event_id -> state_group
+ # otherwise, consider whose membership we are looking for. If it's entirely
+ # local users, then we don't need to wait.
+ for state_key in member_state_keys:
+ if not is_mine_id(state_key):
+ # remote user
+ return True
- Args:
- event_ids: events to get state groups for
- await_full_state: if true, will block if we do not yet have complete
- state at this event.
- """
- if await_full_state:
- await self._partial_state_events_tracker.await_full_state(event_ids)
-
- return await self.stores.main._get_state_group_for_events(event_ids)
+ # local users only
+ return False
- async def store_state_group(
- self,
- event_id: str,
- room_id: str,
- prev_group: Optional[int],
- delta_ids: Optional[StateMap[str]],
- current_state_ids: StateMap[str],
- ) -> int:
- """Store a new set of state, returning a newly assigned state group.
- Args:
- event_id: The event ID for which the state was calculated.
- room_id: ID of the room for which the state was calculated.
- prev_group: A previous state group for the room, optional.
- delta_ids: The delta between state at `prev_group` and
- `current_state_ids`, if `prev_group` was given. Same format as
- `current_state_ids`.
- current_state_ids: The state to store. Map of (type, state_key)
- to event_id.
-
- Returns:
- The state group ID
- """
- return await self.stores.state.store_state_group(
- event_id, room_id, prev_group, delta_ids, current_state_ids
- )
+_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
+)
+_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index d7d6f1d9..0031df1e 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
+from types import TracebackType
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Protocol
@@ -86,5 +87,80 @@ class Connection(Protocol):
def __enter__(self) -> "Connection":
...
- def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> Optional[bool]:
+ ...
+
+
+class DBAPI2Module(Protocol):
+ """The module-level attributes that we use from PEP 249.
+
+ This is NOT a comprehensive stub for the entire DBAPI2."""
+
+ __name__: str
+
+ # Exceptions. See https://peps.python.org/pep-0249/#exceptions
+
+ # For our specific drivers:
+ # - Python's sqlite3 module doesn't contains the same descriptions as the
+ # DBAPI2 spec, see https://docs.python.org/3/library/sqlite3.html#exceptions
+ # - Psycopg2 maps every Postgres error code onto a unique exception class which
+ # extends from this hierarchy. See
+ # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions
+ # https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
+ Warning: Type[Exception]
+ Error: Type[Exception]
+
+ # Errors are divided into `InterfaceError`s (something went wrong in the database
+ # driver) and `DatabaseError`s (something went wrong in the database). These are
+ # both subclasses of `Error`, but we can't currently express this in type
+ # annotations due to https://github.com/python/mypy/issues/8397
+ InterfaceError: Type[Exception]
+ DatabaseError: Type[Exception]
+
+ # Everything below is a subclass of `DatabaseError`.
+
+ # Roughly: the database rejected a nonsensical value. Examples:
+ # - An integer was too big for its data type.
+ # - An invalid date time was provided.
+ # - A string contained a null code point.
+ DataError: Type[Exception]
+
+ # Roughly: something went wrong in the database, but it's not within the application
+ # programmer's control. Examples:
+ # - We failed to establish a connection to the database.
+ # - The connection to the database was lost.
+ # - A deadlock was detected.
+ # - A serialisation failure occurred.
+ # - The database ran out of resources, such as storage, memory, connections, etc.
+ # - The database encountered an error from the operating system.
+ OperationalError: Type[Exception]
+
+ # Roughly: we've given the database data which breaks a rule we asked it to enforce.
+ # Examples:
+ # - Stop, criminal scum! You violated the foreign key constraint
+ # - Also check constraints, non-null constraints, etc.
+ IntegrityError: Type[Exception]
+
+ # Roughly: something went wrong within the database server itself.
+ InternalError: Type[Exception]
+
+ # Roughly: the application did something silly that needs to be fixed. Examples:
+ # - We don't have permissions to do something.
+ # - We tried to create a table with duplicate column names.
+ # - We tried to use a reserved name.
+ # - We referred to a column that doesn't exist.
+ ProgrammingError: Type[Exception]
+
+ # Roughly: we've tried to do something that this database doesn't support.
+ NotSupportedError: Type[Exception]
+
+ def connect(self, **parameters: object) -> Connection:
...
+
+
+__all__ = ["Cursor", "Connection", "DBAPI2Module"]
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index a61a951e..211437cf 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.util import unwrapFirstError
logger = logging.getLogger(__name__)
@@ -118,3 +119,62 @@ class PartialStateEventsTracker:
observer_set.discard(observer)
if not observer_set:
del self._observers[event_id]
+
+
+class PartialCurrentStateTracker:
+ """Keeps track of which rooms have partial state, after partial-state joins"""
+
+ def __init__(self, store: RoomWorkerStore):
+ self._store = store
+
+ # a map from room id to a set of Deferreds which are waiting for that room to be
+ # un-partial-stated.
+ self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
+
+ def notify_un_partial_stated(self, room_id: str) -> None:
+ """Notify that we now have full current state for a given room
+
+ Unblocks any callers to await_full_state() for that room.
+
+ Args:
+ room_id: the room that now has full current state.
+ """
+ observers = self._observers.pop(room_id, None)
+ if not observers:
+ return
+ logger.info(
+ "Notifying %i things waiting for un-partial-stating of room %s",
+ len(observers),
+ room_id,
+ )
+ with PreserveLoggingContext():
+ for o in observers:
+ o.callback(None)
+
+ async def await_full_state(self, room_id: str) -> None:
+ # We add the deferred immediately so that the DB call to check for
+ # partial state doesn't race when we unpartial the room.
+ d: Deferred[None] = Deferred()
+ self._observers.setdefault(room_id, set()).add(d)
+
+ try:
+ # Check if the room has partial current state or not.
+ has_partial_state = await self._store.is_partial_state_room(room_id)
+ if not has_partial_state:
+ return
+
+ logger.info(
+ "Awaiting un-partial-stating of room %s",
+ room_id,
+ )
+
+ await make_deferred_yieldable(d)
+
+ logger.info("Room has un-partial-stated")
+ finally:
+ # Remove the added observer, and remove the room entry if its empty.
+ ds = self._observers.get(room_id)
+ if ds is not None:
+ ds.discard(d)
+ if not ds:
+ self._observers.pop(room_id, None)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index acf17ba6..54e0b1a2 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -54,7 +54,6 @@ class EventSources:
push_rules_key = self.store.get_max_push_rules_stream_id()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
- groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=self.sources.room.get_current_key(),
@@ -65,7 +64,8 @@ class EventSources:
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
- groups_key=groups_key,
+ # Groups key is unused.
+ groups_key=0,
)
return token
diff --git a/synapse/types.py b/synapse/types.py
index 9ac688b2..0586d2cb 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -24,6 +24,7 @@ from typing import (
Mapping,
Match,
MutableMapping,
+ NoReturn,
Optional,
Set,
Tuple,
@@ -35,7 +36,8 @@ from typing import (
import attr
from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes
-from typing_extensions import TypedDict
+from signedjson.types import VerifyKey
+from typing_extensions import Final, TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore, PurgeEventsStore
+ from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
# Define a state map type from type/state_key to T (usually an event ID or
# event)
@@ -114,7 +117,7 @@ class Requester:
app_service: Optional["ApplicationService"]
authenticated_entity: str
- def serialize(self):
+ def serialize(self) -> Dict[str, Any]:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -132,7 +135,9 @@ class Requester:
}
@staticmethod
- def deserialize(store, input):
+ def deserialize(
+ store: "ApplicationServiceWorkerStore", input: Dict[str, Any]
+ ) -> "Requester":
"""Converts a dict that was produced by `serialize` back into a
Requester.
@@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
domain: str
# Because this is a frozen class, it is deeply immutable.
- def __copy__(self):
+ def __copy__(self: DS) -> DS:
return self
- def __deepcopy__(self, memo):
+ def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS:
return self
@classmethod
@@ -315,29 +320,6 @@ class EventID(DomainSpecificString):
SIGIL = "$"
-@attr.s(slots=True, frozen=True, repr=False)
-class GroupID(DomainSpecificString):
- """Structure representing a group ID."""
-
- SIGIL = "+"
-
- @classmethod
- def from_string(cls: Type[DS], s: str) -> DS:
- group_id: DS = super().from_string(s) # type: ignore
-
- if not group_id.localpart:
- raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
-
- if contains_invalid_mxid_characters(group_id.localpart):
- raise SynapseError(
- 400,
- "Group ID can only contain characters a-z, 0-9, or '=_-./'",
- Codes.INVALID_PARAM,
- )
-
- return group_id
-
-
mxid_localpart_allowed_characters = set(
"_-./=" + string.ascii_lowercase + string.digits
)
@@ -625,6 +607,22 @@ class RoomStreamToken:
return "s%d" % (self.stream,)
+class StreamKeyType:
+ """Known stream types.
+
+ A stream is a list of entities ordered by an incrementing "stream token".
+ """
+
+ ROOM: Final = "room_key"
+ PRESENCE: Final = "presence_key"
+ TYPING: Final = "typing_key"
+ RECEIPT: Final = "receipt_key"
+ ACCOUNT_DATA: Final = "account_data_key"
+ PUSH_RULES: Final = "push_rules_key"
+ TO_DEVICE: Final = "to_device_key"
+ DEVICE_LIST: Final = "device_list_key"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class StreamToken:
"""A collection of keys joined together by underscores in the following
@@ -641,7 +639,7 @@ class StreamToken:
6. `push_rules_key`: `541479`
7. `to_device_key`: `274711`
8. `device_list_key`: `265584`
- 9. `groups_key`: `1`
+ 9. `groups_key`: `1` (note that this key is now unused)
You can see how many of these keys correspond to the various
fields in a "/sync" response:
@@ -693,6 +691,7 @@ class StreamToken:
push_rules_key: int
to_device_key: int
device_list_key: int
+ # Note that the groups key is no longer used and may have bogus values.
groups_key: int
_SEPARATOR = "_"
@@ -724,21 +723,26 @@ class StreamToken:
str(self.push_rules_key),
str(self.to_device_key),
str(self.device_list_key),
+ # Note that the groups key is no longer used, but it is still
+ # serialized so that there will not be confusion in the future
+ # if additional tokens are added.
str(self.groups_key),
]
)
@property
- def room_stream_id(self):
+ def room_stream_id(self) -> int:
return self.room_key.stream
- def copy_and_advance(self, key, new_value) -> "StreamToken":
+ def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
+
+ :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken.
"""
- if key == "room_key":
+ if key == StreamKeyType.ROOM:
new_token = self.copy_and_replace(
- "room_key", self.room_key.copy_and_advance(new_value)
+ StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
)
return new_token
@@ -751,7 +755,7 @@ class StreamToken:
else:
return self
- def copy_and_replace(self, key, new_value) -> "StreamToken":
+ def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key: new_value})
@@ -793,14 +797,14 @@ class ThirdPartyInstanceID:
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
- def __iter__(self):
+ def __iter__(self) -> NoReturn:
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
# Because this class is a frozen class, it is deeply immutable.
- def __copy__(self):
+ def __copy__(self) -> "ThirdPartyInstanceID":
return self
- def __deepcopy__(self, memo):
+ def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID":
return self
@classmethod
@@ -852,25 +856,28 @@ class DeviceListUpdates:
return bool(self.changed or self.left)
-def get_verify_key_from_cross_signing_key(key_info):
+def get_verify_key_from_cross_signing_key(
+ key_info: Mapping[str, Any]
+) -> Tuple[str, VerifyKey]:
"""Get the key ID and signedjson verify key from a cross-signing key dict
Args:
- key_info (dict): a cross-signing key dict, which must have a "keys"
+ key_info: a cross-signing key dict, which must have a "keys"
property that has exactly one item in it
Returns:
- (str, VerifyKey): the key ID and verify key for the cross-signing key
+ the key ID and verify key for the cross-signing key
"""
- # make sure that exactly one key is provided
+ # make sure that a `keys` field is provided
if "keys" not in key_info:
raise ValueError("Invalid key")
keys = key_info["keys"]
- if len(keys) != 1:
- raise ValueError("Invalid key")
- # and return that one key
- for key_id, key_data in keys.items():
+ # and that it contains exactly one key
+ if len(keys) == 1:
+ key_id, key_data = next(iter(keys.items()))
return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
+ else:
+ raise ValueError("Invalid key")
@attr.s(auto_attribs=True, frozen=True, slots=True)
@@ -906,3 +913,9 @@ class UserProfile(TypedDict):
user_id: str
display_name: Optional[str]
avatar_url: Optional[str]
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class RetentionPolicy:
+ min_lifetime: Optional[int] = None
+ max_lifetime: Optional[int] = None
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index eda92d86..867f315b 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -595,13 +595,14 @@ def cached(
def cachedList(
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None
) -> Callable[[F], _CachedFunction[F]]:
- """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
+ """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
- Used to do batch lookups for an already created cache. A single argument
+ Used to do batch lookups for an already created cache. One of the arguments
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
- the cache gets passed to the original function, the result of which is stored in the
- cache.
+ the cache gets passed to the original function, which is expected to results
+ in a map of key to value for each passed value. THe new results are stored in the
+ original cache. Note that any missing values are cached as None.
Args:
cached_method_name: The name of the single-item lookup method.
@@ -614,11 +615,11 @@ def cachedList(
Example:
class Example:
- @cached(num_args=2)
- def do_something(self, first_arg):
+ @cached()
+ def do_something(self, first_arg, second_arg):
...
- @cachedList(do_something.cache, list_name="second_args", num_args=2)
+ @cachedList(cached_method_name="do_something", list_name="second_args")
def batch_do_something(self, first_arg, second_args):
...
"""
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 45ff0de6..a3b60578 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+import math
import threading
import weakref
from enum import Enum
@@ -40,6 +41,7 @@ from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -106,10 +108,16 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
@wrap_as_background_process("LruCache._expire_old_entries")
-async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
+async def _expire_old_entries(
+ clock: Clock, expiry_seconds: int, autotune_config: Optional[dict]
+) -> None:
"""Walks the global cache list to find cache entries that haven't been
- accessed in the given number of seconds.
+ accessed in the given number of seconds, or if a given memory threshold has been breached.
"""
+ if autotune_config:
+ max_cache_memory_usage = autotune_config["max_cache_memory_usage"]
+ target_cache_memory_usage = autotune_config["target_cache_memory_usage"]
+ min_cache_ttl = autotune_config["min_cache_ttl"] / 1000
now = int(clock.time())
node = GLOBAL_ROOT.prev_node
@@ -119,11 +127,36 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
logger.debug("Searching for stale caches")
+ evicting_due_to_memory = False
+
+ # determine if we're evicting due to memory
+ jemalloc_interface = get_jemalloc_stats()
+ if jemalloc_interface and autotune_config:
+ try:
+ jemalloc_interface.refresh_stats()
+ mem_usage = jemalloc_interface.get_stat("allocated")
+ if mem_usage > max_cache_memory_usage:
+ logger.info("Begin memory-based cache eviction.")
+ evicting_due_to_memory = True
+ except Exception:
+ logger.warning(
+ "Unable to read allocated memory, skipping memory-based cache eviction."
+ )
+
while node is not GLOBAL_ROOT:
# Only the root node isn't a `_TimedListNode`.
assert isinstance(node, _TimedListNode)
- if node.last_access_ts_secs > now - expiry_seconds:
+ # if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's
+ # nothing to do here
+ if (
+ node.last_access_ts_secs > now - expiry_seconds
+ and not evicting_due_to_memory
+ ):
+ break
+
+ # if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer
+ if evicting_due_to_memory and now - node.last_access_ts_secs < min_cache_ttl:
break
cache_entry = node.get_cache_entry()
@@ -136,10 +169,29 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
assert cache_entry is not None
cache_entry.drop_from_cache()
+ # Check mem allocation periodically if we are evicting a bunch of caches
+ if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0:
+ try:
+ jemalloc_interface.refresh_stats()
+ mem_usage = jemalloc_interface.get_stat("allocated")
+ if mem_usage < target_cache_memory_usage:
+ evicting_due_to_memory = False
+ logger.info("Stop memory-based cache eviction.")
+ except Exception:
+ logger.warning(
+ "Unable to read allocated memory, this may affect memory-based cache eviction."
+ )
+ # If we've failed to read the current memory usage then we
+ # should stop trying to evict based on memory usage
+ evicting_due_to_memory = False
+
# If we do lots of work at once we yield to allow other stuff to happen.
if (i + 1) % 10000 == 0:
logger.debug("Waiting during drop")
- await clock.sleep(0)
+ if node.last_access_ts_secs > now - expiry_seconds:
+ await clock.sleep(0.5)
+ else:
+ await clock.sleep(0)
logger.debug("Waking during drop")
node = next_node
@@ -156,21 +208,28 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
"""Start a background job that expires all cache entries if they have not
- been accessed for the given number of seconds.
+ been accessed for the given number of seconds, or if a given memory usage threshold has been
+ breached.
"""
- if not hs.config.caches.expiry_time_msec:
+ if not hs.config.caches.expiry_time_msec and not hs.config.caches.cache_autotuning:
return
- logger.info(
- "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000
- )
+ if hs.config.caches.expiry_time_msec:
+ expiry_time = hs.config.caches.expiry_time_msec / 1000
+ logger.info("Expiring LRU caches after %d seconds", expiry_time)
+ else:
+ expiry_time = math.inf
global USE_GLOBAL_LIST
USE_GLOBAL_LIST = True
clock = hs.get_clock()
clock.looping_call(
- _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000
+ _expire_old_entries,
+ 30 * 1000,
+ clock,
+ expiry_time,
+ hs.config.caches.cache_autotuning,
)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 81bfed26..d0a69ff8 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -16,8 +16,8 @@ import random
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional, Type
-import synapse.logging.context
from synapse.api.errors import CodeMessageException
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage import DataStore
from synapse.util import Clock
@@ -265,4 +265,4 @@ class RetryDestinationLimiter:
logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
- synapse.logging.context.run_in_background(store_retry_timings)
+ run_as_background_process("store_retry_timings", store_retry_timings)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index de6d2ffc..8aaa8c70 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -20,9 +20,9 @@ from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.utils import prune_event
-from synapse.storage import Storage
+from synapse.storage.controllers import StorageControllers
from synapse.storage.state import StateFilter
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
logger = logging.getLogger(__name__)
@@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, ""
async def filter_events_for_client(
- storage: Storage,
+ storage: StorageControllers,
user_id: str,
events: List[EventBase],
is_peeking: bool = False,
@@ -94,7 +94,7 @@ async def filter_events_for_client(
if filter_send_to_client:
room_ids = {e.room_id for e in events}
- retention_policies = {}
+ retention_policies: Dict[str, RetentionPolicy] = {}
for room_id in room_ids:
retention_policies[
@@ -137,7 +137,7 @@ async def filter_events_for_client(
# events.
if not event.is_state():
retention_policy = retention_policies[event.room_id]
- max_lifetime = retention_policy.get("max_lifetime")
+ max_lifetime = retention_policy.max_lifetime
if max_lifetime is not None:
oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
@@ -162,16 +162,7 @@ async def filter_events_for_client(
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
- visibility_event = state.get(_HISTORY_VIS_KEY, None)
- if visibility_event:
- visibility = visibility_event.content.get(
- "history_visibility", HistoryVisibility.SHARED
- )
- else:
- visibility = HistoryVisibility.SHARED
-
- if visibility not in VISIBILITY_PRIORITY:
- visibility = HistoryVisibility.SHARED
+ visibility = get_effective_room_visibility_from_state(state)
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
@@ -267,8 +258,25 @@ async def filter_events_for_client(
return [ev for ev in filtered_events if ev]
+def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str:
+ """Get the actual history vis, from a state map including the history_visibility event
+
+ Handles missing and invalid history visibility events.
+ """
+ visibility_event = state.get(_HISTORY_VIS_KEY, None)
+ if not visibility_event:
+ return HistoryVisibility.SHARED
+
+ visibility = visibility_event.content.get(
+ "history_visibility", HistoryVisibility.SHARED
+ )
+ if visibility not in VISIBILITY_PRIORITY:
+ visibility = HistoryVisibility.SHARED
+ return visibility
+
+
async def filter_events_for_server(
- storage: Storage,
+ storage: StorageControllers,
server_name: str,
events: List[EventBase],
redact: bool = True,
@@ -360,7 +368,7 @@ async def filter_events_for_server(
async def _event_to_history_vis(
- storage: Storage, events: Collection[EventBase]
+ storage: StorageControllers, events: Collection[EventBase]
) -> Dict[str, str]:
"""Get the history visibility at each of the given events
@@ -407,7 +415,7 @@ async def _event_to_history_vis(
async def _event_to_memberships(
- storage: Storage, events: Collection[EventBase], server_name: str
+ storage: StorageControllers, events: Collection[EventBase], server_name: str
) -> Dict[str, StateMap[EventBase]]:
"""Get the remote membership list at each of the given events
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index d547df8a..bc75ddd3 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -404,7 +404,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
- self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
@@ -433,7 +432,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
- self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 985d6e39..a269c477 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -20,7 +20,7 @@ from unittest.mock import patch
import jsonschema
from frozendict import frozendict
-from synapse.api.constants import EventContentFields
+from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
@@ -85,13 +85,13 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"org.matrix.not_labels": ["#work"],
},
"ephemeral": {
- "types": ["m.receipt", "m.typing"],
+ "types": [EduTypes.RECEIPT, EduTypes.TYPING],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"],
},
},
"presence": {
- "types": ["m.presence"],
+ "types": [EduTypes.PRESENCE],
"not_senders": ["@alice:example.com"],
},
"event_format": "client",
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483d5463..f661a9ff 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -31,7 +31,6 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
None,
- "example.com",
id="foo",
rate_limited=True,
sender="@as:example.com",
@@ -62,7 +61,6 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
None,
- "example.com",
id="foo",
rate_limited=False,
sender="@as:example.com",
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
new file mode 100644
index 00000000..532b6763
--- /dev/null
+++ b/tests/appservice/test_api.py
@@ -0,0 +1,101 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, List, Mapping
+from unittest.mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.appservice import ApplicationService
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests import unittest
+
+PROTOCOL = "myproto"
+TOKEN = "myastoken"
+URL = "http://mytestservice"
+
+
+class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.api = hs.get_application_service_api()
+ self.service = ApplicationService(
+ id="unique_identifier",
+ sender="@as:test",
+ url=URL,
+ token="unused",
+ hs_token=TOKEN,
+ )
+
+ def test_query_3pe_authenticates_token(self):
+ """
+ Tests that 3pe queries to the appservice are authenticated
+ with the appservice's token.
+ """
+
+ SUCCESS_RESULT_USER = [
+ {
+ "protocol": PROTOCOL,
+ "userid": "@a:user",
+ "fields": {
+ "more": "fields",
+ },
+ }
+ ]
+ SUCCESS_RESULT_LOCATION = [
+ {
+ "protocol": PROTOCOL,
+ "alias": "#a:room",
+ "fields": {
+ "more": "fields",
+ },
+ }
+ ]
+
+ URL_USER = f"{URL}/_matrix/app/unstable/thirdparty/user/{PROTOCOL}"
+ URL_LOCATION = f"{URL}/_matrix/app/unstable/thirdparty/location/{PROTOCOL}"
+
+ self.request_url = None
+
+ async def get_json(url: str, args: Mapping[Any, Any]) -> List[JsonDict]:
+ if not args.get(b"access_token"):
+ raise RuntimeError("Access token not provided")
+
+ self.assertEqual(args.get(b"access_token"), TOKEN)
+ self.request_url = url
+ if url == URL_USER:
+ return SUCCESS_RESULT_USER
+ elif url == URL_LOCATION:
+ return SUCCESS_RESULT_LOCATION
+ else:
+ raise RuntimeError(
+ "URL provided was invalid. This should never be seen."
+ )
+
+ # We assign to a method, which mypy doesn't like.
+ self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment]
+
+ result = self.get_success(
+ self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
+ )
+ self.assertEqual(self.request_url, URL_USER)
+ self.assertEqual(result, SUCCESS_RESULT_USER)
+ result = self.get_success(
+ self.api.query_3pe(
+ self.service, "location", PROTOCOL, {b"some": [b"field"]}
+ )
+ )
+ self.assertEqual(self.request_url, URL_LOCATION)
+ self.assertEqual(result, SUCCESS_RESULT_LOCATION)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index edc584d0..3018d3fc 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -23,7 +23,7 @@ from tests.test_utils import simple_async_mock
def _regex(regex: str, exclusive: bool = True) -> Namespace:
- return Namespace(exclusive, None, re.compile(regex))
+ return Namespace(exclusive, re.compile(regex))
class ApplicationServiceTestCase(unittest.TestCase):
@@ -33,7 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
sender="@as:test",
url="some_url",
token="some_token",
- hostname="matrix.org", # only used by get_groups_for_user
)
self.event = Mock(
event_id="$abc:xyz",
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 4bb82e81..d2b3c299 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -38,6 +38,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_NOT_CACHE": "BLAH",
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
@@ -52,6 +53,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_CACHE_FACTOR_FOO": 1,
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(
dict(self.config.cache_factors),
@@ -71,6 +73,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"per_cache_factors": {"foo": 3}}}
self.config.read_config(config)
+ self.config.resize_all_caches()
self.assertEqual(cache.max_size, 300)
@@ -82,6 +85,7 @@ class CacheConfigTests(TestCase):
"""
config = {"caches": {"per_cache_factors": {"foo": 2}}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
@@ -99,6 +103,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"global_factor": 4}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(cache.max_size, 400)
@@ -110,6 +115,7 @@ class CacheConfigTests(TestCase):
"""
config = {"caches": {"global_factor": 1.5}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
@@ -128,6 +134,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache_a = LruCache(100)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
@@ -148,6 +155,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"event_cache_size": "10k"}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(
max_size=self.config.event_cache_size,
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 06e0545a..8fa710c9 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import nacl.signing
-import signedjson.types
-from unpaddedbase64 import decode_base64
+from signedjson.key import decode_signing_key_base64
+from signedjson.types import SigningKey
from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import add_hashes_and_signatures
@@ -25,7 +23,7 @@ from tests import unittest
# Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against.
-SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
+SIGNING_KEY_SEED = "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
KEY_ALG = "ed25519"
KEY_VER = "1"
@@ -36,14 +34,9 @@ HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase):
def setUp(self):
- # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been
- # monkeypatched to include new `alg` and `version` attributes. This is captured
- # by the `signedjson.types.SigningKey` protocol.
- self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment]
- SIGNING_KEY_SEED
+ self.signing_key: SigningKey = decode_signing_key_base64(
+ KEY_ALG, KEY_VER, SIGNING_KEY_SEED
)
- self.signing_key.alg = KEY_ALG
- self.signing_key.version = KEY_VER
def test_sign_minimal(self):
event_dict = {
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index d00ef24c..820a1a54 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,8 +19,8 @@ import attr
import canonicaljson
import signedjson.key
import signedjson.sign
-from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
+from signedjson.types import SigningKey
from twisted.internet import defer
from twisted.internet.defer import Deferred, ensureDeferred
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 3deb14c3..ffc3012a 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -439,7 +439,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
for edu in edus:
# Make sure we're only checking presence-type EDUs
- if edu["edu_type"] != EduTypes.Presence:
+ if edu["edu_type"] != EduTypes.PRESENCE:
continue
# EDUs can contain multiple presence updates
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index defbc68c..8ddce83b 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
@@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store))
- d_context = EventContext.deserialize(self.storage, serialized)
+ d_context = EventContext.deserialize(self._storage_controllers, serialized)
self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 6b26353d..01a1db61 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -19,7 +19,7 @@ from signedjson.types import BaseKey, SigningKey
from twisted.internet import defer
-from synapse.api.constants import RoomEncryptionAlgorithms
+from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.rest import admin
from synapse.rest.client import login
from synapse.types import JsonDict, ReadReceipt
@@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
- # Ensure a new Awaitable is created for each call.
- mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
- ["test", "host2"]
- )
- return self.setup_test_homeserver(
- state_handler=mock_state_handler,
+ hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
+ hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
+ return_value=make_awaitable({"test", "host2"})
+ )
+
+ return hs
+
@override_config({"send_federation": True})
def test_send_receipts(self):
mock_send_transaction = (
@@ -63,7 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
data["edus"],
[
{
- "edu_type": "m.receipt",
+ "edu_type": EduTypes.RECEIPT,
"content": {
"room_id": {
"m.read": {
@@ -103,7 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
data["edus"],
[
{
- "edu_type": "m.receipt",
+ "edu_type": EduTypes.RECEIPT,
"content": {
"room_id": {
"m.read": {
@@ -138,7 +138,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
data["edus"],
[
{
- "edu_type": "m.receipt",
+ "edu_type": EduTypes.RECEIPT,
"content": {
"room_id": {
"m.read": {
@@ -322,8 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# expect signing key update edu
self.assertEqual(len(self.edus), 2)
- self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
- self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
+ self.assertEqual(self.edus.pop(0)["edu_type"], EduTypes.SIGNING_KEY_UPDATE)
+ self.assertEqual(
+ self.edus.pop(0)["edu_type"], EduTypes.UNSTABLE_SIGNING_KEY_UPDATE
+ )
# sign the devices
d1_json = build_device_dict(u1, "D1", device1_signing_key)
@@ -348,7 +350,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 2)
stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142
for edu in self.edus:
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
if stream_id is not None:
self.assertEqual(c["prev_id"], [stream_id])
@@ -388,7 +390,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# expect three edus, in an unknown order
self.assertEqual(len(self.edus), 3)
for edu in self.edus:
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
self.assertGreaterEqual(
c.items(),
@@ -435,7 +437,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 3)
stream_id = None
for edu in self.edus:
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else [])
if stream_id is not None:
@@ -487,7 +489,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# there should be a single update for this user.
self.assertEqual(len(self.edus), 1)
edu = self.edus.pop(0)
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
# synapse uses an empty prev_id list to indicate "needs a full resync".
@@ -544,7 +546,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# ... and we should get a single update for this user.
self.assertEqual(len(self.edus), 1)
edu = self.edus.pop(0)
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
# synapse uses an empty prev_id list to indicate "needs a full resync".
@@ -560,7 +562,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
"""Check that the given EDU is an update for the given device
Returns the stream_id.
"""
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
content = edu["content"]
expected = {
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index b19365b8..413b3c94 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs)
+ self._storage_controllers = hs.get_storage_controllers()
+
# create the room
creator_user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
- self.hs.get_state_handler().get_current_state(self._room_id)
+ self._storage_controllers.state.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
- self.hs.get_state_handler().get_current_state(self._room_id)
+ self._storage_controllers.state.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
diff --git a/tests/federation/transport/server/__init__.py b/tests/federation/transport/server/__init__.py
new file mode 100644
index 00000000..3a5f22c0
--- /dev/null
+++ b/tests/federation/transport/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
new file mode 100644
index 00000000..e63885c1
--- /dev/null
+++ b/tests/federation/transport/server/test__base.py
@@ -0,0 +1,141 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Dict, List, Tuple
+
+from synapse.api.errors import Codes
+from synapse.federation.transport.server import BaseFederationServlet
+from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
+from synapse.http.server import JsonResource, cancellable
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
+
+
+class CancellableFederationServlet(BaseFederationServlet):
+ PATH = "/sleep"
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.clock = hs.get_clock()
+
+ @cancellable
+ async def on_GET(
+ self, origin: str, content: None, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class BaseFederationServletCancellationTests(
+ unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `BaseFederationServlet` cancellation."""
+
+ skip = "`BaseFederationServlet` does not support cancellation yet."
+
+ path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}"
+
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`."""
+ resource = JsonResource(self.hs)
+
+ CancellableFederationServlet(
+ hs=self.hs,
+ authenticator=Authenticator(self.hs),
+ ratelimiter=self.hs.get_federation_ratelimiter(),
+ server_name=self.hs.hostname,
+ ).register(resource)
+
+ return resource
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = self.make_signed_federation_request(
+ "GET", self.path, await_result=False
+ )
+
+ # Advance past all the rate limiting logic. If we disconnect too early, the
+ # request won't be processed.
+ self.pump()
+
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = self.make_signed_federation_request(
+ "POST",
+ self.path,
+ content={},
+ await_result=False,
+ )
+
+ # Advance past all the rate limiting logic. If we disconnect too early, the
+ # request won't be processed.
+ self.pump()
+
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
+
+
+class BaseFederationAuthorizationTests(unittest.TestCase):
+ def test_authorization_header(self) -> None:
+ """Tests that the Authorization header is parsed correctly."""
+
+ # test a "normal" Authorization header
+ self.assertEqual(
+ _parse_auth_header(
+ b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar"'
+ ),
+ ("foo", "ed25519:1", "sig", "bar"),
+ )
+ # test an Authorization with extra spaces, upper-case names, and escaped
+ # characters
+ self.assertEqual(
+ _parse_auth_header(
+ b'X-Matrix ORIGIN=foo,KEY="ed25\\519:1",SIG="sig",destination="bar"'
+ ),
+ ("foo", "ed25519:1", "sig", "bar"),
+ )
+ self.assertEqual(
+ _parse_auth_header(
+ b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar",extra_field=ignored'
+ ),
+ ("foo", "ed25519:1", "sig", "bar"),
+ )
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 5f001c33..cfd550a0 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.constants import EduTypes
+
from tests import unittest
from tests.unittest import DEBUG, override_config
@@ -50,7 +52,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/send/txn_id_1234/",
content={
"edus": [
- {"edu_type": "m.device_list_update", "content": {"foo": "bar"}}
+ {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}}
],
"pdus": [],
},
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 5b0cd1ab..d96d5aa1 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
+from synapse.api.constants import EduTypes
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
@@ -434,16 +435,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
},
)
- # "Complete" a transaction.
- # All this really does for us is make an entry in the application_services_state
- # database table, which tracks the current stream_token per stream ID per AS.
- self.get_success(
- self.hs.get_datastores().main.complete_appservice_txn(
- 0,
- interested_appservice,
- )
- )
-
# Now, pretend that we receive a large burst of read receipts (300 total) that
# all come in at once.
for i in range(300):
@@ -486,7 +477,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
# Check that the ephemeral event is a read receipt with the expected structure
latest_read_receipt = all_ephemeral_events[-1]
- self.assertEqual(latest_read_receipt["type"], "m.receipt")
+ self.assertEqual(latest_read_receipt["type"], EduTypes.RECEIPT)
event_id = list(latest_read_receipt["content"].keys())[0]
self.assertEqual(
@@ -706,7 +697,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
# Create an application service
appservice = ApplicationService(
token=random_string(10),
- hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
@@ -785,7 +775,6 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Create an appservice that is interested in "local_user"
appservice = ApplicationService(
token=random_string(10),
- hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
@@ -852,7 +841,6 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
self._service_token = "VERYSECRET"
self._service = ApplicationService(
self._service_token,
- "as1.invalid",
"as1",
"@as.sender:test",
namespaces={
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 11ad4422..53d49ca8 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
+ self._storage_controllers = hs.get_storage_controllers()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _get_canonical_alias(self):
"""Get the canonical alias state of the room."""
return self.get_success(
- self.state_handler.get_current_state(
+ self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 060ba5f5..e0eda545 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
- self.state_store = hs.get_storage().state
+ self.state_storage_controller = hs.get_storage_controllers().state
self._event_auth_handler = hs.get_event_auth_handler()
return hs
@@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
current_state = self.get_success(
self.store.get_events_as_list(
- (self.get_success(self.store.get_current_state_ids(room_id))).values()
+ (
+ self.get_success(self.store.get_partial_current_state_ids(room_id))
+ ).values()
)
)
@@ -276,7 +278,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
- self.OTHER_SERVER_NAME, event, state=current_state
+ self.OTHER_SERVER_NAME,
+ event,
+ state_ids={
+ (e.type, e.state_key): e.event_id for e in current_state
+ },
)
)
@@ -332,8 +338,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
most_recent_prev_event_depth,
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
# mapping from (type, state_key) -> state_event_id
+ assert most_recent_prev_event_id is not None
prev_state_map = self.get_success(
- self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
+ self.state_storage_controller.get_state_ids_for_event(
+ most_recent_prev_event_id
+ )
)
# List of state event ID's
prev_state_ids = list(prev_state_map.values())
@@ -505,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
self.get_success(d)
# sanity-check: the room should show that the new user is a member
- r = self.get_success(self.store.get_current_state_ids(room_id))
+ r = self.get_success(self.store.get_partial_current_state_ids(room_id))
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
return join_event
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 489ba577..1a36c25c 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) -> None:
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
- state_storage = self.hs.get_storage().state
+ state_storage_controller = self.hs.get_storage_controllers().state
# create the room
user_id = self.register_user("kermit", "test")
@@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
)
- initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
auth_event_ids = [
initial_state_map[("m.room.create", "")],
@@ -146,9 +148,12 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
- persistence = self.hs.get_storage().persistence
+ persistence = self.hs.get_storage_controllers().persistence
self.get_success(
- persistence.persist_event(prev_event, EventContext.for_outlier())
+ persistence.persist_event(
+ prev_event,
+ EventContext.for_outlier(self.hs.get_storage_controllers()),
+ )
)
else:
@@ -214,7 +219,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# check that the state at that event is as expected
state = self.get_success(
- state_storage.get_state_ids_for_event(pulled_event.event_id)
+ state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
)
expected_state = {
(e.type, e.state_key): e.event_id for e in state_at_prev_event
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index f4f7ab48..44da96c7 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.handler = self.hs.get_event_creation_handler()
- self.persist_event_storage = self.hs.get_storage().persistence
+ self._persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
@@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self._persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
return memberEvent, memberEventContext
@@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
- self.persist_event_storage.persist_event(event3, context)
+ self._persist_event_storage_controller.persist_event(event3, context)
)
# Assert that the returned values match those from the initial event
@@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
- self.persist_event_storage.persist_events([(event3, context)])
+ self._persist_event_storage_controller.persist_events([(event3, context)])
)
ret_event4 = events[0]
@@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id)
events, _ = self.get_success(
- self.persist_event_storage.persist_events(
+ self._persist_event_storage_controller.persist_events(
[(event1, context1), (event2, context2)]
)
)
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 0482a1ea..a95868b5 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from copy import deepcopy
from typing import List
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.types import JsonDict
from tests import unittest
@@ -39,7 +39,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[],
@@ -64,7 +64,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[
@@ -79,7 +79,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
@@ -105,7 +105,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[
@@ -120,43 +120,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
- }
- ],
- )
-
- def test_handles_missing_content_of_m_read(self):
- self._test_filters_private(
- [
- {
- "content": {
- "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
- "$1435641916114394fHBLK:matrix.org": {
- ReceiptTypes.READ: {
- "@user:jki.re": {
- "ts": 1436451550453,
- }
- }
- },
- },
- "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
- }
- ],
- [
- {
- "content": {
- "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
- "$1435641916114394fHBLK:matrix.org": {
- ReceiptTypes.READ: {
- "@user:jki.re": {
- "ts": 1436451550453,
- }
- }
- },
- },
- "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
@@ -176,7 +140,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[
@@ -191,7 +155,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
@@ -210,7 +174,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
},
{
"content": {
@@ -223,7 +187,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
},
],
[
@@ -238,7 +202,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
@@ -260,7 +224,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
},
],
[
@@ -273,7 +237,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
},
],
)
@@ -302,7 +266,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[
@@ -327,14 +291,38 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
+ def test_we_do_not_mutate(self):
+ """Ensure the input values are not modified."""
+ events = [
+ {
+ "content": {
+ "$1435641916114394fHBLK:matrix.org": {
+ ReceiptTypes.READ_PRIVATE: {
+ "@rikj:jki.re": {
+ "ts": 1436451550453,
+ }
+ }
+ }
+ },
+ "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
+ "type": EduTypes.RECEIPT,
+ }
+ ]
+ original_events = deepcopy(events)
+ self._test_filters_private(events, [])
+ # Since the events are fed in from a cache they should not be modified.
+ self.assertEqual(events, original_events)
+
def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict]
):
"""Tests that the _filter_out_private returns the expected output"""
- filtered_events = self.event_source.filter_out_private(events, "@me:server.org")
+ filtered_events = self.event_source.filter_out_private_receipts(
+ events, "@me:server.org"
+ )
self.assertEqual(filtered_events, expected_output)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index e74eb717..05466556 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -179,7 +179,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result_children_ids.append(
[
(cs["room_id"], cs["state_key"])
- for cs in result_room.get("children_state")
+ for cs in result_room["children_state"]
]
)
@@ -772,7 +772,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": public_room,
"world_readable": False,
- "join_rules": JoinRules.PUBLIC,
+ "join_rule": JoinRules.PUBLIC,
},
),
(
@@ -780,7 +780,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": knock_room,
"world_readable": False,
- "join_rules": JoinRules.KNOCK,
+ "join_rule": JoinRules.KNOCK,
},
),
(
@@ -788,7 +788,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": not_invited_room,
"world_readable": False,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
),
(
@@ -796,7 +796,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": invited_room,
"world_readable": False,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
),
(
@@ -804,7 +804,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": restricted_room,
"world_readable": False,
- "join_rules": JoinRules.RESTRICTED,
+ "join_rule": JoinRules.RESTRICTED,
"allowed_room_ids": [],
},
),
@@ -813,7 +813,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": restricted_accessible_room,
"world_readable": False,
- "join_rules": JoinRules.RESTRICTED,
+ "join_rule": JoinRules.RESTRICTED,
"allowed_room_ids": [self.room],
},
),
@@ -822,7 +822,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": world_readable_room,
"world_readable": True,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
),
(
@@ -830,7 +830,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": joined_room,
"world_readable": False,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
),
)
@@ -911,7 +911,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": fed_room,
"world_readable": False,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 865b8b7e..db3302a4 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -160,6 +160,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
self.store._get_event_cache.clear()
+ self.store._event_ref.clear()
# The rooms should be excluded from the sync response.
# Get a new request key.
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5f2e26a5..7af13331 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
+from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.server import HomeServer
@@ -128,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
- def get_joined_hosts_for_room(room_id: str):
+ async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members}
- self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
+ hs.get_storage_controllers().state.get_current_hosts_in_room = (
+ get_current_hosts_in_room
+ )
async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
@@ -145,7 +148,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
+ self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = (
@@ -184,7 +187,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
@@ -209,7 +212,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
- "m.typing",
+ EduTypes.TYPING,
content={
"room_id": ROOM_ID,
"user_id": U_APPLE.to_string(),
@@ -231,7 +234,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"PUT",
"/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
- "m.typing",
+ EduTypes.TYPING,
content={
"room_id": ROOM_ID,
"user_id": U_ONION.to_string(),
@@ -254,7 +257,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": ROOM_ID,
"content": {"user_ids": [U_ONION.to_string()]},
}
@@ -270,7 +273,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"PUT",
"/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
- "m.typing",
+ EduTypes.TYPING,
content={
"room_id": OTHER_ROOM_ID,
"user_id": U_ONION.to_string(),
@@ -324,7 +327,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
- "m.typing",
+ EduTypes.TYPING,
content={
"room_id": ROOM_ID,
"user_id": U_APPLE.to_string(),
@@ -345,7 +348,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
events[0],
- [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
+ [
+ {
+ "type": EduTypes.TYPING,
+ "room_id": ROOM_ID,
+ "content": {"user_ids": []},
+ }
+ ],
)
def test_typing_timeout(self) -> None:
@@ -379,7 +388,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
@@ -402,7 +411,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
events[0],
- [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
+ [
+ {
+ "type": EduTypes.TYPING,
+ "room_id": ROOM_ID,
+ "content": {"user_ids": []},
+ }
+ ],
)
# SYN-230 - see if we can still set after timeout
@@ -433,7 +448,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 4d658d29..9e39cd97 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -60,7 +60,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.appservice = ApplicationService(
token="i_am_an_app_service",
- hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
# Note: this user does not match the regex above, so that tests
@@ -954,7 +953,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.hs.get_storage().persistence.persist_event(event, context)
+ self.hs.get_storage_controllers().persistence.persist_event(event, context)
)
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py
new file mode 100644
index 00000000..3a5f22c0
--- /dev/null
+++ b/tests/http/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
new file mode 100644
index 00000000..b9f1a381
--- /dev/null
+++ b/tests/http/server/_base.py
@@ -0,0 +1,100 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unles4s required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Callable, Optional, Union
+from unittest import mock
+
+from twisted.internet.error import ConnectionDone
+
+from synapse.http.server import (
+ HTTP_STATUS_REQUEST_CANCELLED,
+ respond_with_html_bytes,
+ respond_with_json,
+)
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel, ThreadedMemoryReactorClock
+
+
+class EndpointCancellationTestHelperMixin(unittest.TestCase):
+ """Provides helper methods for testing cancellation of endpoints."""
+
+ def _test_disconnect(
+ self,
+ reactor: ThreadedMemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+ ) -> None:
+ """Disconnects an in-flight request and checks the response.
+
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be
+ cancelled, `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200`
+ or `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = HTTPStatus.OK
+
+ request = channel.request
+ self.assertFalse(
+ channel.is_finished(),
+ "Request finished before we could disconnect - "
+ "was `await_result=False` passed to `make_request`?",
+ )
+
+ # We're about to disconnect the request. This also disconnects the channel, so
+ # we have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
+ else:
+ respond_mock.assert_not_called()
+
+ # The handler is expected to run to completion.
+ reactor.pump([1.0])
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 638babae..006dbab0 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel
from synapse.api.errors import RequestSendFailed
from synapse.http.matrixfederationclient import (
- MAX_RESPONSE_SIZE,
+ JsonParser,
MatrixFederationHttpClient,
MatrixFederationRequest,
)
@@ -609,9 +609,9 @@ class FederationClientTests(HomeserverTestCase):
while not test_d.called:
protocol.dataReceived(b"a" * chunk_size)
sent += chunk_size
- self.assertLessEqual(sent, MAX_RESPONSE_SIZE)
+ self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
- self.assertEqual(sent, MAX_RESPONSE_SIZE)
+ self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index a80bfb9f..b3655d7b 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -12,16 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from http import HTTPStatus
from io import BytesIO
+from typing import Tuple
from unittest.mock import Mock
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import cancellable
from synapse.http.servlet import (
+ RestServlet,
parse_json_object_from_request,
parse_json_value_from_request,
)
+from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns
+from synapse.server import HomeServer
+from synapse.types import JsonDict
from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
def make_request(content):
@@ -40,19 +49,21 @@ class TestServletUtils(unittest.TestCase):
"""Basic tests for parse_json_value_from_request."""
# Test round-tripping.
obj = {"foo": 1}
- result = parse_json_value_from_request(make_request(obj))
- self.assertEqual(result, obj)
+ result1 = parse_json_value_from_request(make_request(obj))
+ self.assertEqual(result1, obj)
# Results don't have to be objects.
- result = parse_json_value_from_request(make_request(b'["foo"]'))
- self.assertEqual(result, ["foo"])
+ result2 = parse_json_value_from_request(make_request(b'["foo"]'))
+ self.assertEqual(result2, ["foo"])
# Test empty.
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b""))
- result = parse_json_value_from_request(make_request(b""), allow_empty_body=True)
- self.assertIsNone(result)
+ result3 = parse_json_value_from_request(
+ make_request(b""), allow_empty_body=True
+ )
+ self.assertIsNone(result3)
# Invalid UTF-8.
with self.assertRaises(SynapseError):
@@ -76,3 +87,52 @@ class TestServletUtils(unittest.TestCase):
# Test not an object
with self.assertRaises(SynapseError):
parse_json_object_from_request(make_request(b'["foo"]'))
+
+
+class CancellableRestServlet(RestServlet):
+ """A `RestServlet` with a mix of cancellable and uncancellable handlers."""
+
+ PATTERNS = client_patterns("/sleep$")
+
+ def __init__(self, hs: HomeServer):
+ super().__init__()
+ self.clock = hs.get_clock()
+
+ @cancellable
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class TestRestServletCancellation(
+ unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `RestServlet` cancellation."""
+
+ servlets = [
+ lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
+ ]
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = self.make_request("GET", "/sleep", await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = self.make_request("POST", "/sleep", await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
diff --git a/tests/http/test_site.py b/tests/http/test_site.py
index 8c13b4f6..b2dbf76d 100644
--- a/tests/http/test_site.py
+++ b/tests/http/test_site.py
@@ -36,7 +36,7 @@ class SynapseRequestTestCase(HomeserverTestCase):
# as a control case, first send a regular request.
# complete the connection and wire it up to a fake transport
- client_address = IPv6Address("TCP", "::1", "2345")
+ client_address = IPv6Address("TCP", "::1", 2345)
protocol = factory.buildProtocol(client_address)
transport = StringTransport()
protocol.makeConnection(transport)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8bc84aaa..169e29b5 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -399,7 +399,7 @@ class ModuleApiTestCase(HomeserverTestCase):
for edu in edus:
# Make sure we're only checking presence-type EDUs
- if edu["edu_type"] != EduTypes.Presence:
+ if edu["edu_type"] != EduTypes.PRESENCE:
continue
# EDUs can contain multiple presence updates
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 5dba1870..9b623d00 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Set, Tuple, Union
import frozendict
@@ -26,7 +26,12 @@ from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
+ def _get_evaluator(
+ self,
+ content: JsonDict,
+ relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
+ relations_match_enabled: bool = False,
+ ) -> PushRuleEvaluatorForEvent:
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -42,7 +47,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
sender_power_level = 0
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent(
- event, room_member_count, sender_power_level, power_levels
+ event,
+ room_member_count,
+ sender_power_level,
+ power_levels,
+ relations or set(),
+ relations_match_enabled,
)
def test_display_name(self) -> None:
@@ -276,3 +286,71 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
push_rule_evaluator.tweaks_for_actions(actions),
{"sound": "default", "highlight": True},
)
+
+ def test_relation_match(self) -> None:
+ """Test the relation_match push rule kind."""
+
+ # Check if the experimental feature is disabled.
+ evaluator = self._get_evaluator(
+ {}, {"m.annotation": {("@user:test", "m.reaction")}}
+ )
+ condition = {"kind": "relation_match"}
+ # Oddly, an unknown condition always matches.
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+ # A push rule evaluator with the experimental rule enabled.
+ evaluator = self._get_evaluator(
+ {}, {"m.annotation": {("@user:test", "m.reaction")}}, True
+ )
+
+ # Check just relation type.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check relation type and sender.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "sender": "@user:test",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "sender": "@other:test",
+ }
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check relation type and event type.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "type": "m.reaction",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check just sender, this fails since rel_type is required.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "sender": "@user:test",
+ }
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check sender glob.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "sender": "@*:test",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check event type glob.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "event_type": "*.reaction",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a7602b4c..970d5e53 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
@@ -32,6 +33,7 @@ from synapse.server import HomeServer
from tests import unittest
from tests.server import FakeTransport
+from tests.utils import USE_POSTGRES_FOR_TESTS
try:
import hiredis
@@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self):
- self._subscribers = set()
+ self._subscribers_by_channel: Dict[
+ bytes, Set["FakeRedisPubSubProtocol"]
+ ] = defaultdict(set)
- def add_subscriber(self, conn):
+ def add_subscriber(self, conn, channel: bytes):
"""A connection has called SUBSCRIBE"""
- self._subscribers.add(conn)
+ self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn):
- """A connection has called UNSUBSCRIBE"""
- self._subscribers.discard(conn)
+ """A connection has lost connection"""
+ for subscribers in self._subscribers_by_channel.values():
+ subscribers.discard(conn)
- def publish(self, conn, channel, msg) -> int:
+ def publish(self, conn, channel: bytes, msg) -> int:
"""A connection want to publish a message to subscribers."""
- for sub in self._subscribers:
+ for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
- return len(self._subscribers)
+ return len(self._subscribers_by_channel)
def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
@@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol):
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
- (channel,) = args
- self._server.add_subscriber(self)
- self.send(["subscribe", channel, 1])
+ for idx, channel in enumerate(args):
+ num_channels = idx + 1
+ self._server.add_subscriber(self, channel)
+ self.send(["subscribe", channel, num_channels])
# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
@@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
+
+
+class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
+ """
+ A test case that enables Redis, providing a fake Redis server.
+ """
+
+ if not hiredis:
+ skip = "Requires hiredis"
+
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
+ def default_config(self) -> Dict[str, Any]:
+ """
+ Overrides the default config to enable Redis.
+ Even if the test only uses make_worker_hs, the main process needs Redis
+ enabled otherwise it won't create a Fake Redis server to listen on the
+ Redis port and accept fake TCP connections.
+ """
+ base = super().default_config()
+ base["redis"] = {"enabled": True}
+ return base
diff --git a/tests/replication/http/__init__.py b/tests/replication/http/__init__.py
new file mode 100644
index 00000000..3a5f22c0
--- /dev/null
+++ b/tests/replication/http/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
new file mode 100644
index 00000000..a5ab093a
--- /dev/null
+++ b/tests/replication/http/test__base.py
@@ -0,0 +1,106 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.errors import Codes
+from synapse.http.server import JsonResource, cancellable
+from synapse.replication.http import REPLICATION_PREFIX
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
+
+
+class CancellableReplicationEndpoint(ReplicationEndpoint):
+ NAME = "cancellable_sleep"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload() -> JsonDict:
+ return {}
+
+ @cancellable
+ async def _handle_request( # type: ignore[override]
+ self, request: Request
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class UncancellableReplicationEndpoint(ReplicationEndpoint):
+ NAME = "uncancellable_sleep"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload() -> JsonDict:
+ return {}
+
+ async def _handle_request( # type: ignore[override]
+ self, request: Request
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class ReplicationEndpointCancellationTestCase(
+ unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `ReplicationEndpoint` cancellation."""
+
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`."""
+ resource = JsonResource(self.hs)
+
+ CancellableReplicationEndpoint(self.hs).register(resource)
+ UncancellableReplicationEndpoint(self.hs).register(resource)
+
+ return resource
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
+ channel = self.make_request("POST", path, await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
+ channel = self.make_request("POST", path, await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 85be79d1..c5705256 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 297a9e77..6d3d4afe 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
msg, msgctx = self.build_event()
self.get_success(
- self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+ self._storage_controllers.persistence.persist_events(
+ [(j2, j2ctx), (msg, msgctx)]
+ )
)
self.replicate()
@@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self.storage.persistence.persist_events(
+ self._storage_controllers.persistence.persist_events(
[(event, context)], backfilled=True
)
)
else:
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index 5bbbd5fb..19f57115 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage = self.hs.get_storage().persistence
+ self.persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
@@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
# Join the second user to the second room
@@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
def test_return_empty_with_no_data(self):
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
new file mode 100644
index 00000000..e6a19eaf
--- /dev/null
+++ b/tests/replication/tcp/test_handler.py
@@ -0,0 +1,73 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests.replication._base import RedisMultiWorkerStreamTestCase
+
+
+class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+ def test_subscribed_to_enough_redis_channels(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ self.assertCountEqual(
+ self.hs.get_replication_command_handler()._channels_to_subscribe_to,
+ ["USER_IP"],
+ )
+
+ def test_background_worker_subscribed_to_user_ip(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ worker1 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+ self.assertIn(
+ "USER_IP",
+ worker1.get_replication_command_handler()._channels_to_subscribe_to,
+ )
+
+ # Advance so the Redis subscription gets processed
+ self.pump(0.1)
+
+ # The counts are 2 because both the main process and the worker are subscribed.
+ self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
+ self.assertEqual(
+ len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
+ )
+
+ def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ worker2 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker2",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+ self.assertNotIn(
+ "USER_IP",
+ worker2.get_replication_command_handler()._channels_to_subscribe_to,
+ )
+
+ # Advance so the Redis subscription gets processed
+ self.pump(0.1)
+
+ # The count is 2 because both the main process and the worker are subscribed.
+ self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
+ # For USER_IP, the count is 1 because only the main process is subscribed.
+ self.assertEqual(
+ len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
+ )
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 5f142e84..a7ca6806 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -14,7 +14,6 @@
import logging
from unittest.mock import patch
-from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -64,21 +63,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# We control the room ID generation by patching out the
# `_generate_room_id` method
- async def generate_room(
- creator_id: str, is_public: bool, room_version: RoomVersion
- ):
- await self.store.store_room(
- room_id=room_id,
- room_creator_user_id=creator_id,
- is_public=is_public,
- room_version=room_version,
- )
- return room_id
-
with patch(
"synapse.handlers.room.RoomCreationHandler._generate_room_id"
) as mock:
- mock.side_effect = generate_room
+ mock.side_effect = lambda: room_id
self.helper.create_room_as(user_id, tok=tok)
def test_basic(self):
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 40571b75..82ac5991 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -14,7 +14,6 @@
import urllib.parse
from http import HTTPStatus
-from typing import List
from parameterized import parameterized
@@ -23,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.rest.admin import VersionServlet
-from synapse.rest.client import groups, login, room
+from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -49,93 +48,6 @@ class VersionTestCase(unittest.HomeserverTestCase):
)
-class DeleteGroupTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- groups.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- @unittest.override_config({"experimental_features": {"groups_enabled": True}})
- def test_delete_group(self) -> None:
- # Create a new group
- channel = self.make_request(
- "POST",
- b"/create_group",
- access_token=self.admin_user_tok,
- content={"localpart": "test"},
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- group_id = channel.json_body["group_id"]
-
- self._check_group(group_id, expect_code=HTTPStatus.OK)
-
- # Invite/join another user
-
- url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
- channel = self.make_request(
- "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- url = "/groups/%s/self/accept_invite" % (group_id,)
- channel = self.make_request(
- "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- # Check other user knows they're in the group
- self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
- self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
-
- # Now delete the group
- url = "/_synapse/admin/v1/delete_group/" + group_id
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- access_token=self.admin_user_tok,
- content={"localpart": "test"},
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- # Check group returns HTTPStatus.NOT_FOUND
- self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND)
-
- # Check users don't think they're in the group
- self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
- self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
-
- def _check_group(self, group_id: str, expect_code: int) -> None:
- """Assert that trying to fetch the given group results in the given
- HTTP status code
- """
-
- url = "/groups/%s/profile" % (group_id,)
- channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
-
- self.assertEqual(expect_code, channel.code, msg=channel.json_body)
-
- def _get_groups_user_is_in(self, access_token: str) -> List[str]:
- """Returns the list of groups the user is in (given their access token)"""
- channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- return channel.json_body["groups"]
-
-
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Test /quarantine_media admin API."""
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 95282f07..ca6af941 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -2467,7 +2467,6 @@ PURGE_TABLES = [
"event_push_actions",
"event_search",
"events",
- "group_rooms",
"receipts_graph",
"receipts_linearized",
"room_aliases",
@@ -2484,9 +2483,9 @@ PURGE_TABLES = [
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
- "group_summary_rooms",
"room_account_data",
"room_tags",
# "state_groups", # Current impl leaves orphaned state groups around.
"state_groups_state",
+ "federation_inbound_events_staging",
]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0cdf1dec..0d441022 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
- storage = self.hs.get_storage()
+ storage_controllers = self.hs.get_storage_controllers()
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
event_creation_handler.create_new_client_event(builder)
)
- self.get_success(storage.persistence.persist_event(event, context))
+ self.get_success(storage_controllers.persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index e0a11da9..a43a1372 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -548,7 +548,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": user_id, "exclusive": True}]},
sender=user_id,
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 9653f458..05355c7f 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -195,8 +195,17 @@ class UIAuthTests(unittest.HomeserverTestCase):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.device_id = "dev1"
+
+ # Force-enable password login for just long enough to log in.
+ auth_handler = self.hs.get_auth_handler()
+ allow_auth_for_login = auth_handler._password_enabled_for_login
+ auth_handler._password_enabled_for_login = True
+
self.user_tok = self.login("test", self.user_pass, self.device_id)
+ # Restore password login to however it was.
+ auth_handler._password_enabled_for_login = allow_auth_for_login
+
def delete_device(
self,
access_token: str,
@@ -263,6 +272,38 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
+ @override_config({"password_config": {"enabled": "only_for_reauth"}})
+ def test_ui_auth_with_passwords_for_reauth_only(self) -> None:
+ """
+ Test user interactive authentication outside of registration.
+ """
+
+ # Attempt to delete this device.
+ # Returns a 401 as per the spec
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ self.user_tok,
+ self.device_id,
+ HTTPStatus.OK,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
def test_grandfathered_identifier(self) -> None:
"""Check behaviour without "identifier" dict
diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_devices.py
index a8af4e24..aa982224 100644
--- a/tests/rest/client/test_device_lists.py
+++ b/tests/rest/client/test_devices.py
@@ -13,8 +13,13 @@
# limitations under the License.
from http import HTTPStatus
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, login, register
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -157,3 +162,41 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
+
+
+class DevicesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_device_handler()
+
+ @unittest.override_config({"delete_stale_devices_after": 72000000})
+ def test_delete_stale_devices(self) -> None:
+ """Tests that stale devices are automatically removed after a set time of
+ inactivity.
+ The configuration is set to delete devices that haven't been used in the past 20h.
+ """
+ # Register a user and creates 2 devices for them.
+ user_id = self.register_user("user", "password")
+ tok1 = self.login("user", "password", device_id="abc")
+ tok2 = self.login("user", "password", device_id="def")
+
+ # Sync them so they have a last_seen value.
+ self.make_request("GET", "/sync", access_token=tok1)
+ self.make_request("GET", "/sync", access_token=tok2)
+
+ # Advance half a day and sync again with one of the devices, so that the next
+ # time the background job runs we don't delete this device (since it will look
+ # for devices that haven't been used for over an hour).
+ self.reactor.advance(43200)
+ self.make_request("GET", "/sync", access_token=tok1)
+
+ # Advance another half a day, and check that the device that has synced still
+ # exists but the one that hasn't has been removed.
+ self.reactor.advance(43200)
+ self.get_success(self.handler.get_device(user_id, "abc"))
+ self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError)
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 1b1392fa..a9b7db9d 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -19,6 +19,7 @@ from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
+from synapse.api.constants import EduTypes
from synapse.rest.client import events, login, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -103,7 +104,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
c
for c in channel.json_body["chunk"]
if not (
- c.get("type") == "m.presence"
+ c.get("type") == EduTypes.PRESENCE
and c["content"].get("user_id") == self.user_id
)
]
diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py
deleted file mode 100644
index e067cf82..00000000
--- a/tests/rest/client/test_groups.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest.client import groups, room
-
-from tests import unittest
-from tests.unittest import override_config
-
-
-class GroupsTestCase(unittest.HomeserverTestCase):
- user_id = "@alice:test"
- room_creator_user_id = "@bob:test"
-
- servlets = [room.register_servlets, groups.register_servlets]
-
- @override_config({"enable_group_creation": True})
- def test_rooms_limited_by_visibility(self) -> None:
- group_id = "+spqr:test"
-
- # Alice creates a group
- channel = self.make_request("POST", "/create_group", {"localpart": "spqr"})
- self.assertEqual(channel.code, 200, msg=channel.text_body)
- self.assertEqual(channel.json_body, {"group_id": group_id})
-
- # Bob creates a private room
- room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False)
- self.helper.auth_user_id = self.room_creator_user_id
- self.helper.send_state(
- room_id, "m.room.name", {"name": "bob's secret room"}, tok=None
- )
- self.helper.auth_user_id = self.user_id
-
- # Alice adds the room to her group.
- channel = self.make_request(
- "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {}
- )
- self.assertEqual(channel.code, 200, msg=channel.text_body)
- self.assertEqual(channel.json_body, {})
-
- # Alice now tries to retrieve the room list of the space.
- channel = self.make_request("GET", f"/groups/{group_id}/rooms")
- self.assertEqual(channel.code, 200, msg=channel.text_body)
- self.assertEqual(
- channel.json_body, {"chunk": [], "total_room_count_estimate": 0}
- )
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 4920468f..f4ea1209 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -1112,7 +1112,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.service = ApplicationService(
id="unique_identifier",
token="some_token",
- hostname="example.com",
sender="@asbot:example.com",
namespaces={
ApplicationService.NS_USERS: [
@@ -1125,7 +1124,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.another_service = ApplicationService(
id="another__identifier",
token="another_token",
- hostname="example.com",
sender="@as2bot:example.com",
namespaces={
ApplicationService.NS_USERS: [
diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 7b7d283b..a4327f7a 100644
--- a/tests/rest/client/test_mutual_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -36,12 +36,10 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self.handler = hs.get_user_directory_handler()
def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py
new file mode 100644
index 00000000..700f6587
--- /dev/null
+++ b/tests/rest/client/test_notifications.py
@@ -0,0 +1,91 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.rest.client import login, notifications, receipts, room
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+
+class HTTPPusherTests(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ receipts.register_servlets,
+ notifications.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.store = homeserver.get_datastores().main
+ self.module_api = homeserver.get_module_api()
+ self.event_creation_handler = homeserver.get_event_creation_handler()
+ self.sync_handler = homeserver.get_sync_handler()
+ self.auth_handler = homeserver.get_auth_handler()
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ # Mock out the calls over federation.
+ fed_transport_client = Mock(spec=["send_transaction"])
+ fed_transport_client.send_transaction = simple_async_mock({})
+
+ return self.setup_test_homeserver(
+ federation_transport_client=fed_transport_client,
+ )
+
+ def test_notify_for_local_invites(self) -> None:
+ """
+ Local users will get notified for invites
+ """
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Check we start with no pushes
+ channel = self.make_request(
+ "GET",
+ "/notifications",
+ access_token=other_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body)
+
+ # Send an invite
+ self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token)
+
+ # We should have a notification now
+ channel = self.make_request(
+ "GET",
+ "/notifications",
+ access_token=other_access_token,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
+ self.assertEqual(
+ channel.json_body["notifications"][0]["event"]["content"]["membership"],
+ "invite",
+ channel.json_body,
+ )
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 9aebf173..afb08b27 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -56,7 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
@@ -80,7 +79,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 27dee8f6..62e4db23 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
relation_type: str,
assertion_callable: Callable[[JsonDict], None],
expected_db_txn_for_event: int,
+ access_token: Optional[str] = None,
) -> None:
"""
Makes requests to various endpoints which should include bundled aggregations
@@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
for relation-specific assertions.
expected_db_txn_for_event: The number of database transactions which
are expected for a call to /event/.
+ access_token: The access token to user, defaults to self.user_token.
"""
+ access_token = access_token or self.user_token
def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
@@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body)
@@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
@@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"])
@@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# Request sync.
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
+ "GET", f"/sync?filter={filter}", access_token=access_token
)
self.assertEqual(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
@@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
chunk = [
@@ -995,7 +998,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored."""
@@ -1031,36 +1034,66 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
def test_thread(self) -> None:
"""
Test that threads get correctly bundled.
"""
- self._send_relation(RelationTypes.THREAD, "m.room.test")
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ # The root message is from "user", send replies as "user2".
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
thread_2 = channel.json_body["event_id"]
- def assert_thread(bundled_aggregations: JsonDict) -> None:
- self.assertEqual(2, bundled_aggregations.get("count"))
- self.assertTrue(bundled_aggregations.get("current_user_participated"))
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
+ # This needs two assertion functions which are identical except for whether
+ # the current_user_participated flag is True, create a factory for the
+ # two versions.
+ def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
+ def assert_thread(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertEqual(
+ participated, bundled_aggregations.get("current_user_participated")
+ )
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user2_id,
+ "type": "m.room.test",
},
- "event_id": thread_2,
- "sender": self.user_id,
- "type": "m.room.test",
- },
- bundled_aggregations.get("latest_event"),
- )
+ bundled_aggregations.get("latest_event"),
+ )
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10)
+ return assert_thread
+
+ # The "user" sent the root event and is making queries for the bundled
+ # aggregations: they have participated.
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+ # The "user2" sent replies in the thread and is making queries for the
+ # bundled aggregations: they have participated.
+ #
+ # Note that this re-uses some cached values, so the total number of
+ # queries is much smaller.
+ self._test_bundled_aggregations(
+ RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+ )
+
+ # A user with no interactions with the thread: they have not participated.
+ user3_id, user3_token = self._create_user("charlie")
+ self.helper.join(self.room, user=user3_id, tok=user3_token)
+ self._test_bundled_aggregations(
+ RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+ )
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
"""
@@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
def test_nested_thread(self) -> None:
"""
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 7b8fe6d0..ac9c1133 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -129,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
- storage = self.hs.get_storage()
+ storage_controllers = self.hs.get_storage_controllers()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Send a first event, which should be filtered out at the end of the test.
@@ -154,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
- filter_events_for_client(storage, self.user_id, events)
+ filter_events_for_client(storage_controllers, self.user_id, events)
)
# We should only get one event back.
@@ -252,16 +253,24 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
- config["retention"] = {
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+
+ retention_config = {
"enabled": True,
}
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ retention_config.update(config.get("retention", {}))
+ config["retention"] = retention_config
+
+ return config
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
mock_federation_client = Mock(spec=["backfill"])
self.hs = self.setup_test_homeserver(
- config=config,
federation_client=mock_federation_client,
)
return self.hs
@@ -295,6 +304,24 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
self._test_retention(room_id, expected_code_for_first_event=404)
+ @unittest.override_config({"retention": {"enabled": False}})
+ def test_visibility_when_disabled(self) -> None:
+ """Retention policies should be ignored when the retention feature is disabled."""
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms},
+ tok=self.token,
+ )
+
+ resp = self.helper.send(room_id=room_id, body="test", tok=self.token)
+
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ self.get_event(room_id, resp["event_id"])
+
def _test_retention(
self, room_id: str, expected_code_for_first_event: int = 200
) -> None:
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index 41a1bf6d..9d5cb60d 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -71,7 +71,6 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
self.appservice = ApplicationService(
token="i_am_an_app_service",
- hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
# Note: this user does not have to match the regex above
@@ -88,7 +87,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.clock = clock
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.virtual_user_id, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
@@ -168,7 +167,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
# Fetch the state_groups
state_group_map = self.get_success(
- self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
+ self._storage_controllers.state.get_state_groups_ids(
+ room_id, historical_event_ids
+ )
)
# We expect all of the historical events to be using the same state_group
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9443daa0..f523d89b 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -26,6 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import (
+ EduTypes,
EventContentFields,
EventTypes,
Membership,
@@ -925,7 +926,7 @@ class RoomJoinTestCase(RoomBase):
) -> bool:
return return_value
- callback_mock = Mock(side_effect=user_may_join_room)
+ callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
# Join a first room, without being invited to it.
@@ -1116,6 +1117,264 @@ class RoomMessagesTestCase(RoomBase):
self.assertEqual(200, channel.code, msg=channel.result["body"])
+class RoomPowerLevelOverridesTestCase(RoomBase):
+ """Tests that the power levels can be overridden with server config."""
+
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin_user_id = self.register_user("admin", "pass")
+ self.admin_access_token = self.login("admin", "pass")
+
+ def power_levels(self, room_id: str) -> Dict[str, Any]:
+ return self.helper.get_state(
+ room_id, "m.room.power_levels", self.admin_access_token
+ )
+
+ def test_default_power_levels_with_room_override(self) -> None:
+ """
+ Create a room, providing power level overrides.
+ Confirm that the room's power levels reflect the overrides.
+
+ See https://github.com/matrix-org/matrix-spec/issues/492
+ - currently we overwrite each key of power_level_content_override
+ completely.
+ """
+
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ extra_content={
+ "power_level_content_override": {"events": {"custom.event": 0}}
+ },
+ )
+ self.assertEqual(
+ {
+ "custom.event": 0,
+ },
+ self.power_levels(room_id)["events"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_power_levels_with_server_override(self) -> None:
+ """
+ With a server configured to modify the room-level defaults,
+ Create a room, without providing any extra power level overrides.
+ Confirm that the room's power levels reflect the server-level overrides.
+
+ Similar to https://github.com/matrix-org/matrix-spec/issues/492,
+ we overwrite each key of power_level_content_override completely.
+ """
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.assertEqual(
+ {
+ "custom.event": 0,
+ },
+ self.power_levels(room_id)["events"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {
+ "events": {"server.event": 0},
+ "ban": 13,
+ },
+ }
+ },
+ )
+ def test_power_levels_with_server_and_room_overrides(self) -> None:
+ """
+ With a server configured to modify the room-level defaults,
+ create a room, providing different overrides.
+ Confirm that the room's power levels reflect both overrides, and
+ choose the room overrides where they clash.
+ """
+
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ extra_content={
+ "power_level_content_override": {"events": {"room.event": 0}}
+ },
+ )
+
+ # Room override wins over server config
+ self.assertEqual(
+ {"room.event": 0},
+ self.power_levels(room_id)["events"],
+ )
+
+ # But where there is no room override, server config wins
+ self.assertEqual(13, self.power_levels(room_id)["ban"])
+
+
+class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
+ """
+ Tests that we can really do various otherwise-prohibited actions
+ based on overriding the power levels in config.
+ """
+
+ user_id = "@sid1:red"
+
+ def test_creator_can_post_state_event(self) -> None:
+ # Given I am the creator of a room
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am allowed
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+
+ def test_normal_user_can_not_post_state_event(self) -> None:
+ # Given I am a normal member of a room
+ room_id = self.helper.create_room_as("@some_other_guy:red")
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed because state events require PL>=50
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ "user_level (0) < send_level (50)",
+ channel.json_body["error"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_with_config_override_normal_user_can_post_state_event(self) -> None:
+ # Given the server has config allowing normal users to post my event type,
+ # and I am a normal member of a room
+ room_id = self.helper.create_room_as("@some_other_guy:red")
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am allowed
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_any_room_override_defeats_config_override(self) -> None:
+ # Given the server has config allowing normal users to post my event type
+ # And I am a normal member of a room
+ # But the room was created with special permissions
+ extra_content: Dict[str, Any] = {
+ "power_level_content_override": {"events": {}},
+ }
+ room_id = self.helper.create_room_as(
+ "@some_other_guy:red", extra_content=extra_content
+ )
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_specific_room_override_defeats_config_override(self) -> None:
+ # Given the server has config allowing normal users to post my event type,
+ # and I am a normal member of a room,
+ # but the room was created with special permissions for this event type
+ extra_content = {
+ "power_level_content_override": {"events": {"custom.event": 1}},
+ }
+ room_id = self.helper.create_room_as(
+ "@some_other_guy:red", extra_content=extra_content
+ )
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ + "user_level (0) < send_level (1)",
+ channel.json_body["error"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ "private_chat": None,
+ "trusted_private_chat": None,
+ }
+ },
+ )
+ def test_config_override_applies_only_to_specific_preset(self) -> None:
+ # Given the server has config for public_chats,
+ # and I am a normal member of a private_chat room
+ room_id = self.helper.create_room_as("@some_other_guy:red", is_public=False)
+ self.helper.invite(room=room_id, src="@some_other_guy:red", targ=self.user_id)
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed because the public_chat config does not
+ # affect this room, because this room is a private_chat
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ + "user_level (0) < send_level (50)",
+ channel.json_body["error"],
+ )
+
+
class RoomInitialSyncTestCase(RoomBase):
"""Tests /rooms/$room_id/initialSync."""
@@ -1154,7 +1413,7 @@ class RoomInitialSyncTestCase(RoomBase):
e["content"]["user_id"]: e for e in channel.json_body["presence"]
}
self.assertTrue(self.user_id in presence_by_user)
- self.assertEqual("m.presence", presence_by_user[self.user_id]["type"])
+ self.assertEqual(EduTypes.PRESENCE, presence_by_user[self.user_id]["type"])
class RoomMessageListTestCase(RoomBase):
@@ -2598,7 +2857,9 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now.
- mock = Mock(return_value=make_awaitable(True))
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
# Send a 3PID invite into the room and check that it succeeded.
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index c3942889..6435800f 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
@@ -139,7 +140,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
for i in range(3):
self.get_success(
federation_registry.on_edu(
- "m.direct_to_device",
+ EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
@@ -172,7 +173,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# and we can send more messages
self.get_success(
federation_registry.on_edu(
- "m.direct_to_device",
+ EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index ae5ada3b..d9bd8c4a 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.rest.client import (
directory,
login,
@@ -226,7 +226,7 @@ class RoomTestCase(_ShadowBannedBase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": room_id,
"content": {"user_ids": [self.other_user_id]},
}
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 01083376..e3efd1f1 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from http import HTTPStatus
from typing import List, Optional
from parameterized import parameterized
@@ -21,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import (
+ EduTypes,
EventContentFields,
EventTypes,
ReceiptTypes,
@@ -485,30 +487,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we didn't override the public read receipt
self.assertIsNone(self._get_read_receipt())
- @parameterized.expand(
- [
- # Old Element version, expected to send an empty body
- (
- "agent1",
- "Element/1.2.2 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
- 200,
- ),
- # Old SchildiChat version, expected to send an empty body
- ("agent2", "SchildiChat/1.2.1 (Android 10)", 200),
- # Expected 400: Denies empty body starting at version 1.3+
- ("agent3", "Element/1.3.6 (Android 10)", 400),
- ("agent4", "SchildiChat/1.3.6 (Android 11)", 400),
- # Contains "Riot": Receipts with empty bodies expected
- ("agent5", "Element (Riot.im) (Android 9)", 200),
- # Expected 400: Does not contain "Android"
- ("agent6", "Element/1.2.1", 400),
- # Expected 400: Different format, missing "/" after Element; existing build that should allow empty bodies, but minimal ongoing usage
- ("agent7", "Element dbg/1.1.8-dev (Android)", 400),
- ]
- )
- def test_read_receipt_with_empty_body(
- self, name: str, user_agent: str, expected_status_code: int
- ) -> None:
+ def test_read_receipt_with_empty_body_is_rejected(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -517,16 +496,16 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
"POST",
f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}",
access_token=self.tok2,
- custom_headers=[("User-Agent", user_agent)],
)
- self.assertEqual(channel.code, expected_status_code)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body)
def _get_read_receipt(self) -> Optional[JsonDict]:
"""Syncs and returns the read receipt."""
# Checks if event is a read receipt
def is_read_receipt(event: JsonDict) -> bool:
- return event["type"] == "m.receipt"
+ return event["type"] == EduTypes.RECEIPT
# Sync
channel = self.make_request(
@@ -678,12 +657,13 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(3)
# Check that custom events with a body increase the unread counter.
- self.helper.send_event(
+ result = self.helper.send_event(
self.room_id,
"org.matrix.custom_type",
{"body": "hello"},
tok=self.tok2,
)
+ event_id = result["event_id"]
self._check_unread_count(4)
# Check that edits don't increase the unread counter.
@@ -693,7 +673,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
content={
"body": "hello",
"msgtype": "m.text",
- "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ "m.relates_to": {
+ "rel_type": RelationTypes.REPLACE,
+ "event_id": event_id,
+ },
},
tok=self.tok2,
)
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index d6da5107..61b66d76 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -17,6 +17,7 @@
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import EduTypes
from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.types import UserID
@@ -67,7 +68,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": self.room_id,
"content": {"user_ids": [self.user_id]},
}
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index c86fc5df..98c1039d 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -76,7 +76,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
"""
Upgrading a room should work fine.
"""
- # THe user isn't in the room.
+ # The user isn't in the room.
roomless = self.register_user("roomless", "pass")
roomless_token = self.login(roomless, "pass")
@@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
new_space_id = channel.json_body["replacement_room"]
- state_ids = self.get_success(self.store.get_current_state_ids(new_space_id))
+ state_ids = self.get_success(
+ self.store.get_partial_current_state_ids(new_space_id)
+ )
# Ensure the new room is still a space.
create_event = self.get_success(
@@ -263,3 +265,35 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertIn((EventTypes.SpaceChild, self.room_id), state_ids)
# The child that was removed should not be copied over.
self.assertNotIn((EventTypes.SpaceChild, old_room_id), state_ids)
+
+ def test_custom_room_type(self) -> None:
+ """Test upgrading a room that has a custom room type set."""
+ test_room_type = "com.example.my_custom_room_type"
+
+ # Create a room with a custom room type.
+ room_id = self.helper.create_room_as(
+ self.creator,
+ tok=self.creator_token,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: test_room_type}
+ },
+ )
+
+ # Upgrade the room!
+ channel = self._upgrade_room(room_id=room_id)
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ state_ids = self.get_success(
+ self.store.get_partial_current_state_ids(new_room_id)
+ )
+
+ # Ensure the new room is the same type as the old room.
+ create_event = self.get_success(
+ self.store.get_event(state_ids[(EventTypes.Create, "")])
+ )
+ self.assertEqual(
+ create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type
+ )
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
new file mode 100644
index 00000000..14af07c5
--- /dev/null
+++ b/tests/rest/media/test_media_retention.py
@@ -0,0 +1,321 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import io
+from typing import Iterable, Optional, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, register, room
+from synapse.server import HomeServer
+from synapse.types import UserID
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+from tests.utils import MockClock
+
+
+class MediaRetentionTestCase(unittest.HomeserverTestCase):
+
+ ONE_DAY_IN_MS = 24 * 60 * 60 * 1000
+ THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS
+
+ servlets = [
+ room.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ # We need to be able to test advancing time in the homeserver, so we
+ # replace the test homeserver's default clock with a MockClock, which
+ # supports advancing time.
+ return self.setup_test_homeserver(clock=MockClock())
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.remote_server_name = "remote.homeserver"
+ self.store = hs.get_datastores().main
+
+ # Create a user to upload media with
+ test_user_id = self.register_user("alice", "password")
+
+ # Inject media (recently accessed, old access, never accessed, old access
+ # quarantined media) into both the local store and the remote cache, plus
+ # one additional local media that is marked as protected from quarantine.
+ media_repository = hs.get_media_repository()
+ test_media_content = b"example string"
+
+ def _create_media_and_set_attributes(
+ last_accessed_ms: Optional[int],
+ is_quarantined: Optional[bool] = False,
+ is_protected: Optional[bool] = False,
+ ) -> str:
+ # "Upload" some media to the local media store
+ mxc_uri = self.get_success(
+ media_repository.create_content(
+ media_type="text/plain",
+ upload_name=None,
+ content=io.BytesIO(test_media_content),
+ content_length=len(test_media_content),
+ auth_user=UserID.from_string(test_user_id),
+ )
+ )
+
+ media_id = mxc_uri.split("/")[-1]
+
+ # Set the last recently accessed time for this media
+ if last_accessed_ms is not None:
+ self.get_success(
+ self.store.update_cached_last_access_time(
+ local_media=(media_id,),
+ remote_media=(),
+ time_ms=last_accessed_ms,
+ )
+ )
+
+ if is_quarantined:
+ # Mark this media as quarantined
+ self.get_success(
+ self.store.quarantine_media_by_id(
+ server_name=self.hs.config.server.server_name,
+ media_id=media_id,
+ quarantined_by="@theadmin:test",
+ )
+ )
+
+ if is_protected:
+ # Mark this media as protected from quarantine
+ self.get_success(
+ self.store.mark_local_media_as_safe(
+ media_id=media_id,
+ safe=True,
+ )
+ )
+
+ return media_id
+
+ def _cache_remote_media_and_set_attributes(
+ media_id: str,
+ last_accessed_ms: Optional[int],
+ is_quarantined: Optional[bool] = False,
+ ) -> str:
+ # Pretend to cache some remote media
+ self.get_success(
+ self.store.store_cached_remote_media(
+ origin=self.remote_server_name,
+ media_id=media_id,
+ media_type="text/plain",
+ media_length=1,
+ time_now_ms=clock.time_msec(),
+ upload_name="testfile.txt",
+ filesystem_id="abcdefg12345",
+ )
+ )
+
+ # Set the last recently accessed time for this media
+ if last_accessed_ms is not None:
+ self.get_success(
+ hs.get_datastores().main.update_cached_last_access_time(
+ local_media=(),
+ remote_media=((self.remote_server_name, media_id),),
+ time_ms=last_accessed_ms,
+ )
+ )
+
+ if is_quarantined:
+ # Mark this media as quarantined
+ self.get_success(
+ self.store.quarantine_media_by_id(
+ server_name=self.remote_server_name,
+ media_id=media_id,
+ quarantined_by="@theadmin:test",
+ )
+ )
+
+ return media_id
+
+ # Start with the local media store
+ self.local_recently_accessed_media = _create_media_and_set_attributes(
+ last_accessed_ms=self.THIRTY_DAYS_IN_MS,
+ )
+ self.local_not_recently_accessed_media = _create_media_and_set_attributes(
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ )
+ self.local_not_recently_accessed_quarantined_media = (
+ _create_media_and_set_attributes(
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ is_quarantined=True,
+ )
+ )
+ self.local_not_recently_accessed_protected_media = (
+ _create_media_and_set_attributes(
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ is_protected=True,
+ )
+ )
+ self.local_never_accessed_media = _create_media_and_set_attributes(
+ last_accessed_ms=None,
+ )
+
+ # And now the remote media store
+ self.remote_recently_accessed_media = _cache_remote_media_and_set_attributes(
+ media_id="a",
+ last_accessed_ms=self.THIRTY_DAYS_IN_MS,
+ )
+ self.remote_not_recently_accessed_media = (
+ _cache_remote_media_and_set_attributes(
+ media_id="b",
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ )
+ )
+ self.remote_not_recently_accessed_quarantined_media = (
+ _cache_remote_media_and_set_attributes(
+ media_id="c",
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ is_quarantined=True,
+ )
+ )
+ # Remote media will always have a "last accessed" attribute, as it would not
+ # be fetched from the remote homeserver unless instigated by a user.
+
+ @override_config(
+ {
+ "media_retention": {
+ # Enable retention for local media
+ "local_media_lifetime": "30d"
+ # Cached remote media should not be purged
+ }
+ }
+ )
+ def test_local_media_retention(self) -> None:
+ """
+ Tests that local media that have not been accessed recently is purged, while
+ cached remote media is unaffected.
+ """
+ # Advance 31 days (in seconds)
+ self.reactor.advance(31 * 24 * 60 * 60)
+
+ # Check that media has been correctly purged.
+ # Local media accessed <30 days ago should still exist.
+ # Remote media should be unaffected.
+ self._assert_if_mxc_uris_purged(
+ purged=[
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_media,
+ ),
+ (self.hs.config.server.server_name, self.local_never_accessed_media),
+ ],
+ not_purged=[
+ (self.hs.config.server.server_name, self.local_recently_accessed_media),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_quarantined_media,
+ ),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_protected_media,
+ ),
+ (self.remote_server_name, self.remote_recently_accessed_media),
+ (self.remote_server_name, self.remote_not_recently_accessed_media),
+ (
+ self.remote_server_name,
+ self.remote_not_recently_accessed_quarantined_media,
+ ),
+ ],
+ )
+
+ @override_config(
+ {
+ "media_retention": {
+ # Enable retention for cached remote media
+ "remote_media_lifetime": "30d"
+ # Local media should not be purged
+ }
+ }
+ )
+ def test_remote_media_cache_retention(self) -> None:
+ """
+ Tests that entries from the remote media cache that have not been accessed
+ recently is purged, while local media is unaffected.
+ """
+ # Advance 31 days (in seconds)
+ self.reactor.advance(31 * 24 * 60 * 60)
+
+ # Check that media has been correctly purged.
+ # Local media should be unaffected.
+ # Remote media accessed <30 days ago should still exist.
+ self._assert_if_mxc_uris_purged(
+ purged=[
+ (self.remote_server_name, self.remote_not_recently_accessed_media),
+ ],
+ not_purged=[
+ (self.remote_server_name, self.remote_recently_accessed_media),
+ (self.hs.config.server.server_name, self.local_recently_accessed_media),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_media,
+ ),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_quarantined_media,
+ ),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_protected_media,
+ ),
+ (
+ self.remote_server_name,
+ self.remote_not_recently_accessed_quarantined_media,
+ ),
+ (self.hs.config.server.server_name, self.local_never_accessed_media),
+ ],
+ )
+
+ def _assert_if_mxc_uris_purged(
+ self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]]
+ ) -> None:
+ def _assert_mxc_uri_purge_state(
+ server_name: str, media_id: str, expect_purged: bool
+ ) -> None:
+ """Given an MXC URI, assert whether it has been purged or not."""
+ if server_name == self.hs.config.server.server_name:
+ found_media_dict = self.get_success(
+ self.store.get_local_media(media_id)
+ )
+ else:
+ found_media_dict = self.get_success(
+ self.store.get_cached_remote_media(server_name, media_id)
+ )
+
+ mxc_uri = f"mxc://{server_name}/{media_id}"
+
+ if expect_purged:
+ self.assertIsNone(
+ found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
+ )
+ else:
+ self.assertIsNotNone(
+ found_media_dict,
+ msg=f"{mxc_uri} unexpectedly purged",
+ )
+
+ # Assert that the given MXC URIs have either been correctly purged or not.
+ for server_name, media_id in purged:
+ _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True)
+ for server_name, media_id in not_purged:
+ _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False)
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index 62e30881..ea9e5889 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -145,7 +145,7 @@ class SummarizeTestCase(unittest.TestCase):
)
-class CalcOgTestCase(unittest.TestCase):
+class OpenGraphFromHtmlTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
@@ -235,6 +235,21 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+ # Another variant is a title with no content.
+ html = b"""
+ <html>
+ <head><title></title></head>
+ <body>
+ <h1>Title</h1>
+ </body>
+ </html>
+ """
+
+ tree = decode_body(html, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
+
+ self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
+
def test_h1_as_title(self) -> None:
html = b"""
<html>
@@ -250,6 +265,26 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
+ def test_empty_description(self) -> None:
+ """Description tags with empty content should be ignored."""
+ html = b"""
+ <html>
+ <meta property="og:description" content=""/>
+ <meta property="og:description"/>
+ <meta name="description" content=""/>
+ <meta name="description"/>
+ <meta name="description" content="Finally!"/>
+ <body>
+ <h1>Title</h1>
+ </body>
+ </html>
+ """
+
+ tree = decode_body(html, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
+
+ self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
+
def test_missing_title_and_broken_h1(self) -> None:
html = b"""
<html>
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 3b24d0ac..2c321f8d 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -656,6 +656,41 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
+ def test_nonexistent_image(self) -> None:
+ """If the preview image doesn't exist, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = (
+ b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
+ )
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # The image should not be in the result.
+ self.assertNotIn("og:image", channel.json_body)
+
def test_data_url(self) -> None:
"""
Requesting to preview a data URL is not supported.
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 19a145ee..22f99c6a 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
from unittest.mock import Mock, patch
from synapse._scripts.register_new_matrix_user import request_registration
@@ -49,8 +50,8 @@ class RegisterTestCase(TestCase):
requests.post = post
# The fake stdout will be written here
- out = []
- err_code = []
+ out: List[str] = []
+ err_code: List[int] = []
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
request_registration(
@@ -85,8 +86,8 @@ class RegisterTestCase(TestCase):
requests.get = get
# The fake stdout will be written here
- out = []
- err_code = []
+ out: List[str] = []
+ err_code: List[int] = []
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
request_registration(
@@ -137,8 +138,8 @@ class RegisterTestCase(TestCase):
requests.post = post
# The fake stdout will be written here
- out = []
- err_code = []
+ out: List[str] = []
+ err_code: List[int] = []
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
request_registration(
diff --git a/tests/server.py b/tests/server.py
index 8f30e250..b9f46597 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -109,6 +109,17 @@ class FakeChannel:
_ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
resource_usage: Optional[ContextResourceUsage] = None
+ _request: Optional[Request] = None
+
+ @property
+ def request(self) -> Request:
+ assert self._request is not None
+ return self._request
+
+ @request.setter
+ def request(self, request: Request) -> None:
+ assert self._request is None
+ self._request = request
@property
def json_body(self):
@@ -322,6 +333,8 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel, site)
+ channel.request = req
+
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(0, SEEK_END)
@@ -736,6 +749,7 @@ def setup_test_homeserver(
if config is None:
config = default_config(name, parse=True)
+ config.caches.resize_all_caches()
config.ldap_enabled = False
if "clock" not in kwargs:
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 9ee9509d..07e29788 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -75,6 +75,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
return_value=make_awaitable("!something:localhost")
)
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock(
+ return_value=make_awaitable("!something:localhost")
+ )
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@@ -102,6 +105,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once()
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
@@ -300,7 +304,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
hasn't been reached (since it's the only user and the limit is 5), so users
shouldn't receive a server notice.
"""
- self.register_user("user", "password")
+ m = Mock(return_value=make_awaitable(None))
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m
+
+ user_id = self.register_user("user", "password")
tok = self.login("user", "password")
channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
@@ -309,6 +316,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
"rooms", channel.json_body, "Got invites without server notice"
)
+ m.assert_called_once_with(user_id)
+
def test_invite_with_notice(self):
"""Tests that, if the MAU limit is hit, the server notices user invites each user
to a room in which it has sent a notice.
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index c237a8c7..38963ce4 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -154,6 +154,31 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+ def test_event_ref(self):
+ """Test that we reuse events that are still in memory but have fallen
+ out of the cache, rather than requesting them from the DB.
+ """
+
+ # Reset the event cache
+ self.store._get_event_cache.clear()
+
+ with LoggingContext("test") as ctx:
+ # We keep hold of the event event though we never use it.
+ event = self.get_success(self.store.get_event(self.event_id)) # noqa: F841
+
+ # We should have fetched the event from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+ # Reset the event cache
+ self.store._get_event_cache.clear()
+
+ with LoggingContext("test") as ctx:
+ self.get_success(self.store.get_event(self.event_id))
+
+ # Since the event is still in memory we shouldn't have fetched it
+ # from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
+
def test_dedupe(self):
"""Test that if we request the same event multiple times we only pull it
out once.
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 74c6224e..3cc2a58d 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer, reactor
+from twisted.internet.base import ReactorBase
+from twisted.internet.defer import Deferred
+
from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
@@ -22,6 +26,56 @@ class LockTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs: HomeServer):
self.store = hs.get_datastores().main
+ def test_acquire_contention(self):
+ # Track the number of tasks holding the lock.
+ # Should be at most 1.
+ in_lock = 0
+ max_in_lock = 0
+
+ release_lock: "Deferred[None]" = Deferred()
+
+ async def task():
+ nonlocal in_lock
+ nonlocal max_in_lock
+
+ lock = await self.store.try_acquire_lock("name", "key")
+ if not lock:
+ return
+
+ async with lock:
+ in_lock += 1
+ max_in_lock = max(max_in_lock, in_lock)
+
+ # Block to allow other tasks to attempt to take the lock.
+ await release_lock
+
+ in_lock -= 1
+
+ # Start 3 tasks.
+ task1 = defer.ensureDeferred(task())
+ task2 = defer.ensureDeferred(task())
+ task3 = defer.ensureDeferred(task())
+
+ # Give the reactor a kick so that the database transaction returns.
+ self.pump()
+
+ release_lock.callback(None)
+
+ # Run the tasks to completion.
+ # To work around `Linearizer`s using a different reactor to sleep when
+ # contended (#12841), we call `runUntilCurrent` on
+ # `twisted.internet.reactor`, which is a different reactor to that used
+ # by the homeserver.
+ assert isinstance(reactor, ReactorBase)
+ self.get_success(task1)
+ reactor.runUntilCurrent()
+ self.get_success(task2)
+ reactor.runUntilCurrent()
+ self.get_success(task3)
+
+ # At most one task should have held the lock at a time.
+ self.assertEqual(max_in_lock, 1)
+
def test_simple_lock(self):
"""Test that we can take out a lock and that while we hold it nobody
else can take it out.
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1bf93e79..1047ed09 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -14,7 +14,7 @@
import json
import os
import tempfile
-from typing import List, Optional, cast
+from typing import List, cast
from unittest.mock import Mock
import yaml
@@ -149,15 +149,12 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- def _set_state(
- self, id: str, state: ApplicationServiceState, txn: Optional[int] = None
- ):
+ def _set_state(self, id: str, state: ApplicationServiceState):
return self.db_pool.runOperation(
self.engine.convert_param_style(
- "INSERT INTO application_services_state(as_id, state, last_txn) "
- "VALUES(?,?,?)"
+ "INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
),
- (id, state.value, txn),
+ (id, state.value),
)
def _insert_txn(self, as_id, txn_id, events):
@@ -283,17 +280,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
- "SELECT last_txn FROM application_services_state WHERE as_id=?"
- ),
- (service.id,),
- )
- )
- self.assertEqual(1, len(res))
- self.assertEqual(txn_id, res[0][0])
-
- res = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
@@ -316,14 +302,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
- "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+ "SELECT state FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
)
self.assertEqual(1, len(res))
- self.assertEqual(txn_id, res[0][0])
- self.assertEqual(ApplicationServiceState.UP.value, res[0][1])
+ self.assertEqual(ApplicationServiceState.UP.value, res[0][0])
res = self.get_success(
self.db_pool.runQuery(
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index a8ffb52c..cce8e75c 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -60,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
- self.datastore = SQLBaseStore(db, None, hs)
+ self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
@defer.inlineCallbacks
def test_insert_1col(self):
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index bbf079b2..f37505b6 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -13,6 +13,7 @@
# limitations under the License.
import synapse.api.errors
+from synapse.api.constants import EduTypes
from tests.unittest import HomeserverTestCase
@@ -266,10 +267,12 @@ class DeviceStoreTestCase(HomeserverTestCase):
# (This is a temporary arrangement for backwards compatibility!)
self.assertEqual(len(device_updates), 2, device_updates)
self.assertEqual(
- device_updates[0][0], "m.signing_key_update", device_updates[0]
+ device_updates[0][0], EduTypes.SIGNING_KEY_UPDATE, device_updates[0]
)
self.assertEqual(
- device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
+ device_updates[1][0],
+ EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
+ device_updates[1],
)
# Check there are no more device updates left.
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 401020fd..a0ce077a 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
- txn, [(e, EventContext()) for e in events]
+ txn,
+ [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 645d564d..d92a9ac5 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -58,15 +58,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
(room_id, event_id),
)
- txn.execute(
- (
- "INSERT INTO event_reference_hashes "
- "(event_id, algorithm, hash) "
- "VALUES (?, 'sha256', ?)"
- ),
- (event_id, bytearray(b"ffff")),
- )
-
for i in range(0, 20):
self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index ef5e2587..2ff88e64 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -31,7 +31,8 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
+ self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main
self.register_user("user", "pass")
@@ -69,9 +70,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
- self.state.compute_event_context(event, old_state=state)
+ self.state.compute_event_context(event, state_ids_before_event=state)
)
- self.get_success(self.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities):
"""Assert the current extremities for the room"""
@@ -103,9 +104,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -135,17 +138,20 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
- self.get_success(self.state.get_current_state(self.room_id))
+ self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
)
state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success(
self.state.compute_event_context(
- remote_event_2, old_state=state_before_gap.values()
+ remote_event_2,
+ state_ids_before_event=state_before_gap,
)
)
- self.get_success(self.persistence.persist_event(remote_event_2, context))
+ self.get_success(self._persistence.persist_event(remote_event_2, context))
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -177,9 +183,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +215,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +257,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +301,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +337,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
@@ -340,7 +356,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self):
@@ -377,7 +393,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_rooms_for_user` to add the remote user to the cache
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
@@ -424,7 +440,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_users_in_room` to add the remote user to the cache
users = self.get_success(self.store.get_users_in_room(room_id))
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 4c29ad79..e8b4a564 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -407,3 +407,86 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(result[service1], 2)
self.assertEqual(result[service2], 1)
self.assertEqual(result[native], 1)
+
+ def test_get_monthly_active_users_by_service(self):
+ # (No users, no filtering) -> empty result
+ result = self.get_success(self.store.get_monthly_active_users_by_service())
+
+ self.assertEqual(len(result), 0)
+
+ # (Some users, no filtering) -> non-empty result
+ appservice1_user1 = "@appservice1_user1:example.com"
+ appservice2_user1 = "@appservice2_user1:example.com"
+ service1 = "service1"
+ service2 = "service2"
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user1, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice2_user1, password_hash=None, appservice_id=service2
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
+
+ result = self.get_success(self.store.get_monthly_active_users_by_service())
+
+ self.assertEqual(len(result), 2)
+ self.assertIn((service1, appservice1_user1), result)
+ self.assertIn((service2, appservice2_user1), result)
+
+ # (Some users, end-timestamp filtering) -> non-empty result
+ appservice1_user2 = "@appservice1_user2:example.com"
+ timestamp1 = self.reactor.seconds()
+ self.reactor.advance(5)
+ timestamp2 = self.reactor.seconds()
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user2, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
+
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ end_timestamp=round(timestamp1 * 1000)
+ )
+ )
+
+ self.assertEqual(len(result), 2)
+ self.assertNotIn((service1, appservice1_user2), result)
+
+ # (Some users, start-timestamp filtering) -> non-empty result
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ start_timestamp=round(timestamp2 * 1000)
+ )
+ )
+
+ self.assertEqual(len(result), 1)
+ self.assertIn((service1, appservice1_user2), result)
+
+ # (Some users, full-timestamp filtering) -> non-empty result
+ native_user1 = "@native_user1:example.com"
+ native = "native"
+ timestamp3 = self.reactor.seconds()
+ self.reactor.advance(100)
+ self.get_success(
+ self.store.register_user(
+ user_id=native_user1, password_hash=None, appservice_id=native
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(native_user1))
+
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ start_timestamp=round(timestamp2 * 1000),
+ end_timestamp=round(timestamp3 * 1000),
+ )
+ )
+
+ self.assertEqual(len(result), 1)
+ self.assertIn((service1, appservice1_user2), result)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 08cc6023..8dfaa055 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self):
"""
@@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
self.get_success(
- self.storage.purge_events.purge_history(self.room_id, token_str, True)
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, token_str, True
+ )
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
@@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
f = self.get_failure(
- self.storage.purge_events.purge_history(self.room_id, event, True),
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, event, True
+ ),
SynapseError,
)
self.assertIn("greater than forward", f.value.args[0])
@@ -98,14 +102,17 @@ class PurgeTests(HomeserverTestCase):
first = self.helper.send(self.room_id, body="test1")
# Get the current room state.
- state_handler = self.hs.get_state_handler()
create_event = self.get_success(
- state_handler.get_current_state(self.room_id, "m.room.create", "")
+ self._storage_controllers.state.get_current_state_event(
+ self.room_id, "m.room.create", ""
+ )
)
self.assertIsNotNone(create_event)
# Purge everything before this topological token
- self.get_success(self.storage.purge_events.purge_room(self.room_id))
+ self.get_success(
+ self._storage_controllers.purge_events.purge_room(self.room_id)
+ )
# The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d8d17ef3..6c4e63b7 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage = hs.get_storage_controllers()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.storage.persistence.persist_event(redaction_event, context)
+ self._storage.persistence.persist_event(redaction_event, context)
)
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 5b011e18..3c79dabc 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs):
self.get_success(
- self.storage.persistence.persist_event(
+ self._storage_controllers.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
@@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
)
state = self.get_success(
- self.store.get_current_state(room_id=self.room.to_string())
+ self._storage_controllers.state.get_current_state(
+ room_id=self.room.to_string()
+ )
)
self.assertEqual(1, len(state))
@@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
)
state = self.get_success(
- self.store.get_current_state(room_id=self.room.to_string())
+ self._storage_controllers.state.get_current_state(
+ room_id=self.room.to_string()
+ )
)
self.assertEqual(1, len(state))
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 8dfc1e1d..e747c6b5 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
prev_state_map = self.get_success(
- self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
+ self.hs.get_storage_controllers().state.get_state_ids_for_event(
+ prev_event_ids[0]
+ )
)
event_dict = {
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index a2a9c05f..1218786d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -34,7 +34,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None:
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override]
# We can't test the RoomMemberStore on its own without the other event
# storage logic
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index f88f1c55..8043bdbd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7f1964eb..5b60cf52 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -134,7 +134,6 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice = ApplicationService(
token="i_am_an_app_service",
- hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 303e190b..cae14151 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -17,8 +17,12 @@ from unittest import mock
from twisted.internet.defer import CancelledError, ensureDeferred
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+ PartialCurrentStateTracker,
+ PartialStateEventsTracker,
+)
+from tests.test_utils import make_awaitable
from tests.unittest import TestCase
@@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d2)
+
+
+class PartialCurrentStateTrackerTestCase(TestCase):
+ def setUp(self) -> None:
+ self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+
+ self.tracker = PartialCurrentStateTracker(self.mock_store)
+
+ def test_does_not_block_for_full_state_rooms(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+
+ self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+ def test_blocks_for_partial_room_state(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+ d = ensureDeferred(self.tracker.await_full_state("room_id"))
+
+ # there should be no result yet
+ self.assertNoResult(d)
+
+ # notifying that the room has been de-partial-stated should unblock
+ self.tracker.notify_un_partial_stated("room_id")
+ self.successResultOf(d)
+
+ def test_un_partial_state_race(self):
+ # We should correctly handle race between awaiting the state and us
+ # un-partialling the state
+ async def is_partial_state_room(events):
+ self.tracker.notify_un_partial_stated("room_id")
+ return True
+
+ self.mock_store.is_partial_state_room.side_effect = is_partial_state_room
+
+ self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+ def test_cancellation(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+ d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
+ self.assertNoResult(d1)
+
+ d2 = ensureDeferred(self.tracker.await_full_state("room_id"))
+ self.assertNoResult(d2)
+
+ d1.cancel()
+ self.assertFailure(d1, CancelledError)
+
+ # d2 should still be waiting!
+ self.assertNoResult(d2)
+
+ self.tracker.notify_un_partial_stated("room_id")
+ self.successResultOf(d2)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 5bbc361a..f14fcb7d 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -105,7 +105,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.store.services_cache.append(
ApplicationService(
token=as_token,
- hostname=self.hs.hostname,
id="SomeASID",
sender="@as_sender:test",
namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
@@ -251,7 +250,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.store.services_cache.append(
ApplicationService(
token=as_token_1,
- hostname=self.hs.hostname,
id="SomeASID",
sender="@as_sender_1:test",
namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]},
@@ -262,7 +260,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.store.services_cache.append(
ApplicationService(
token=as_token_2,
- hostname=self.hs.hostname,
id="AnotherASID",
sender="@as_sender_2:test",
namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]},
diff --git a/tests/test_server.py b/tests/test_server.py
index f2ffbc89..0f1eb43c 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -13,18 +13,28 @@
# limitations under the License.
import re
+from http import HTTPStatus
+from typing import Tuple
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def
-from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
-from synapse.http.site import SynapseSite
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ DirectServeJsonResource,
+ JsonResource,
+ OptionsResource,
+ cancellable,
+)
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import make_deferred_yieldable
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
@@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
+
+
+class CancellableDirectServeJsonResource(DirectServeJsonResource):
+ def __init__(self, clock: Clock):
+ super().__init__()
+ self.clock = clock
+
+ @cancellable
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
+ ERROR_TEMPLATE = "{code} {msg}"
+
+ def __init__(self, clock: Clock):
+ super().__init__()
+ self.clock = clock
+
+ @cancellable
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, b"ok"
+
+ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, b"ok"
+
+
+class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+ """Tests for `DirectServeJsonResource` cancellation."""
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+ self.resource = CancellableDirectServeJsonResource(self.clock)
+ self.site = FakeSite(self.resource, self.reactor)
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "GET", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "POST", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
+
+
+class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+ """Tests for `DirectServeHtmlResource` cancellation."""
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+ self.resource = CancellableDirectServeHtmlResource(self.clock)
+ self.site = FakeSite(self.resource, self.reactor)
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "GET", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body=b"499 Request cancelled",
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "POST", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
+ )
diff --git a/tests/test_state.py b/tests/test_state.py
index e4baa691..95f81beb 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -88,6 +88,9 @@ class _DummyStore:
return groups
+ async def get_state_ids_for_group(self, state_group, state_filter=None):
+ return self._group_to_state[state_group]
+
async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
@@ -126,6 +129,19 @@ class _DummyStore:
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
+ async def get_state_group_for_events(self, event_ids):
+ res = {}
+ for event in event_ids:
+ res[event] = self._event_to_state_group[event]
+ return res
+
+ async def get_state_for_groups(self, groups):
+ res = {}
+ for group in groups:
+ state = self._group_to_state[group]
+ res[group] = state
+ return res
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -163,12 +179,12 @@ class Graph:
class StateTestCase(unittest.TestCase):
def setUp(self):
self.dummy_store = _DummyStore()
- storage = Mock(main=self.dummy_store, state=self.dummy_store)
+ storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
hs = Mock(
spec_set=[
"config",
"get_datastores",
- "get_storage",
+ "get_storage_controllers",
"get_auth",
"get_state_handler",
"get_clock",
@@ -183,7 +199,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
- hs.get_storage.return_value = storage
+ hs.get_storage_controllers.return_value = storage_controllers
self.state = StateHandler(hs)
self.event_id = 0
@@ -426,7 +442,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
- self.state.compute_event_context(event, old_state=old_state)
+ self.state.compute_event_context(
+ event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in old_state
+ },
+ )
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -451,7 +472,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
- self.state.compute_event_context(event, old_state=old_state)
+ self.state.compute_event_context(
+ event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in old_state
+ },
+ )
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
diff --git a/tests/test_types.py b/tests/test_types.py
index 80888a74..0b10dae8 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -13,7 +13,7 @@
# limitations under the License.
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest
@@ -62,25 +62,6 @@ class RoomAliasTestCase(unittest.HomeserverTestCase):
self.assertFalse(RoomAlias.is_valid(id_string))
-class GroupIDTestCase(unittest.TestCase):
- def test_parse(self):
- group_id = GroupID.from_string("+group/=_-.123:my.domain")
- self.assertEqual("group/=_-.123", group_id.localpart)
- self.assertEqual("my.domain", group_id.domain)
-
- def test_validate(self):
- bad_ids = ["$badsigil:domain", "+:empty"] + [
- "+group" + c + ":domain" for c in "A%?æ£"
- ]
- for id_string in bad_ids:
- try:
- GroupID.from_string(id_string)
- self.fail("Parsing '%s' should raise exception" % id_string)
- except SynapseError as exc:
- self.assertEqual(400, exc.code)
- self.assertEqual("M_INVALID_PARAM", exc.errcode)
-
-
class MapUsernameTestCase(unittest.TestCase):
def testPassThrough(self):
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index c654e36e..8027c7a8 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -70,7 +70,7 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- persistence = hs.get_storage().persistence
+ persistence = hs.get_storage_controllers().persistence
assert persistence is not None
await persistence.persist_event(event, context)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0230f9e..f338af6c 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
super(FilterEventsForServerTestCase, self).setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
events_to_filter.append(evt)
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
outlier = self._inject_outlier()
self.assertEqual(
self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier]
+ )
),
[outlier],
)
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier, evt]
+ )
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
self.assertEqual(filtered[0], outlier)
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... but other servers should only be able to see the outlier (the other should
# be redacted)
filtered = self.get_success(
- filter_events_for_server(self.storage, "other_server", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "other_server", [outlier, evt]
+ )
)
self.assertEqual(filtered[0], outlier)
self.assertEqual(filtered[1].event_id, evt.event_id)
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
for i in range(0, len(events_to_filter)):
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_room_member(
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_message(
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_outlier(self) -> EventBase:
@@ -234,7 +250,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self.storage.persistence.persist_event(event, EventContext.for_outlier())
+ self._storage_controllers.persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
+ )
)
return event
@@ -291,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@user:test",
+ [invite_event, reject_event],
)
),
[invite_event, reject_event],
@@ -301,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@other:test",
+ [invite_event, reject_event],
)
),
[],
diff --git a/tests/unittest.py b/tests/unittest.py
index 9afa68c1..e7f255b4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -831,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
self.site,
method=method,
path=path,
- content=content or "",
+ content=content if content is not None else "",
shorthand=False,
await_result=await_result,
custom_headers=custom_headers,
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 321fc177..67173a4f 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -14,8 +14,9 @@
from typing import List
-from unittest.mock import Mock
+from unittest.mock import Mock, patch
+from synapse.metrics.jemalloc import JemallocStats
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
from synapse.util.caches.treecache import TreeCache
@@ -316,3 +317,58 @@ class TimeEvictionTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key1"), None)
self.assertEqual(cache.get("key2"), 3)
+
+
+class MemoryEvictionTestCase(unittest.HomeserverTestCase):
+ @override_config(
+ {
+ "caches": {
+ "cache_autotuning": {
+ "max_cache_memory_usage": "700M",
+ "target_cache_memory_usage": "500M",
+ "min_cache_ttl": "5m",
+ }
+ }
+ }
+ )
+ @patch("synapse.util.caches.lrucache.get_jemalloc_stats")
+ def test_evict_memory(self, jemalloc_interface) -> None:
+ mock_jemalloc_class = Mock(spec=JemallocStats)
+ jemalloc_interface.return_value = mock_jemalloc_class
+
+ # set the return value of get_stat() to be greater than max_cache_memory_usage
+ mock_jemalloc_class.get_stat.return_value = 924288000
+
+ setup_expire_lru_cache_entries(self.hs)
+ cache = LruCache(4, clock=self.hs.get_clock())
+
+ cache["key1"] = 1
+ cache["key2"] = 2
+
+ # advance the reactor less than the min_cache_ttl
+ self.reactor.advance(60 * 2)
+
+ # our items should still be in the cache
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 2)
+
+ # advance the reactor past the min_cache_ttl
+ self.reactor.advance(60 * 6)
+
+ # the items should be cleared from cache
+ self.assertEqual(cache.get("key1"), None)
+ self.assertEqual(cache.get("key2"), None)
+
+ # add more stuff to caches
+ cache["key1"] = 1
+ cache["key2"] = 2
+
+ # set the return value of get_stat() to be lower than target_cache_memory_usage
+ mock_jemalloc_class.get_stat.return_value = 10000
+
+ # advance the reactor past the min_cache_ttl
+ self.reactor.advance(60 * 6)
+
+ # the items should still be in the cache
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 2)
diff --git a/tests/utils.py b/tests/utils.py
index d4ba3a9b..3059c453 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -264,7 +264,7 @@ class MockClock:
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room"""
- persistence_store = hs.get_storage().persistence
+ persistence_store = hs.get_storage_controllers().persistence
store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()